In [1]:
import sys
import numpy as np
from tensorflow.keras.datasets import mnist 
from tensorflow.keras.layers import Input, Lambda, Dense
from tensorflow.keras.models import Model

In [2]:
sys.path.append('../')
from siamese_networks.siamese_network import build_siamese_model
from siamese_networks.build_siamese_pairs import make_pairs
from siamese_networks import config, utils

In [3]:
# load MNIST dataset and scale the pixel values to the range of [0, 1]
print("[INFO] loading MNIST dataset...")

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train / 255.0
_xtest = x_test / 255.0

# add a channel dimension to the images
x_train = np.expand_dims(x_train, axis=-1)
x_test = np.expand_dims(x_test, axis=-1)

# prepare the positive and negative pairs
print("[INFO] preparing positive and negative pairs...")

(x_pair_train, y_label_train) = make_pairs(x_train, y_train)
(x_pair_test, y_label_test) = make_pairs(x_test, y_test)

[INFO] loading MNIST dataset...
[INFO] preparing positive and negative pairs...


In [4]:
# configure the siamese network

print("[INFO] building siamese network...")
img_a = Input(shape=config.IMG_SHAPE)
img_b = Input(shape=config.IMG_SHAPE)

featureExtractor = build_siamese_model(config.IMG_SHAPE)
feats_a = featureExtractor(img_a)
feats_b = featureExtractor(img_b)

[INFO] building siamese network...


In [5]:
# finally, construct the siamese network
distance = Lambda(utils.euclidean_distance)([feats_a, feats_b])
outputs = Dense(1, activation="sigmoid")(distance)
model = Model(inputs=[img_a, img_b], outputs=outputs)

In [None]:
# compile the model
print("[INFO] compiling model...")
model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"])

# train the model
print("[INFO] training model...")
history = model.fit([x_pair_train[:, 0], x_pair_train[:, 1]],
                    y_label_train[:],
                    validation_data=(
                        [x_pair_test[:, 0], x_pair_test[:, 1]],
                        y_label_test[:]
                    ),
                    batch_size=config.BATCH_SIZE, 
                    epochs=config.EPOCHS)

# serialize the model to disk
print("[INFO] saving siamese model...")
model.save(config.MODEL_PATH)

# plot the training history
print("[INFO] plotting training history...")
utils.plot_training(history, config.PLOT_PATH)