In [None]:
import tensorflow_datasets as tfds
import tensorflow as tf

# 1. get mnist from tensorflow_datasets
mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)
train_ds = mnist[0]
val_ds = mnist[1]

# 2. write function to create the dataset that we want
def preprocess(data, batch_size, type):
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between -1 and 1
    data = data.map(lambda x,t: ((x/128.)-1., t))
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), data.shuffle(2000)))

    if type == 'greater_equal':
        # map ((x1,y1),(x2,y2)) to (x1,x2, y1==y2*) *boolean
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1] + x2[1] >= 5))
        # transform boolean target to int
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
        # batch the dataset
        zipped_ds = zipped_ds.batch(batch_size)
        # prefetch
        zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    elif type == 'subtract':
        # map ((x1,y1),(x2,y2)) to (x1,x2, y1 - y2)
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], abs(x1[1] - x2[1])))
        # batch the dataset
        zipped_ds = zipped_ds.batch(batch_size)
        # prefetch
        zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

greater_equal_train_ds = preprocess(train_ds, batch_size=32, type='greater_equal')
greater_equal_val_ds = preprocess(val_ds, batch_size=32, type="greater_equal")

subtract_train_ds = preprocess(train_ds, batch_size=32, type='subtract') #train_ds.apply(preprocess)
subtract_val_ds = preprocess(val_ds, batch_size=32, type="subtract") 

for img1, img2, label in greater_equal_train_ds.take(1):
    print(img1.shape, img2.shape, label.shape)

In [None]:
class TwinMNISTModel(tf.keras.Model):

    # 1. constructor
    def __init__(self, type):
        super().__init__()
        # inherit functionality from parent class

        # optimizer, loss function and metrics
        self.metrics_list = [tf.keras.metrics.BinaryAccuracy(),
                             tf.keras.metrics.Mean(name="loss")]
        
        self.optimizer = tf.keras.optimizers.Adam()
        
        #type-dependent settings
        if type == "greater_equal":
            self.loss_function = tf.keras.losses.BinaryCrossentropy()
            self.out_layer = tf.keras.layer.Dense(1,activation=tf.nn.sigmoid)
        elif type == "subtract":
            self.loss_function = tf.keras.losses.MeanSquaredError()
            self.out_layer = tf.keras.layer.Dense(1,activation=tf.nn.relu)

        #same layers for both types
        self.dense1 = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        
        
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x], axis=1)
        
        return self.out_layer(combined_x)



    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model



    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()



    # 5. train step method
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        # update the state of the metrics according to loss
        # return a dictionary with metric names as keys and metric results as values

In [None]:
def training_loop(...):

    # 1. iterate over epochs


        # 2. train steps on all batches in the training data



        # 3. log and print training metrics

        with train_summary_writer.as_default():
            ...



        # 4. reset metric objects



        # 5. evaluate on validation data



        # 6. log validation metrics

        with val_summary_writer.as_default():
            ...

        # 7. reset metric objects


    # 8. save model weights