In [None]:
import tensorflow as tf
import tensorflow_addons as tfa
import tensorflow_datasets as tfds
from tensorflow.keras import models, layers

BATCH_SIZE = 32
LATENT_DEM = 128

def _normalize_img(img, label):
    img = tf.cast(img, tf.float32) / 255.
    return (img, label)

train_dataset, test_dataset = tfds.load(name="mnist", split=['train', 'test'], as_supervised=True)

# Build your input pipelines
train_dataset = train_dataset.shuffle(1024).batch(BATCH_SIZE)
train_dataset = train_dataset.map(_normalize_img)

test_dataset = test_dataset.batch(BATCH_SIZE)
test_dataset = test_dataset.map(_normalize_img)

inputs = layers.Input(shape=(28, 28, 1))
resNet50 = tf.keras.applications.ResNet50(include_top=False, weights=None, input_tensor=inputs, pooling='avg')
outputs = layers.Dense(LATENT_DEM, activation=None)(resNet50.output) # No activation on final dense layer
outputs = layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(outputs) # L2 normalize embedding

siamese_model = models.Model(inputs=inputs, outputs=outputs)

# Compile the model
siamese_model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss())

# Train the network
history = siamese_model.fit(
    train_dataset,
    epochs=3)