In [38]:
import tensorflow_datasets as tfds
import tensorflow as tf
import datetime

# 1. get mnist from tensorflow_datasets
mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)
train_ds = mnist[0]
val_ds = mnist[1]

# 2. write function to create the dataset that we want
def preprocess(data, batch_size, type):
    # image should be float
    data = data.map(lambda x, t: (tf.cast(x, float), t))
    # image should be flattened
    data = data.map(lambda x, t: (tf.reshape(x, (-1,)), t))
    # image vector will here have values between -1 and 1
    data = data.map(lambda x,t: ((x/128.)-1., t))
    # we want to have two mnist images in each example
    # this leads to a single example being ((x1,y1),(x2,y2))
    zipped_ds = tf.data.Dataset.zip((data.shuffle(2000), data.shuffle(2000)))

    if type == 'greater_equal':
        # map ((x1,y1),(x2,y2)) to (x1,x2, y1==y2*) *boolean
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1] + x2[1] >= 5))
        # transform boolean target to int
        zipped_ds = zipped_ds.map(lambda x1, x2, t: (x1,x2, tf.cast(t, tf.int32)))
        # batch the dataset
        zipped_ds = zipped_ds.batch(batch_size)
        # prefetch
        zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    elif type == 'subtract':
        # map ((x1,y1),(x2,y2)) to (x1,x2, y1 - y2)
        #zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], abs(x1[1] - x2[1])))
        zipped_ds = zipped_ds.map(lambda x1, x2: (x1[0], x2[0], x1[1] - x2[1]))
        # batch the dataset
        zipped_ds = zipped_ds.batch(batch_size)
        # prefetch
        zipped_ds = zipped_ds.prefetch(tf.data.AUTOTUNE)
    return zipped_ds

greater_equal_train_ds = preprocess(train_ds, batch_size=32, type='greater_equal')
greater_equal_val_ds = preprocess(val_ds, batch_size=32, type="greater_equal")

subtract_train_ds = preprocess(train_ds, batch_size=32, type='subtract') #train_ds.apply(preprocess)
subtract_val_ds = preprocess(val_ds, batch_size=32, type="subtract") 

for img1, img2, label in greater_equal_train_ds.take(1):
    print(img1.shape, img2.shape, label.shape)

(32, 784) (32, 784) (32,)


2022-11-28 15:24:15.970500: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
2022-11-28 15:24:15.972477: W tensorflow/core/kernels/data/cache_dataset_ops.cc:856] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.


In [44]:
class TwinMNISTModel(tf.keras.Model):

    # 1. constructor
    def __init__(self, type):
        super().__init__()
        # inherit functionality from parent class

        # optimizer, loss function and metrics
        self.metrics_list = [tf.keras.metrics.BinaryAccuracy(),
                             tf.keras.metrics.Mean(name="loss")]
                
        #type-dependent settings
        if type == "greater_equal":
            self.loss_function = tf.keras.losses.BinaryCrossentropy()
            self.out_layer = tf.keras.layers.Dense(1,activation=tf.nn.sigmoid)
        elif type == "subtract":
            self.loss_function = tf.keras.losses.MeanSquaredError()
            self.out_layer = tf.keras.layers.Dense(1)

        #same layers for both types
        self.dense1 = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(32, activation=tf.nn.relu)
        
        
        
    # 2. call method (forward computation)
    def call(self, images, training=False):
        img1, img2 = images
        
        img1_x = self.dense1(img1)
        img1_x = self.dense2(img1_x)
        
        img2_x = self.dense1(img2)
        img2_x = self.dense2(img2_x)
        
        combined_x = tf.concat([img1_x, img2_x], axis=1)
        
        return self.out_layer(combined_x)



    # 3. metrics property
    @property
    def metrics(self):
        return self.metrics_list
        # return a list with all metrics in the model



    # 4. reset all metrics objects
    def reset_metrics(self):
        for metric in self.metrics:
            metric.reset_states()



    # 5. train step method
    @tf.function
    def train_step(self, data):
        img1, img2, label = data
        
        with tf.GradientTape() as tape:
            output = self((img1, img2), training=True)
            loss = self.loss_function(label, output)
            
        gradients = tape.gradient(loss, self.trainable_variables)
        
        self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
        
        # update the state of the metrics according to loss
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        # return a dictionary with metric names as keys and metric results as values
        return {m.name : m.result() for m in self.metrics}
    
    # 6. test_step method
    @tf.function
    def test_step(self, data):
        img1, img2, label = data
        # same as train step (without parameter updates)
        output = self((img1, img2), training=False)
        loss = self.loss_function(label, output)
        self.metrics[0].update_state(label, output)
        self.metrics[1].update_state(loss)
        
        return {m.name : m.result() for m in self.metrics}

In [40]:
def create_summary_writers(config_name):
    
    # Define where to save the logs
    # along with this, you may want to save a config file with the same name so you know what the hyperparameters were used
    # alternatively make a copy of the code that is used for later reference
    
    current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

    train_log_path = f"logs/{config_name}/{current_time}/train"
    val_log_path = f"logs/{config_name}/{current_time}/val"

    # log writer for training metrics
    train_summary_writer = tf.summary.create_file_writer(train_log_path)

    # log writer for validation metrics
    val_summary_writer = tf.summary.create_file_writer(val_log_path)
    
    return train_summary_writer, val_summary_writer

train_summary_writer, val_summary_writer = create_summary_writers(config_name="RUN1")

In [43]:
import tqdm

def train_model(type, optimizer):
    #1. create model
    model = TwinMNISTModel(type)
    model.optimizer = optimizer

    #2. create dataset
    mnist = tfds.load("mnist", split =["train","test"], as_supervised=True)
    train_ds = mnist[0]
    val_ds = mnist[1]
    train_ds = preprocess(train_ds, batch_size=32, type=type)
    val_ds = preprocess(val_ds, batch_size=32, type=type)

    # 3. training loop
    start_epoch = 0
    epochs = 30
    for e in range(start_epoch, epochs):

        # 1. train steps on all batches in the training data
        for data in tqdm.tqdm(train_ds, position=0, leave=True):
            metrics = model.train_step(data)

        # 2. log and print training metrics

        with train_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
        
        #print the metrics
        print([f"{key}: {value.numpy()}" for (key, value) in metrics.items()])
        
        # 3. reset metric objects
        model.reset_metrics()


        # 4. evaluate on validation data
        for data in val_ds:
            metrics = model.test_step(data)
        

        # 5. log validation metrics

        with val_summary_writer.as_default():
            # for scalar metrics:
            for metric in model.metrics:
                    tf.summary.scalar(f"{metric.name}", metric.result(), step=e)
            # alternatively, log metrics individually (allows for non-scalar metrics such as tf.keras.metrics.MeanTensor)
            # e.g. tf.summary.image(name="mean_activation_layer3", data = metrics["mean_activation_layer3"],step=e)
            
        print([f"val_{key}: {value.numpy()}" for (key, value) in metrics.items()])
        # 6. reset metric objects
        model.reset_metrics()

    # 7. save model weights if save_path is given
    save_path = "test-speicher"
    if save_path:
        model.save_weights(save_path)

    return


In [15]:
%load_ext tensorboard 

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [26]:
# open the tensorboard logs
%tensorboard --logdir logs/

Possible types
    "greater_equal"
    "subtract"

Possible optimizers
    tf.keras.optimizers.Adam()
    tf.keras.optimizers.SGD()

In [46]:
train_model("subtract", tf.keras.optimizers.Adam())

100%|██████████| 1875/1875 [00:05<00:00, 371.21it/s]


['binary_accuracy: 0.11953333020210266', 'loss: 3.7735795974731445']
['val_binary_accuracy: 0.12950000166893005', 'val_loss: 2.3355093002319336']


100%|██████████| 1875/1875 [00:03<00:00, 484.32it/s]


['binary_accuracy: 0.13093332946300507', 'loss: 2.1583542823791504']
['val_binary_accuracy: 0.1316000074148178', 'val_loss: 1.9378808736801147']


100%|██████████| 1875/1875 [00:03<00:00, 501.96it/s]


['binary_accuracy: 0.13783332705497742', 'loss: 1.8315470218658447']
['val_binary_accuracy: 0.14270000159740448', 'val_loss: 1.7500616312026978']


100%|██████████| 1875/1875 [00:03<00:00, 562.88it/s]


['binary_accuracy: 0.1394166648387909', 'loss: 1.6436582803726196']
['val_binary_accuracy: 0.14059999585151672', 'val_loss: 1.7341139316558838']


100%|██████████| 1875/1875 [00:03<00:00, 579.25it/s]


['binary_accuracy: 0.14104999601840973', 'loss: 1.534754991531372']
['val_binary_accuracy: 0.14059999585151672', 'val_loss: 1.5739374160766602']


100%|██████████| 1875/1875 [00:03<00:00, 573.19it/s]


['binary_accuracy: 0.13983333110809326', 'loss: 1.4474979639053345']
['val_binary_accuracy: 0.14229999482631683', 'val_loss: 1.5402039289474487']


100%|██████████| 1875/1875 [00:03<00:00, 580.92it/s]


['binary_accuracy: 0.1423500031232834', 'loss: 1.3694908618927002']
['val_binary_accuracy: 0.14569999277591705', 'val_loss: 1.5085698366165161']


100%|██████████| 1875/1875 [00:03<00:00, 513.18it/s]


['binary_accuracy: 0.1426166594028473', 'loss: 1.3092228174209595']
['val_binary_accuracy: 0.14069999754428864', 'val_loss: 1.5767685174942017']


100%|██████████| 1875/1875 [00:03<00:00, 551.61it/s]


['binary_accuracy: 0.14348334074020386', 'loss: 1.2463176250457764']
['val_binary_accuracy: 0.147599995136261', 'val_loss: 1.4491242170333862']


100%|██████████| 1875/1875 [00:03<00:00, 587.52it/s]


['binary_accuracy: 0.1446666717529297', 'loss: 1.2092443704605103']
['val_binary_accuracy: 0.1467999964952469', 'val_loss: 1.4348481893539429']


100%|██████████| 1875/1875 [00:03<00:00, 515.30it/s]


['binary_accuracy: 0.14641666412353516', 'loss: 1.1727663278579712']
['val_binary_accuracy: 0.1453000009059906', 'val_loss: 1.3890846967697144']


100%|██████████| 1875/1875 [00:03<00:00, 494.66it/s]


['binary_accuracy: 0.14641666412353516', 'loss: 1.1379125118255615']
['val_binary_accuracy: 0.13840000331401825', 'val_loss: 1.5018763542175293']


100%|██████████| 1875/1875 [00:04<00:00, 453.55it/s]


['binary_accuracy: 0.14785000681877136', 'loss: 1.1070386171340942']
['val_binary_accuracy: 0.15199999511241913', 'val_loss: 1.4117687940597534']


100%|██████████| 1875/1875 [00:03<00:00, 522.00it/s]


['binary_accuracy: 0.14710000157356262', 'loss: 1.0678279399871826']
['val_binary_accuracy: 0.14409999549388885', 'val_loss: 1.334977626800537']


100%|██████████| 1875/1875 [00:03<00:00, 512.61it/s]


['binary_accuracy: 0.14880000054836273', 'loss: 1.0757677555084229']
['val_binary_accuracy: 0.15219999849796295', 'val_loss: 1.3230575323104858']


100%|██████████| 1875/1875 [00:05<00:00, 361.57it/s]


['binary_accuracy: 0.148416668176651', 'loss: 1.0302684307098389']
['val_binary_accuracy: 0.1599999964237213', 'val_loss: 1.408939242362976']


100%|██████████| 1875/1875 [00:03<00:00, 547.30it/s]


['binary_accuracy: 0.1525166630744934', 'loss: 1.0013231039047241']
['val_binary_accuracy: 0.14920000731945038', 'val_loss: 1.4005138874053955']


100%|██████████| 1875/1875 [00:03<00:00, 544.69it/s]


['binary_accuracy: 0.1526000052690506', 'loss: 0.9868374466896057']
['val_binary_accuracy: 0.15129999816417694', 'val_loss: 1.3658167123794556']


100%|██████████| 1875/1875 [00:03<00:00, 549.75it/s]


['binary_accuracy: 0.15336667001247406', 'loss: 0.9679520130157471']
['val_binary_accuracy: 0.1535000056028366', 'val_loss: 1.259988784790039']


100%|██████████| 1875/1875 [00:03<00:00, 550.81it/s]


['binary_accuracy: 0.1509000062942505', 'loss: 0.9547527432441711']
['val_binary_accuracy: 0.15000000596046448', 'val_loss: 1.3599681854248047']


100%|██████████| 1875/1875 [00:03<00:00, 554.42it/s]


['binary_accuracy: 0.15504999458789825', 'loss: 0.9317067265510559']
['val_binary_accuracy: 0.1500999927520752', 'val_loss: 1.3787431716918945']


100%|██████████| 1875/1875 [00:03<00:00, 527.01it/s]


['binary_accuracy: 0.1537666618824005', 'loss: 0.9239733219146729']
['val_binary_accuracy: 0.14900000393390656', 'val_loss: 1.308957815170288']


100%|██████████| 1875/1875 [00:03<00:00, 551.71it/s]


['binary_accuracy: 0.15158332884311676', 'loss: 0.9116506576538086']
['val_binary_accuracy: 0.14159999787807465', 'val_loss: 1.3555996417999268']


100%|██████████| 1875/1875 [00:03<00:00, 497.64it/s]


['binary_accuracy: 0.1516333371400833', 'loss: 0.8861414194107056']
['val_binary_accuracy: 0.1509000062942505', 'val_loss: 1.2649399042129517']


100%|██████████| 1875/1875 [00:03<00:00, 541.01it/s]


['binary_accuracy: 0.15168333053588867', 'loss: 0.8802497982978821']
['val_binary_accuracy: 0.14630000293254852', 'val_loss: 1.2794545888900757']


100%|██████████| 1875/1875 [00:03<00:00, 498.59it/s]


['binary_accuracy: 0.1527666598558426', 'loss: 0.8758882284164429']
['val_binary_accuracy: 0.1459999978542328', 'val_loss: 1.2820228338241577']


100%|██████████| 1875/1875 [00:03<00:00, 499.98it/s]


['binary_accuracy: 0.15494999289512634', 'loss: 0.8568170070648193']
['val_binary_accuracy: 0.15119999647140503', 'val_loss: 1.2820227146148682']


100%|██████████| 1875/1875 [00:03<00:00, 481.99it/s]


['binary_accuracy: 0.15584999322891235', 'loss: 0.8492461442947388']
['val_binary_accuracy: 0.15489999949932098', 'val_loss: 1.2210263013839722']


100%|██████████| 1875/1875 [00:03<00:00, 512.53it/s]


['binary_accuracy: 0.15205000340938568', 'loss: 0.8340210914611816']
['val_binary_accuracy: 0.1412000060081482', 'val_loss: 1.371034860610962']


100%|██████████| 1875/1875 [00:03<00:00, 484.94it/s]


['binary_accuracy: 0.15209999680519104', 'loss: 0.828899621963501']
['val_binary_accuracy: 0.14880000054836273', 'val_loss: 1.4175634384155273']


greater_equal, Adam
    nach 10 Epochen:
        ['binary_accuracy: 0.9675999879837036', 'loss: 0.09635035693645477']
        ['val_binary_accuracy: 0.9659000039100647', 'val_loss: 0.10654401779174805']

greater_equal, SGD
    nach 10 Epochen:
        ['binary_accuracy: 0.9610000252723694', 'loss: 0.10990440100431442']
        ['val_binary_accuracy: 0.9629999995231628', 'val_loss: 0.10369819402694702']

subtract, Adam
    nach 10 Epochen:
        loss: 0.828899621963501
        val_loss: 1.4175634384155273

subtract, SGD
    nach 30 Epochen:
        loss: 1.0312447547912598
        val_loss: 1.4527225494384766