In [None]:
import tensorflow as tf
from tensorflow import keras
from keras.models import Sequential, Input, Model
from keras.layers import Dense, Input, Lambda, Layer, Dropout, Flatten, GlobalAveragePooling2D, GlobalMaxPooling2D
from keras import backend as K
from keras.callbacks import ModelCheckpoint

from tensorflow.keras.layers.experimental.preprocessing import RandomFlip, RandomRotation, RandomTranslation, RandomZoom
from tensorflow.keras.applications import Xception
import tensorflow_addons as tfa 

import numpy as np
import pandas as pd
import os

In [None]:
'''Triplet loss function from omoindrot's github repository (https://github.com/omoindrot/tensorflow-triplet-loss) 
Updated code to work with Tensorflow 2 and added function for keras compatibility
 '''

def _pairwise_distances(embeddings, squared=False):
    """Compute the 2D matrix of distances between all the embeddings.
    Args:
        embeddings: tensor of shape (batch_size, embed_dim)
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.
    Returns:
        pairwise_distances: tensor of shape (batch_size, batch_size)
    """
    # Get the dot product between all embeddings
    # shape (batch_size, batch_size)
    dot_product = tf.matmul(embeddings, tf.transpose(embeddings))

    # Get squared L2 norm for each embedding. We can just take the diagonal of `dot_product`.
    # This also provides more numerical stability (the diagonal of the result will be exactly 0).
    # shape (batch_size,)
    square_norm = tf.linalg.diag_part(dot_product)

    # Compute the pairwise distance matrix as we have:
    # ||a - b||^2 = ||a||^2  - 2 <a, b> + ||b||^2
    # shape (batch_size, batch_size)
    distances = tf.expand_dims(square_norm, 1) - 2.0 * dot_product + tf.expand_dims(square_norm, 0)

    # Because of computation errors, some distances might be negative so we put everything >= 0.0
    distances = tf.maximum(distances, 0.0)

    if not squared:
        # Because the gradient of sqrt is infinite when distances == 0.0 (ex: on the diagonal)
        # we need to add a small epsilon where distances == 0.0
        mask = tf.cast(tf.equal(distances, 0.0), dtype=tf.float32)
        distances = distances + mask * 1e-16

        distances = tf.sqrt(distances)

        # Correct the epsilon added: set the distances on the mask to be exactly 0.0
        distances = distances * (1.0 - mask)

    return distances


def _get_anchor_positive_triplet_mask(labels):
    """Return a 2D mask where mask[a, p] is True iff a and p are distinct and have same label.
    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]
    Returns:
        mask: tf.bool `Tensor` with shape [batch_size, batch_size]
    """
    # Check that i and j are distinct
    indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
    indices_not_equal = tf.logical_not(indices_equal)

    # Check if labels[i] == labels[j]
    # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
    labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))

    # Combine the two masks
    mask = tf.logical_and(indices_not_equal, labels_equal)

    return mask


def _get_anchor_negative_triplet_mask(labels):
    """Return a 2D mask where mask[a, n] is True iff a and n have distinct labels.
    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]
    Returns:
        mask: tf.bool `Tensor` with shape [batch_size, batch_size]
    """
    # Check if labels[i] != labels[k]
    # Uses broadcasting where the 1st argument has shape (1, batch_size) and the 2nd (batch_size, 1)
    labels_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))

    mask = tf.logical_not(labels_equal)

    return mask


def _get_triplet_mask(labels):
    """Return a 3D mask where mask[a, p, n] is True iff the triplet (a, p, n) is valid.
    A triplet (i, j, k) is valid if:
        - i, j, k are distinct
        - labels[i] == labels[j] and labels[i] != labels[k]
    Args:
        labels: tf.int32 `Tensor` with shape [batch_size]
    """
    # Check that i, j and k are distinct
    indices_equal = tf.cast(tf.eye(tf.shape(labels)[0]), tf.bool)
    indices_not_equal = tf.logical_not(indices_equal)
    i_not_equal_j = tf.expand_dims(indices_not_equal, 2)
    i_not_equal_k = tf.expand_dims(indices_not_equal, 1)
    j_not_equal_k = tf.expand_dims(indices_not_equal, 0)

    distinct_indices = tf.logical_and(tf.logical_and(i_not_equal_j, i_not_equal_k), j_not_equal_k)


    # Check if labels[i] == labels[j] and labels[i] != labels[k]
    label_equal = tf.equal(tf.expand_dims(labels, 0), tf.expand_dims(labels, 1))
    i_equal_j = tf.expand_dims(label_equal, 2)
    i_equal_k = tf.expand_dims(label_equal, 1)

    valid_labels = tf.logical_and(i_equal_j, tf.logical_not(i_equal_k))

    # Combine the two masks
    mask = tf.logical_and(distinct_indices, valid_labels)

    return mask


def batch_hard_triplet_loss(labels, embeddings, margin, squared=False):
    """Build the triplet loss over a batch of embeddings.
    For each anchor, we get the hardest positive and hardest negative to form a triplet.
    Args:
        labels: labels of the batch, of size (batch_size,)
        embeddings: tensor of shape (batch_size, embed_dim)
        margin: margin for triplet loss
        squared: Boolean. If true, output is the pairwise squared euclidean distance matrix.
                 If false, output is the pairwise euclidean distance matrix.
    Returns:
        triplet_loss: scalar tensor containing the triplet loss
    """
    # Get the pairwise distance matrix
    pairwise_dist = _pairwise_distances(embeddings, squared=squared)

    # For each anchor, get the hardest positive
    # First, we need to get a mask for every valid positive (they should have same label)
    mask_anchor_positive = _get_anchor_positive_triplet_mask(labels)
    mask_anchor_positive = tf.cast(mask_anchor_positive, dtype=tf.float32)

    # We put to 0 any element where (a, p) is not valid (valid if a != p and label(a) == label(p))
    anchor_positive_dist = tf.multiply(mask_anchor_positive, pairwise_dist)

    # shape (batch_size, 1)
    hardest_positive_dist = tf.reduce_max(anchor_positive_dist, axis=1, keepdims=True)
    tf.summary.scalar("hardest_positive_dist", tf.reduce_mean(hardest_positive_dist))

    # For each anchor, get the hardest negative
    # First, we need to get a mask for every valid negative (they should have different labels)
    mask_anchor_negative = _get_anchor_negative_triplet_mask(labels)
    mask_anchor_negative = tf.cast(mask_anchor_negative, dtype=tf.float32)

    # We add the maximum value in each row to the invalid negatives (label(a) == label(n))
    max_anchor_negative_dist = tf.reduce_max(pairwise_dist, axis=1, keepdims=True)
    anchor_negative_dist = pairwise_dist + max_anchor_negative_dist * (1.0 - mask_anchor_negative)

    # shape (batch_size,)
    hardest_negative_dist = tf.reduce_min(anchor_negative_dist, axis=1, keepdims=True)
    tf.summary.scalar("hardest_negative_dist", tf.reduce_mean(hardest_negative_dist))

    # Combine biggest d(a, p) and smallest d(a, n) into final triplet loss
    triplet_loss = tf.maximum(hardest_positive_dist - hardest_negative_dist + margin, 0.0)

    # Get final mean triplet loss
    triplet_loss = tf.reduce_mean(triplet_loss)

    return triplet_loss


def keras_batch_hard_triplet_loss(labels, y_pred):
    labels = K.flatten(labels)
    return batch_hard_triplet_loss(labels, y_pred, margin = margin)

In [None]:
# Load dataset
X = np.load('/content/drive/My Drive/Data/FaceDataset/celebA_mtcnn_X.npy', allow_pickle=True)
Y = np.load('/content/drive/My Drive/Data/FaceDataset/celebA_mtcnn_Y.npy', allow_pickle=True)
Y = np.array([int(i) for i in Y])
assert X.shape[0] == Y.shape[0]

In [None]:
# Shuffle Dataset
np.random.seed(42)
idx = np.random.permutation(len(Y))

X, Y = X[idx], Y[idx]

In [None]:
# Split Training Data into Train Set and Test Set 
 
train_split = int(X.shape[0] * 0.95)
X_train = X[:train_split]
X_test = X[train_split:]
Y_train = Y[:train_split]
Y_test = Y[train_split:]
print(X_train.shape, X_test.shape, Y_train.shape, Y_test.shape)
del X
del Y

(188924, 96, 96, 3) (9944, 96, 96, 3) (188924,) (9944,)


In [None]:
train_data = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
#val_data = tf.data.Dataset.from_tensor_slices((X_test, Y_test))
del X_train
del Y_train

In [None]:
# Apply image augmentation to hopefully compensate for relatively small dataset

def rand_brightness(x, p=0.85):
    if tf.random.uniform([]) < p: 
        return tf.image.random_brightness(x, 0.5)
    else:
        return x


def saturate(x, p=0.85):
    if tf.random.uniform([]) < p:
        return tf.image.random_saturation(x, 1,8)
    else:
        return x

def rand_contrast(x, p=0.85):
    if tf.random.uniform([]) < p:
        return tf.image.random_contrast(x, 0.1, 0.8)
    else:
        return x 


def hue(x, p=0.85):
    if tf.random.uniform([]) < p:
        return tf.image.random_hue(x, 0.1)
    else:
        return x

class Augment(Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
    
    def call(self, x):
        x = rand_brightness(x)
        x = saturate(x)
        x = rand_contrast(x)
        x = hue(x)
        return x

data_augmentation = tf.keras.Sequential([Augment(),
                                         RandomFlip("horizontal"),
                                         RandomRotation(0.2), 
                                         RandomTranslation(height_factor=(-0.3, 0.3), width_factor=(-0.3, 0.3)),
                                         RandomZoom(0.3, 0.3),
                                         Augment(),
                                         Augment()])



In [None]:
def prepare(ds, shuffle=False, augment=True):
    if shuffle: 
        ds = ds.shuffle(batch_size)
    ds = ds.batch(batch_size)
    if augment:
        ds = ds.map(lambda x, y: (data_augmentation(x), y),
                    num_parallel_calls=AUTOTUNE)
    return ds.prefetch(buffer_size=AUTOTUNE)

In [None]:
# Hyper-parameters for Triplet Loss Model
batch_size = 256
epochs = 10
learning_rate = 1e-4
embedding_size = 128
input_shape = (96, 96, 3)
margin = 2
AUTOTUNE = tf.data.experimental.AUTOTUNE

In [None]:
train_data = prepare(train_data, shuffle=True, augment=True)
#val_data = prepare(val_data)


In [None]:
# Using Adam as optimizer of choice, although stochastic gradient descent (SGD) is a reasonable alternative to be explored
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

In [None]:
# Instantiate Xception Model
xception = Xception(weights="imagenet", input_shape=input_shape, include_top=False)
xception.trainable = False

inputs = Input(shape=input_shape)
# Layer for Xception preprocessing
layer = Lambda(lambda x: (x/127.5)-1)(inputs)
layer = xception(layer)
layer = GlobalMaxPooling2D()(layer)
layer = Dropout(0.8)(layer)
layer = Dense(embedding_size*4, activation='relu')(layer)
layer = Dropout(0.5)(layer)
layer = Dense(embedding_size)(layer)
layer = Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(layer)
model = Model(inputs, layer)
model.summary()

Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_2 (InputLayer)         [(None, 96, 96, 3)]       0         
_________________________________________________________________
lambda (Lambda)              (None, 96, 96, 3)         0         
_________________________________________________________________
xception (Functional)        (None, 3, 3, 2048)        20861480  
_________________________________________________________________
global_max_pooling2d (Global (None, 2048)              0         
_________________________________________________________________
dropout (Dropout)            (None, 2048)              0         
_________________________________________________________________
dense (Dense)                (None, 512)               1049088   
_________________________________________________________________
dropout_1 (Dropout)          (None, 512)              

In [None]:
weights = r'/content/drive/My Drive/Data/FaceDataset/global_max_pooling_FC_4_1_xception.1.4245.hdf5'
model.load_weights(weights)

#checkpoint_path = "/content/drive/My Drive/Data/FaceDataset/512BS_1M-3FC_16_4_1_xception_weights.{val_loss:.4f}.hdf5"
checkpoint_path = "/content/drive/My Drive/Data/FaceDataset/global_max_pooling_FC_4_1_xception.{val_loss:.4f}.hdf5"
checkpoint = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, monitor='val_loss', verbose=1, save_best_only=False, save_weights_only=True, mode='auto')


In [None]:
model.compile(loss=keras_batch_hard_triplet_loss,
              optimizer=optimizer)

history = model.fit(train_data,
          epochs=epochs,
          validation_data=(X_test, Y_test),
          callbacks=[checkpoint])

Epoch 1/10

Epoch 00001: saving model to /content/drive/My Drive/Data/FaceDataset/global_max_pooling_FC_4_1_xception.1.4292.hdf5
