# 2nd Course, Part 2: Matching digit

We now turn to the second task and build a model that

- receives two images of hand-written digits as input and
- outputs a probability that both images show the same digit.

In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from IPython.display import set_matplotlib_formats
%matplotlib inline
set_matplotlib_formats('svg')


## Step 1: Preparing the data

We use the MNIST dataset again:

In [2]:
import tensorflow_datasets as tdfs

tdfs.disable_progress_bar()

mnist_train = tdfs.load(name='mnist', split='train')
mnist_test = tdfs.load(name='mnist', split='test')

mnist_train, mnist_test

(<DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>,
 <DatasetV1Adapter shapes: {image: (28, 28, 1), label: ()}, types: {image: tf.uint8, label: tf.int64}>)

We now take pairs of subsequent samples, scale the images as before, and check whether the labels coincide:

In [6]:
mnist_train.batch(2)

<DatasetV1Adapter shapes: {image: (None, 28, 28, 1), label: (None,)}, types: {image: tf.uint8, label: tf.int64}>

In [42]:
def match_pairs(samples):
    images, digits = samples['image'], samples['label']
    matching = 1. if digits[0] == digits[1] else 0.
    return (images[0] / 255, images[1]/255), matching


Xy_train = mnist_train.batch(2).map(match_pairs)
Xy_test = mnist_test.batch(2).map(match_pairs)

Xy_train

<DatasetV1Adapter shapes: (((28, 28, 1), (28, 28, 1)), ()), types: ((tf.float32, tf.float32), tf.float32)>

Let us see how many matching samples we have in our training set:


In [43]:
Xy_train.reduce(tf.constant((0,0)), lambda count, sample: count + (sample[1], 1))

<tf.Tensor: id=43117, shape=(2,), dtype=int32, numpy=array([ 3026, 30000], dtype=int32)>

## Step 2: Building the model

The `Sequential` class allows us to conveniently construct a neural network by stacking layers.

But if we need more flexibility, for example, to construct

- a model with multiple inputs or multiple outputs, or
- a general directed acyclic graph of layers,

we need to use the `tf.keras.Model` class, also known as the functional API of keras.

The idea for our model is that we

1. apply our pre-trained digit-classifier to both images,
2. obtain two probability distributions $p^{(1)}$ and $p^{(2)}$,
3. and use a dense layer to deduce the desired probability.


In [44]:
CLASSIFIFER_PATH = 'classifier'

classifier = tf.keras.models.load_model(CLASSIFIFER_PATH)
classifier.trainable = False

def build_matcher_dense():
    image_1 = tf.keras.layers.Input((28,28,1))
    image_2 = tf.keras.layers.Input((28,28,1))
    probs_1 = classifier(image_1)
    probs_2 = classifier(image_2)
    both_probs = tf.keras.layers.Concatenate()([probs_1, probs_2])
    dense = tf.keras.layers.Dense(32, activation='relu')(both_probs)
    prediction = tf.keras.layers.Dense(1, activation='sigmoid')(dense)
    matcher = tf.keras.Model(inputs=[image_1, image_2], outputs=[prediction])
    return matcher


Let's train our matcher!

In [45]:
def train(model, nr_batches=400, nr_epochs=5):
    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    history = model.fit(Xy_train.repeat().batch(32).take(nr_batches),
                        validation_data=Xy_test.repeat().batch(32).take(nr_batches // 2),
                        epochs=nr_epochs)

matcher = build_matcher_dense()
train(matcher)

Epoch 1/5


      1/Unknown - 1s 515ms/step - loss: 0.6374 - accuracy: 0.7500

      2/Unknown - 1s 261ms/step - loss: 0.6292 - accuracy: 0.7812

      3/Unknown - 1s 176ms/step - loss: 0.6317 - accuracy: 0.7708

      4/Unknown - 1s 133ms/step - loss: 0.6342 - accuracy: 0.7500

      5/Unknown - 1s 108ms/step - loss: 0.6399 - accuracy: 0.7375

      6/Unknown - 1s 91ms/step - loss: 0.6377 - accuracy: 0.7552 

      7/Unknown - 1s 79ms/step - loss: 0.6346 - accuracy: 0.7723

      8/Unknown - 1s 70ms/step - loss: 0.6308 - accuracy: 0.7852

      9/Unknown - 1s 63ms/step - loss: 0.6322 - accuracy: 0.7812

     10/Unknown - 1s 57ms/step - loss: 0.6313 - accuracy: 0.7937

     11/Unknown - 1s 53ms/step - loss: 0.6276 - accuracy: 0.8011

     12/Unknown - 1s 49ms/step - loss: 0.6285 - accuracy: 0.7969

     13/Unknown - 1s 46ms/step - loss: 0.6236 - accuracy: 0.8101

     14/Unknown - 1s 43ms/step - loss: 0.6232 - accuracy: 0.8125

     15/Unknown - 1s 40ms/step - loss: 0.6208 - accuracy: 0.8208

     16/Unknown - 1s 38ms/step - loss: 0.6186 - accuracy: 0.8281

     17/Unknown - 1s 36ms/step - loss: 0.6159 - accuracy: 0.8346

     18/Unknown - 1s 35ms/step - loss: 0.6124 - accuracy: 0.8420

     19/Unknown - 1s 33ms/step - loss: 0.6086 - accuracy: 0.8487

     20/Unknown - 1s 32ms/step - loss: 0.6072 - accuracy: 0.8469

     21/Unknown - 1s 30ms/step - loss: 0.6047 - accuracy: 0.8512

     22/Unknown - 1s 29ms/step - loss: 0.6037 - accuracy: 0.8523

     23/Unknown - 1s 28ms/step - loss: 0.5998 - accuracy: 0.8587

     24/Unknown - 1s 27ms/step - loss: 0.5955 - accuracy: 0.8646

     25/Unknown - 1s 27ms/step - loss: 0.5944 - accuracy: 0.8650

     26/Unknown - 1s 26ms/step - loss: 0.5917 - accuracy: 0.8678

     27/Unknown - 1s 25ms/step - loss: 0.5906 - accuracy: 0.8692

     28/Unknown - 1s 24ms/step - loss: 0.5905 - accuracy: 0.8672

     29/Unknown - 1s 24ms/step - loss: 0.5882 - accuracy: 0.8696

     30/Unknown - 1s 23ms/step - loss: 0.5848 - accuracy: 0.8729

     31/Unknown - 1s 23ms/step - loss: 0.5840 - accuracy: 0.8730

     32/Unknown - 1s 22ms/step - loss: 0.5803 - accuracy: 0.8770

     33/Unknown - 1s 22ms/step - loss: 0.5780 - accuracy: 0.8778

     34/Unknown - 1s 21ms/step - loss: 0.5762 - accuracy: 0.8787

     35/Unknown - 1s 21ms/step - loss: 0.5741 - accuracy: 0.8804

     36/Unknown - 1s 20ms/step - loss: 0.5718 - accuracy: 0.8819

     37/Unknown - 1s 20ms/step - loss: 0.5689 - accuracy: 0.8843

     38/Unknown - 1s 20ms/step - loss: 0.5690 - accuracy: 0.8808

     39/Unknown - 1s 19ms/step - loss: 0.5677 - accuracy: 0.8806

     40/Unknown - 1s 19ms/step - loss: 0.5657 - accuracy: 0.8813

     41/Unknown - 1s 19ms/step - loss: 0.5641 - accuracy: 0.8811

     42/Unknown - 1s 18ms/step - loss: 0.5621 - accuracy: 0.8817

     43/Unknown - 1s 18ms/step - loss: 0.5598 - accuracy: 0.8830

     44/Unknown - 1s 18ms/step - loss: 0.5576 - accuracy: 0.8842

     45/Unknown - 1s 18ms/step - loss: 0.5550 - accuracy: 0.8861

     46/Unknown - 1s 17ms/step - loss: 0.5523 - accuracy: 0.8872

     47/Unknown - 1s 17ms/step - loss: 0.5525 - accuracy: 0.8850

     48/Unknown - 1s 17ms/step - loss: 0.5497 - accuracy: 0.8867

     49/Unknown - 1s 17ms/step - loss: 0.5475 - accuracy: 0.8878

     50/Unknown - 1s 17ms/step - loss: 0.5468 - accuracy: 0.8875

     51/Unknown - 1s 16ms/step - loss: 0.5446 - accuracy: 0.8885

     52/Unknown - 1s 16ms/step - loss: 0.5425 - accuracy: 0.8894

     53/Unknown - 1s 16ms/step - loss: 0.5398 - accuracy: 0.8909

     54/Unknown - 1s 16ms/step - loss: 0.5405 - accuracy: 0.8883

     55/Unknown - 1s 16ms/step - loss: 0.5378 - accuracy: 0.8898

     56/Unknown - 1s 15ms/step - loss: 0.5359 - accuracy: 0.8901

     57/Unknown - 1s 15ms/step - loss: 0.5344 - accuracy: 0.8904

     58/Unknown - 1s 15ms/step - loss: 0.5323 - accuracy: 0.8912

     59/Unknown - 1s 15ms/step - loss: 0.5299 - accuracy: 0.8925

     60/Unknown - 1s 15ms/step - loss: 0.5291 - accuracy: 0.8917

     61/Unknown - 1s 15ms/step - loss: 0.5270 - accuracy: 0.8924

     62/Unknown - 1s 15ms/step - loss: 0.5268 - accuracy: 0.8911

     63/Unknown - 1s 14ms/step - loss: 0.5255 - accuracy: 0.8914

     64/Unknown - 1s 14ms/step - loss: 0.5250 - accuracy: 0.8901

     65/Unknown - 1s 14ms/step - loss: 0.5234 - accuracy: 0.8904

     66/Unknown - 1s 14ms/step - loss: 0.5216 - accuracy: 0.8906

     67/Unknown - 1s 14ms/step - loss: 0.5200 - accuracy: 0.8909

     68/Unknown - 1s 14ms/step - loss: 0.5184 - accuracy: 0.8911

     69/Unknown - 1s 14ms/step - loss: 0.5173 - accuracy: 0.8909

     70/Unknown - 1s 14ms/step - loss: 0.5161 - accuracy: 0.8906

     71/Unknown - 1s 13ms/step - loss: 0.5158 - accuracy: 0.8895

     72/Unknown - 1s 13ms/step - loss: 0.5146 - accuracy: 0.8893

     73/Unknown - 1s 13ms/step - loss: 0.5147 - accuracy: 0.8878

     74/Unknown - 1s 13ms/step - loss: 0.5135 - accuracy: 0.8877

     75/Unknown - 1s 13ms/step - loss: 0.5116 - accuracy: 0.8883

     76/Unknown - 1s 13ms/step - loss: 0.5100 - accuracy: 0.8886

     77/Unknown - 1s 13ms/step - loss: 0.5089 - accuracy: 0.8884

     78/Unknown - 1s 13ms/step - loss: 0.5094 - accuracy: 0.8870

     79/Unknown - 1s 13ms/step - loss: 0.5073 - accuracy: 0.8877

     80/Unknown - 1s 13ms/step - loss: 0.5059 - accuracy: 0.8879

     81/Unknown - 1s 13ms/step - loss: 0.5043 - accuracy: 0.8881

     82/Unknown - 1s 12ms/step - loss: 0.5027 - accuracy: 0.8883

     83/Unknown - 1s 12ms/step - loss: 0.5022 - accuracy: 0.8878

     84/Unknown - 1s 12ms/step - loss: 0.5016 - accuracy: 0.8873

     85/Unknown - 1s 12ms/step - loss: 0.5013 - accuracy: 0.8864

     86/Unknown - 1s 12ms/step - loss: 0.4990 - accuracy: 0.8874

     87/Unknown - 1s 12ms/step - loss: 0.4994 - accuracy: 0.8861

     88/Unknown - 1s 12ms/step - loss: 0.4987 - accuracy: 0.8857

     89/Unknown - 1s 12ms/step - loss: 0.4985 - accuracy: 0.8848

     90/Unknown - 1s 12ms/step - loss: 0.4966 - accuracy: 0.8854

     91/Unknown - 1s 12ms/step - loss: 0.4950 - accuracy: 0.8856

     92/Unknown - 1s 12ms/step - loss: 0.4924 - accuracy: 0.8869

     93/Unknown - 1s 12ms/step - loss: 0.4910 - accuracy: 0.8871

     94/Unknown - 1s 12ms/step - loss: 0.4887 - accuracy: 0.8880

     95/Unknown - 1s 12ms/step - loss: 0.4875 - accuracy: 0.8882

     96/Unknown - 1s 12ms/step - loss: 0.4868 - accuracy: 0.8880

     97/Unknown - 1s 11ms/step - loss: 0.4868 - accuracy: 0.8872

     98/Unknown - 1s 11ms/step - loss: 0.4846 - accuracy: 0.8881

     99/Unknown - 1s 11ms/step - loss: 0.4829 - accuracy: 0.8886

    100/Unknown - 1s 11ms/step - loss: 0.4817 - accuracy: 0.8888

    101/Unknown - 1s 11ms/step - loss: 0.4813 - accuracy: 0.8883

    102/Unknown - 1s 11ms/step - loss: 0.4803 - accuracy: 0.8882

    103/Unknown - 1s 11ms/step - loss: 0.4791 - accuracy: 0.8883

    104/Unknown - 1s 11ms/step - loss: 0.4790 - accuracy: 0.8879

    105/Unknown - 1s 11ms/step - loss: 0.4777 - accuracy: 0.8881

    106/Unknown - 1s 11ms/step - loss: 0.4764 - accuracy: 0.8883

    107/Unknown - 1s 11ms/step - loss: 0.4760 - accuracy: 0.8879

    108/Unknown - 1s 11ms/step - loss: 0.4743 - accuracy: 0.8883

    109/Unknown - 1s 11ms/step - loss: 0.4727 - accuracy: 0.8888

    110/Unknown - 1s 11ms/step - loss: 0.4715 - accuracy: 0.8889

    111/Unknown - 1s 11ms/step - loss: 0.4717 - accuracy: 0.8882

    112/Unknown - 1s 11ms/step - loss: 0.4701 - accuracy: 0.8887

    113/Unknown - 1s 11ms/step - loss: 0.4678 - accuracy: 0.8897

    114/Unknown - 1s 11ms/step - loss: 0.4661 - accuracy: 0.8904

    115/Unknown - 1s 11ms/step - loss: 0.4653 - accuracy: 0.8902

    116/Unknown - 1s 11ms/step - loss: 0.4637 - accuracy: 0.8906

    117/Unknown - 1s 11ms/step - loss: 0.4623 - accuracy: 0.8910

    118/Unknown - 1s 11ms/step - loss: 0.4604 - accuracy: 0.8917

    119/Unknown - 1s 11ms/step - loss: 0.4593 - accuracy: 0.8918

    120/Unknown - 1s 10ms/step - loss: 0.4578 - accuracy: 0.8922

    121/Unknown - 1s 10ms/step - loss: 0.4567 - accuracy: 0.8923

    122/Unknown - 1s 10ms/step - loss: 0.4557 - accuracy: 0.8924

    123/Unknown - 1s 10ms/step - loss: 0.4547 - accuracy: 0.8925

    124/Unknown - 1s 10ms/step - loss: 0.4528 - accuracy: 0.8931

    125/Unknown - 1s 10ms/step - loss: 0.4531 - accuracy: 0.8925

    126/Unknown - 1s 10ms/step - loss: 0.4526 - accuracy: 0.8924

    127/Unknown - 1s 10ms/step - loss: 0.4513 - accuracy: 0.8927

    128/Unknown - 1s 10ms/step - loss: 0.4513 - accuracy: 0.8921

    129/Unknown - 1s 10ms/step - loss: 0.4496 - accuracy: 0.8927

    130/Unknown - 1s 10ms/step - loss: 0.4490 - accuracy: 0.8925

    131/Unknown - 1s 10ms/step - loss: 0.4472 - accuracy: 0.8931

    132/Unknown - 1s 10ms/step - loss: 0.4454 - accuracy: 0.8937

    133/Unknown - 1s 10ms/step - loss: 0.4444 - accuracy: 0.8938

    134/Unknown - 1s 10ms/step - loss: 0.4442 - accuracy: 0.8934

    135/Unknown - 1s 10ms/step - loss: 0.4429 - accuracy: 0.8938

    136/Unknown - 1s 10ms/step - loss: 0.4425 - accuracy: 0.8936

    137/Unknown - 1s 10ms/step - loss: 0.4428 - accuracy: 0.8930

    138/Unknown - 1s 10ms/step - loss: 0.4416 - accuracy: 0.8933

    139/Unknown - 1s 10ms/step - loss: 0.4403 - accuracy: 0.8937

    140/Unknown - 1s 10ms/step - loss: 0.4398 - accuracy: 0.8935

    141/Unknown - 1s 10ms/step - loss: 0.4397 - accuracy: 0.8932

    142/Unknown - 1s 10ms/step - loss: 0.4397 - accuracy: 0.8928

    143/Unknown - 1s 10ms/step - loss: 0.4385 - accuracy: 0.8931

    144/Unknown - 1s 10ms/step - loss: 0.4369 - accuracy: 0.8937

    145/Unknown - 1s 10ms/step - loss: 0.4349 - accuracy: 0.8944

    146/Unknown - 1s 10ms/step - loss: 0.4350 - accuracy: 0.8940

    147/Unknown - 1s 10ms/step - loss: 0.4338 - accuracy: 0.8943

    148/Unknown - 1s 10ms/step - loss: 0.4335 - accuracy: 0.8942

    149/Unknown - 1s 10ms/step - loss: 0.4331 - accuracy: 0.8941

    150/Unknown - 1s 10ms/step - loss: 0.4326 - accuracy: 0.8940

    151/Unknown - 1s 10ms/step - loss: 0.4321 - accuracy: 0.8938

    152/Unknown - 1s 10ms/step - loss: 0.4312 - accuracy: 0.8939

    153/Unknown - 1s 9ms/step - loss: 0.4300 - accuracy: 0.8942 

    154/Unknown - 1s 9ms/step - loss: 0.4297 - accuracy: 0.8941

    155/Unknown - 1s 9ms/step - loss: 0.4285 - accuracy: 0.8944

    156/Unknown - 1s 9ms/step - loss: 0.4274 - accuracy: 0.8946

    157/Unknown - 1s 9ms/step - loss: 0.4264 - accuracy: 0.8949

    158/Unknown - 1s 9ms/step - loss: 0.4253 - accuracy: 0.8952

    159/Unknown - 1s 9ms/step - loss: 0.4253 - accuracy: 0.8949

    160/Unknown - 1s 9ms/step - loss: 0.4242 - accuracy: 0.8951

    161/Unknown - 2s 9ms/step - loss: 0.4231 - accuracy: 0.8954

    162/Unknown - 2s 9ms/step - loss: 0.4229 - accuracy: 0.8953

    163/Unknown - 2s 9ms/step - loss: 0.4229 - accuracy: 0.8949

    164/Unknown - 2s 9ms/step - loss: 0.4218 - accuracy: 0.8952

    165/Unknown - 2s 9ms/step - loss: 0.4219 - accuracy: 0.8949

    166/Unknown - 2s 9ms/step - loss: 0.4208 - accuracy: 0.8951

    167/Unknown - 2s 9ms/step - loss: 0.4209 - accuracy: 0.8948

    168/Unknown - 2s 9ms/step - loss: 0.4205 - accuracy: 0.8947

    169/Unknown - 2s 9ms/step - loss: 0.4200 - accuracy: 0.8948

    170/Unknown - 2s 9ms/step - loss: 0.4197 - accuracy: 0.8947

    171/Unknown - 2s 9ms/step - loss: 0.4197 - accuracy: 0.8944

    172/Unknown - 2s 9ms/step - loss: 0.4197 - accuracy: 0.8941

    173/Unknown - 2s 9ms/step - loss: 0.4194 - accuracy: 0.8940

    174/Unknown - 2s 9ms/step - loss: 0.4195 - accuracy: 0.8937

    175/Unknown - 2s 9ms/step - loss: 0.4189 - accuracy: 0.8938

    176/Unknown - 2s 9ms/step - loss: 0.4189 - accuracy: 0.8935

    177/Unknown - 2s 9ms/step - loss: 0.4183 - accuracy: 0.8935

    178/Unknown - 2s 9ms/step - loss: 0.4172 - accuracy: 0.8938

    179/Unknown - 2s 9ms/step - loss: 0.4156 - accuracy: 0.8944

    180/Unknown - 2s 9ms/step - loss: 0.4147 - accuracy: 0.8946

    181/Unknown - 2s 9ms/step - loss: 0.4144 - accuracy: 0.8945

    182/Unknown - 2s 9ms/step - loss: 0.4137 - accuracy: 0.8946

    183/Unknown - 2s 9ms/step - loss: 0.4142 - accuracy: 0.8941



    184/Unknown - 2s 9ms/step - loss: 0.4129 - accuracy: 0.8945

    185/Unknown - 2s 9ms/step - loss: 0.4129 - accuracy: 0.8943

    186/Unknown - 2s 9ms/step - loss: 0.4123 - accuracy: 0.8943

    187/Unknown - 2s 9ms/step - loss: 0.4117 - accuracy: 0.8944

    188/Unknown - 2s 9ms/step - loss: 0.4108 - accuracy: 0.8946

    189/Unknown - 2s 9ms/step - loss: 0.4096 - accuracy: 0.8950

    190/Unknown - 2s 9ms/step - loss: 0.4094 - accuracy: 0.8949

    191/Unknown - 2s 9ms/step - loss: 0.4082 - accuracy: 0.8953

    192/Unknown - 2s 9ms/step - loss: 0.4071 - accuracy: 0.8957

    193/Unknown - 2s 9ms/step - loss: 0.4065 - accuracy: 0.8957

    194/Unknown - 2s 9ms/step - loss: 0.4054 - accuracy: 0.8961

    195/Unknown - 2s 9ms/step - loss: 0.4042 - accuracy: 0.8965

    196/Unknown - 2s 9ms/step - loss: 0.4041 - accuracy: 0.8964

    197/Unknown - 2s 9ms/step - loss: 0.4036 - accuracy: 0.8964

    198/Unknown - 2s 9ms/step - loss: 0.4030 - accuracy: 0.8965

    199/Unknown - 2s 9ms/step - loss: 0.4025 - accuracy: 0.8965

    200/Unknown - 2s 9ms/step - loss: 0.4017 - accuracy: 0.8967



    201/Unknown - 2s 9ms/step - loss: 0.4012 - accuracy: 0.8968

    202/Unknown - 2s 9ms/step - loss: 0.4004 - accuracy: 0.8970

    203/Unknown - 2s 9ms/step - loss: 0.4006 - accuracy: 0.8967

    204/Unknown - 2s 9ms/step - loss: 0.4002 - accuracy: 0.8968

    205/Unknown - 2s 9ms/step - loss: 0.3995 - accuracy: 0.8970

    206/Unknown - 2s 9ms/step - loss: 0.3990 - accuracy: 0.8970

    207/Unknown - 2s 9ms/step - loss: 0.3976 - accuracy: 0.8975

    208/Unknown - 2s 9ms/step - loss: 0.3977 - accuracy: 0.8972

    209/Unknown - 2s 9ms/step - loss: 0.3976 - accuracy: 0.8971

    210/Unknown - 2s 9ms/step - loss: 0.3962 - accuracy: 0.8976

    211/Unknown - 2s 9ms/step - loss: 0.3961 - accuracy: 0.8975

    212/Unknown - 2s 9ms/step - loss: 0.3957 - accuracy: 0.8976

    213/Unknown - 2s 9ms/step - loss: 0.3959 - accuracy: 0.8973

    214/Unknown - 2s 9ms/step - loss: 0.3946 - accuracy: 0.8978

    215/Unknown - 2s 8ms/step - loss: 0.3950 - accuracy: 0.8974

    216/Unknown - 2s 8ms/step - loss: 0.3946 - accuracy: 0.8974

    217/Unknown - 2s 8ms/step - loss: 0.3943 - accuracy: 0.8975

    218/Unknown - 2s 8ms/step - loss: 0.3935 - accuracy: 0.8976

    219/Unknown - 2s 8ms/step - loss: 0.3939 - accuracy: 0.8973



    220/Unknown - 2s 8ms/step - loss: 0.3938 - accuracy: 0.8972

    221/Unknown - 2s 8ms/step - loss: 0.3934 - accuracy: 0.8972

    222/Unknown - 2s 8ms/step - loss: 0.3942 - accuracy: 0.8967

    223/Unknown - 2s 8ms/step - loss: 0.3940 - accuracy: 0.8966

    224/Unknown - 2s 8ms/step - loss: 0.3941 - accuracy: 0.8963

    225/Unknown - 2s 8ms/step - loss: 0.3935 - accuracy: 0.8965

    226/Unknown - 2s 8ms/step - loss: 0.3931 - accuracy: 0.8966

    227/Unknown - 2s 8ms/step - loss: 0.3926 - accuracy: 0.8966

    228/Unknown - 2s 8ms/step - loss: 0.3924 - accuracy: 0.8965

    229/Unknown - 2s 8ms/step - loss: 0.3917 - accuracy: 0.8967

    230/Unknown - 2s 8ms/step - loss: 0.3913 - accuracy: 0.8967

    231/Unknown - 2s 8ms/step - loss: 0.3906 - accuracy: 0.8969

    232/Unknown - 2s 8ms/step - loss: 0.3897 - accuracy: 0.8972

    233/Unknown - 2s 8ms/step - loss: 0.3895 - accuracy: 0.8971

    234/Unknown - 2s 8ms/step - loss: 0.3889 - accuracy: 0.8973

    235/Unknown - 2s 8ms/step - loss: 0.3882 - accuracy: 0.8975

    236/Unknown - 2s 8ms/step - loss: 0.3883 - accuracy: 0.8972

    237/Unknown - 2s 8ms/step - loss: 0.3882 - accuracy: 0.8972

    238/Unknown - 2s 8ms/step - loss: 0.3882 - accuracy: 0.8971

    239/Unknown - 2s 8ms/step - loss: 0.3879 - accuracy: 0.8971

    240/Unknown - 2s 8ms/step - loss: 0.3878 - accuracy: 0.8970

    241/Unknown - 2s 8ms/step - loss: 0.3875 - accuracy: 0.8970

    242/Unknown - 2s 8ms/step - loss: 0.3868 - accuracy: 0.8972

    243/Unknown - 2s 8ms/step - loss: 0.3860 - accuracy: 0.8975



    244/Unknown - 2s 8ms/step - loss: 0.3859 - accuracy: 0.8974

    245/Unknown - 2s 8ms/step - loss: 0.3861 - accuracy: 0.8972

    246/Unknown - 2s 8ms/step - loss: 0.3858 - accuracy: 0.8972

    247/Unknown - 2s 8ms/step - loss: 0.3858 - accuracy: 0.8971

    248/Unknown - 2s 8ms/step - loss: 0.3858 - accuracy: 0.8971

    249/Unknown - 2s 8ms/step - loss: 0.3849 - accuracy: 0.8973

    250/Unknown - 2s 8ms/step - loss: 0.3854 - accuracy: 0.8970

    251/Unknown - 2s 8ms/step - loss: 0.3856 - accuracy: 0.8968

    252/Unknown - 2s 8ms/step - loss: 0.3858 - accuracy: 0.8966

    253/Unknown - 2s 8ms/step - loss: 0.3855 - accuracy: 0.8966

    254/Unknown - 2s 8ms/step - loss: 0.3854 - accuracy: 0.8965

    255/Unknown - 2s 8ms/step - loss: 0.3851 - accuracy: 0.8966

    256/Unknown - 2s 8ms/step - loss: 0.3852 - accuracy: 0.8964

    257/Unknown - 2s 8ms/step - loss: 0.3844 - accuracy: 0.8966

    258/Unknown - 2s 8ms/step - loss: 0.3843 - accuracy: 0.8966

    259/Unknown - 2s 8ms/step - loss: 0.3837 - accuracy: 0.8967

    260/Unknown - 2s 8ms/step - loss: 0.3833 - accuracy: 0.8968

    261/Unknown - 2s 8ms/step - loss: 0.3829 - accuracy: 0.8968

    262/Unknown - 2s 8ms/step - loss: 0.3829 - accuracy: 0.8967

    263/Unknown - 2s 8ms/step - loss: 0.3830 - accuracy: 0.8965

    264/Unknown - 2s 8ms/step - loss: 0.3828 - accuracy: 0.8964

    265/Unknown - 2s 8ms/step - loss: 0.3828 - accuracy: 0.8963

    266/Unknown - 2s 8ms/step - loss: 0.3828 - accuracy: 0.8963

    267/Unknown - 2s 8ms/step - loss: 0.3822 - accuracy: 0.8964

    268/Unknown - 2s 8ms/step - loss: 0.3821 - accuracy: 0.8963

    269/Unknown - 2s 8ms/step - loss: 0.3815 - accuracy: 0.8965

    270/Unknown - 2s 8ms/step - loss: 0.3810 - accuracy: 0.8966

    271/Unknown - 2s 8ms/step - loss: 0.3813 - accuracy: 0.8963

    272/Unknown - 2s 8ms/step - loss: 0.3807 - accuracy: 0.8965

    273/Unknown - 2s 8ms/step - loss: 0.3809 - accuracy: 0.8963

    274/Unknown - 2s 8ms/step - loss: 0.3811 - accuracy: 0.8961

    275/Unknown - 2s 8ms/step - loss: 0.3813 - accuracy: 0.8959

    276/Unknown - 2s 8ms/step - loss: 0.3813 - accuracy: 0.8958

    277/Unknown - 2s 8ms/step - loss: 0.3810 - accuracy: 0.8959

    278/Unknown - 2s 8ms/step - loss: 0.3816 - accuracy: 0.8955

    279/Unknown - 2s 8ms/step - loss: 0.3806 - accuracy: 0.8958

    280/Unknown - 2s 8ms/step - loss: 0.3799 - accuracy: 0.8961

    281/Unknown - 2s 8ms/step - loss: 0.3791 - accuracy: 0.8964

    282/Unknown - 2s 8ms/step - loss: 0.3793 - accuracy: 0.8962

    283/Unknown - 2s 8ms/step - loss: 0.3785 - accuracy: 0.8964

    284/Unknown - 2s 8ms/step - loss: 0.3780 - accuracy: 0.8966

    285/Unknown - 2s 8ms/step - loss: 0.3780 - accuracy: 0.8965

    286/Unknown - 2s 8ms/step - loss: 0.3778 - accuracy: 0.8965

    287/Unknown - 2s 8ms/step - loss: 0.3774 - accuracy: 0.8966

    288/Unknown - 2s 8ms/step - loss: 0.3765 - accuracy: 0.8969

    289/Unknown - 2s 8ms/step - loss: 0.3761 - accuracy: 0.8971

    290/Unknown - 2s 8ms/step - loss: 0.3756 - accuracy: 0.8972

    291/Unknown - 2s 8ms/step - loss: 0.3753 - accuracy: 0.8972

    292/Unknown - 2s 8ms/step - loss: 0.3757 - accuracy: 0.8969

    293/Unknown - 2s 8ms/step - loss: 0.3752 - accuracy: 0.8971

    294/Unknown - 2s 8ms/step - loss: 0.3755 - accuracy: 0.8968

    295/Unknown - 2s 8ms/step - loss: 0.3752 - accuracy: 0.8968

    296/Unknown - 2s 8ms/step - loss: 0.3748 - accuracy: 0.8970

    297/Unknown - 2s 8ms/step - loss: 0.3746 - accuracy: 0.8970

    298/Unknown - 2s 8ms/step - loss: 0.3743 - accuracy: 0.8970

    299/Unknown - 2s 8ms/step - loss: 0.3743 - accuracy: 0.8969

    300/Unknown - 2s 8ms/step - loss: 0.3738 - accuracy: 0.8971

    301/Unknown - 2s 8ms/step - loss: 0.3734 - accuracy: 0.8971

    302/Unknown - 2s 8ms/step - loss: 0.3726 - accuracy: 0.8975

    303/Unknown - 2s 8ms/step - loss: 0.3725 - accuracy: 0.8974

    304/Unknown - 2s 8ms/step - loss: 0.3716 - accuracy: 0.8977

    305/Unknown - 2s 8ms/step - loss: 0.3710 - accuracy: 0.8980

    306/Unknown - 2s 8ms/step - loss: 0.3703 - accuracy: 0.8982

    307/Unknown - 2s 8ms/step - loss: 0.3707 - accuracy: 0.8979

    308/Unknown - 2s 8ms/step - loss: 0.3705 - accuracy: 0.8979

    309/Unknown - 2s 8ms/step - loss: 0.3701 - accuracy: 0.8981

    310/Unknown - 2s 8ms/step - loss: 0.3696 - accuracy: 0.8982

    311/Unknown - 2s 8ms/step - loss: 0.3695 - accuracy: 0.8982

    312/Unknown - 2s 8ms/step - loss: 0.3696 - accuracy: 0.8980

    313/Unknown - 2s 8ms/step - loss: 0.3689 - accuracy: 0.8983

    314/Unknown - 2s 8ms/step - loss: 0.3685 - accuracy: 0.8984

    315/Unknown - 2s 8ms/step - loss: 0.3684 - accuracy: 0.8983

    316/Unknown - 2s 8ms/step - loss: 0.3682 - accuracy: 0.8983

    317/Unknown - 2s 8ms/step - loss: 0.3687 - accuracy: 0.8980

    318/Unknown - 2s 8ms/step - loss: 0.3687 - accuracy: 0.8979

    319/Unknown - 2s 8ms/step - loss: 0.3689 - accuracy: 0.8977

    320/Unknown - 2s 8ms/step - loss: 0.3685 - accuracy: 0.8979

    321/Unknown - 2s 8ms/step - loss: 0.3683 - accuracy: 0.8979

    322/Unknown - 2s 8ms/step - loss: 0.3683 - accuracy: 0.8978

    323/Unknown - 2s 8ms/step - loss: 0.3674 - accuracy: 0.8981

    324/Unknown - 2s 8ms/step - loss: 0.3670 - accuracy: 0.8982

    325/Unknown - 2s 8ms/step - loss: 0.3670 - accuracy: 0.8982

    326/Unknown - 2s 8ms/step - loss: 0.3667 - accuracy: 0.8982

    327/Unknown - 3s 8ms/step - loss: 0.3663 - accuracy: 0.8983

    328/Unknown - 3s 8ms/step - loss: 0.3663 - accuracy: 0.8982

    329/Unknown - 3s 8ms/step - loss: 0.3661 - accuracy: 0.8983

    330/Unknown - 3s 8ms/step - loss: 0.3661 - accuracy: 0.8982

    331/Unknown - 3s 8ms/step - loss: 0.3657 - accuracy: 0.8983

    332/Unknown - 3s 8ms/step - loss: 0.3655 - accuracy: 0.8983

    333/Unknown - 3s 8ms/step - loss: 0.3651 - accuracy: 0.8985

    334/Unknown - 3s 8ms/step - loss: 0.3649 - accuracy: 0.8985

    335/Unknown - 3s 8ms/step - loss: 0.3647 - accuracy: 0.8985

    336/Unknown - 3s 8ms/step - loss: 0.3643 - accuracy: 0.8986

    337/Unknown - 3s 8ms/step - loss: 0.3639 - accuracy: 0.8987

    338/Unknown - 3s 8ms/step - loss: 0.3642 - accuracy: 0.8985

    339/Unknown - 3s 8ms/step - loss: 0.3636 - accuracy: 0.8987

    340/Unknown - 3s 8ms/step - loss: 0.3632 - accuracy: 0.8988

    341/Unknown - 3s 8ms/step - loss: 0.3627 - accuracy: 0.8990

    342/Unknown - 3s 8ms/step - loss: 0.3629 - accuracy: 0.8988

    343/Unknown - 3s 8ms/step - loss: 0.3625 - accuracy: 0.8990

    344/Unknown - 3s 8ms/step - loss: 0.3619 - accuracy: 0.8992

    345/Unknown - 3s 8ms/step - loss: 0.3613 - accuracy: 0.8994

    346/Unknown - 3s 8ms/step - loss: 0.3610 - accuracy: 0.8994

    347/Unknown - 3s 8ms/step - loss: 0.3607 - accuracy: 0.8995

    348/Unknown - 3s 8ms/step - loss: 0.3614 - accuracy: 0.8991

    349/Unknown - 3s 8ms/step - loss: 0.3614 - accuracy: 0.8990

    350/Unknown - 3s 8ms/step - loss: 0.3610 - accuracy: 0.8991

    351/Unknown - 3s 8ms/step - loss: 0.3610 - accuracy: 0.8990

    352/Unknown - 3s 8ms/step - loss: 0.3608 - accuracy: 0.8991

    353/Unknown - 3s 8ms/step - loss: 0.3609 - accuracy: 0.8990

    354/Unknown - 3s 8ms/step - loss: 0.3609 - accuracy: 0.8989

    355/Unknown - 3s 8ms/step - loss: 0.3613 - accuracy: 0.8987

    356/Unknown - 3s 7ms/step - loss: 0.3615 - accuracy: 0.8985

    357/Unknown - 3s 7ms/step - loss: 0.3611 - accuracy: 0.8986



    358/Unknown - 3s 7ms/step - loss: 0.3611 - accuracy: 0.8986

    359/Unknown - 3s 7ms/step - loss: 0.3608 - accuracy: 0.8987

    360/Unknown - 3s 7ms/step - loss: 0.3610 - accuracy: 0.8985

    361/Unknown - 3s 7ms/step - loss: 0.3611 - accuracy: 0.8984

    362/Unknown - 3s 7ms/step - loss: 0.3611 - accuracy: 0.8983

    363/Unknown - 3s 7ms/step - loss: 0.3607 - accuracy: 0.8984

    364/Unknown - 3s 7ms/step - loss: 0.3604 - accuracy: 0.8985

    365/Unknown - 3s 7ms/step - loss: 0.3598 - accuracy: 0.8987

    366/Unknown - 3s 7ms/step - loss: 0.3598 - accuracy: 0.8987

    367/Unknown - 3s 7ms/step - loss: 0.3595 - accuracy: 0.8987

    368/Unknown - 3s 7ms/step - loss: 0.3592 - accuracy: 0.8988

    369/Unknown - 3s 7ms/step - loss: 0.3593 - accuracy: 0.8986

    370/Unknown - 3s 7ms/step - loss: 0.3593 - accuracy: 0.8986

    371/Unknown - 3s 7ms/step - loss: 0.3595 - accuracy: 0.8984

    372/Unknown - 3s 7ms/step - loss: 0.3594 - accuracy: 0.8984

    373/Unknown - 3s 7ms/step - loss: 0.3595 - accuracy: 0.8982

    374/Unknown - 3s 7ms/step - loss: 0.3590 - accuracy: 0.8984

    375/Unknown - 3s 7ms/step - loss: 0.3590 - accuracy: 0.8983

    376/Unknown - 3s 7ms/step - loss: 0.3593 - accuracy: 0.8981

    377/Unknown - 3s 7ms/step - loss: 0.3594 - accuracy: 0.8980

    378/Unknown - 3s 7ms/step - loss: 0.3594 - accuracy: 0.8979

    379/Unknown - 3s 7ms/step - loss: 0.3597 - accuracy: 0.8977

    380/Unknown - 3s 7ms/step - loss: 0.3593 - accuracy: 0.8978

    381/Unknown - 3s 7ms/step - loss: 0.3590 - accuracy: 0.8979

    382/Unknown - 3s 7ms/step - loss: 0.3585 - accuracy: 0.8981

    383/Unknown - 3s 7ms/step - loss: 0.3582 - accuracy: 0.8982

    384/Unknown - 3s 7ms/step - loss: 0.3583 - accuracy: 0.8980

    385/Unknown - 3s 7ms/step - loss: 0.3578 - accuracy: 0.8982

    386/Unknown - 3s 7ms/step - loss: 0.3578 - accuracy: 0.8982

    387/Unknown - 3s 7ms/step - loss: 0.3576 - accuracy: 0.8982

    388/Unknown - 3s 7ms/step - loss: 0.3574 - accuracy: 0.8982

    389/Unknown - 3s 7ms/step - loss: 0.3573 - accuracy: 0.8981

    390/Unknown - 3s 7ms/step - loss: 0.3570 - accuracy: 0.8982

    391/Unknown - 3s 7ms/step - loss: 0.3564 - accuracy: 0.8985

    392/Unknown - 3s 7ms/step - loss: 0.3564 - accuracy: 0.8984

    393/Unknown - 3s 7ms/step - loss: 0.3559 - accuracy: 0.8986

    394/Unknown - 3s 7ms/step - loss: 0.3561 - accuracy: 0.8985

    395/Unknown - 3s 7ms/step - loss: 0.3556 - accuracy: 0.8987

    396/Unknown - 3s 7ms/step - loss: 0.3553 - accuracy: 0.8988

    397/Unknown - 3s 7ms/step - loss: 0.3558 - accuracy: 0.8984



    398/Unknown - 3s 7ms/step - loss: 0.3555 - accuracy: 0.8985

    399/Unknown - 3s 7ms/step - loss: 0.3559 - accuracy: 0.8982

    400/Unknown - 3s 7ms/step - loss: 0.3561 - accuracy: 0.8980



Epoch 2/5
  1/400 [..............................] - ETA: 26s - loss: 0.2922 - accuracy: 0.9062

 12/400 [..............................] - ETA: 3s - loss: 0.3225 - accuracy: 0.8854 

 21/400 [>.............................] - ETA: 3s - loss: 0.2941 - accuracy: 0.9018

 30/400 [=>............................] - ETA: 2s - loss: 0.2833 - accuracy: 0.9083

 40/400 [==>...........................] - ETA: 2s - loss: 0.2841 - accuracy: 0.9078

 50/400 [==>...........................] - ETA: 2s - loss: 0.2820 - accuracy: 0.9087

 60/400 [===>..........................] - ETA: 2s - loss: 0.2798 - accuracy: 0.9094

 70/400 [====>.........................] - ETA: 2s - loss: 0.2856 - accuracy: 0.9058

 80/400 [=====>........................] - ETA: 2s - loss: 0.2943 - accuracy: 0.9012

 89/400 [=====>........................] - ETA: 1s - loss: 0.3027 - accuracy: 0.8968































































Epoch 3/5
  1/400 [..............................] - ETA: 27s - loss: 0.2552 - accuracy: 0.9062

 11/400 [..............................] - ETA: 4s - loss: 0.2362 - accuracy: 0.8949 

 20/400 [>.............................] - ETA: 3s - loss: 0.2365 - accuracy: 0.9000

 30/400 [=>............................] - ETA: 2s - loss: 0.2250 - accuracy: 0.9083

 40/400 [==>...........................] - ETA: 2s - loss: 0.2253 - accuracy: 0.9078

 50/400 [==>...........................] - ETA: 2s - loss: 0.2243 - accuracy: 0.9087

 60/400 [===>..........................] - ETA: 2s - loss: 0.2224 - accuracy: 0.9094

 70/400 [====>.........................] - ETA: 2s - loss: 0.2265 - accuracy: 0.9058

 80/400 [=====>........................] - ETA: 1s - loss: 0.2331 - accuracy: 0.9012

 90/400 [=====>........................] - ETA: 1s - loss: 0.2392 - accuracy: 0.8972

































































Epoch 4/5
  1/400 [..............................] - ETA: 27s - loss: 0.1857 - accuracy: 0.9062

 10/400 [..............................] - ETA: 4s - loss: 0.1536 - accuracy: 0.9531 

 20/400 [>.............................] - ETA: 3s - loss: 0.1539 - accuracy: 0.9484

 29/400 [=>............................] - ETA: 2s - loss: 0.1497 - accuracy: 0.9515

 39/400 [=>............................] - ETA: 2s - loss: 0.1478 - accuracy: 0.9527

 49/400 [==>...........................] - ETA: 2s - loss: 0.1479 - accuracy: 0.9515

 59/400 [===>..........................] - ETA: 2s - loss: 0.1465 - accuracy: 0.9507

 69/400 [====>.........................] - ETA: 2s - loss: 0.1487 - accuracy: 0.9493

 78/400 [====>.........................] - ETA: 2s - loss: 0.1554 - accuracy: 0.9439

 88/400 [=====>........................] - ETA: 1s - loss: 0.1585 - accuracy: 0.9414

































































Epoch 5/5
  1/400 [..............................] - ETA: 28s - loss: 0.1309 - accuracy: 0.9062

 11/400 [..............................] - ETA: 4s - loss: 0.0865 - accuracy: 0.9716 

 20/400 [>.............................] - ETA: 3s - loss: 0.0928 - accuracy: 0.9641

 30/400 [=>............................] - ETA: 2s - loss: 0.0910 - accuracy: 0.9646

 40/400 [==>...........................] - ETA: 2s - loss: 0.0909 - accuracy: 0.9633

 50/400 [==>...........................] - ETA: 2s - loss: 0.0930 - accuracy: 0.9625

 60/400 [===>..........................] - ETA: 2s - loss: 0.0918 - accuracy: 0.9620

 70/400 [====>.........................] - ETA: 2s - loss: 0.0949 - accuracy: 0.9621

 80/400 [=====>........................] - ETA: 1s - loss: 0.0983 - accuracy: 0.9590

 89/400 [=====>........................] - ETA: 1s - loss: 0.1009 - accuracy: 0.9582

































































Let us now evaluate the matcher:

In [52]:
from sklearn.metrics import classification_report, confusion_matrix

def evaluate(model, nr_samples=1000):
    y_true = np.stack(list(Xy_test.map(lambda _, p: p).take(nr_samples)))
    y_pred = np.round(np.concatenate(model.predict(Xy_test.batch(nr_samples).take(1))))
    print(confusion_matrix(y_true, y_pred))
    print(classification_report(y_true, y_pred))

evaluate(matcher)

[[887   1]
 [ 17  95]]
              precision    recall  f1-score   support

         0.0       0.98      1.00      0.99       888
         1.0       0.99      0.85      0.91       112

    accuracy                           0.98      1000
   macro avg       0.99      0.92      0.95      1000
weighted avg       0.98      0.98      0.98      1000



## Step 3: Building a  model that need not learn

We now replace the last dense classification layer with a lambda layer that need not be trained, using the following observation:

Given the two probability distributions $p^{(1)}$ and $p^{(2)}$, the probability that both digits coincide is $\sum_i p^{(1)}_i p^{(2)}_i$.

We can implement the formula in 3. using a `Lambda` layer. The catch is that all
tensors flowing through the neural network are *batches* of data:

In [53]:
p1_batch = tf.constant([[0.1, 0.9], [0.5, 0.5]])
p2_batch = tf.constant([[0.2, 0.8],  [0.4, 0.6]])

def compute_prob_equality(p1_batch, p2_batch):
    return tf.reduce_sum(p1_batch * p2_batch, axis=-1)


compute_prob_equality(p1_batch, p2_batch)

<tf.Tensor: id=58038, shape=(2,), dtype=float32, numpy=array([0.73999995, 0.5       ], dtype=float32)>

In [54]:
def build_matcher_lambda():
    image_1 = tf.keras.layers.Input((28,28,1))
    image_2 = tf.keras.layers.Input((28,28,1))
    probs_1 = classifier(image_1)
    probs_2 = classifier(image_2)
    # both_probs = tf.keras.layers.Concatenate()([probs_1, probs_2])
    prediction = tf.keras.layers.Lambda(lambda p: tf.reduce_sum(p[0] * p[1], axis=-1, keepdims=True))([probs_1, probs_2])
    matcher = tf.keras.Model(inputs=[image_1, image_2], outputs=[prediction])
    return matcher

Let's see how this model performs!

In [55]:
matcher = build_matcher_lambda()
evaluate(matcher)


[[885   3]
 [  5 107]]
              precision    recall  f1-score   support

         0.0       0.99      1.00      1.00       888
         1.0       0.97      0.96      0.96       112

    accuracy                           0.99      1000
   macro avg       0.98      0.98      0.98      1000
weighted avg       0.99      0.99      0.99      1000



This is quite good, isn't it?