# Fashion-MNIST model testing
In this notebook we will test if the results obtained with MNIST are also valid with Fashion-MNIST

# Make the necessary imports

In [None]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.datasets import fashion_mnist
from tensorflow.keras import layers
from tensorflow.keras.preprocessing.image import img_to_array, array_to_img
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from matplotlib.pyplot import imshow
import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.metrics import roc_curve, auc

import pandas as pd
import numpy as np
import random
import math
import os

# # Import and preprocess the Fashion-MNIST dataset

In [None]:
(Xtrain_orig, ytrain_orig), (Xtest_orig, ytest_orig) = fashion_mnist.load_data()

Xtrain = Xtrain_orig.astype('float32')/255.
Xtest = Xtest_orig.astype('float32')/255.

ytrain = ytrain_orig
ytest = ytest_orig

print ("Xtrain shape: " + str(Xtrain.shape))
print ("ytrain shape: " + str(ytrain.shape))

print ("Xtest shape: " + str(Xtest.shape))
print ("ytest shape: " + str(ytest.shape))

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
Xtrain shape: (60000, 28, 28)
ytrain shape: (60000,)
Xtest shape: (10000, 28, 28)
ytest shape: (10000,)


# Support functions

In [None]:
def create_random_set(X, y, strip_size=4, set_size=60000):
    
    #Create a list of lists where every sublist contains the indexes of the images belonging to a class
    list_indices_by_number = [np.where(y == i)[0] for i in range(10)]
    
    #Create the strips of images
    X_groups = []
    number_groups = []
    y_label = []
    
    for i in range(set_size): #Create as many images as strip_size
        group_i = []
        numbers_i = []
        while len(group_i) < strip_size: #While the strip is shorter that the size wanted
            #Choose a random index
            image_idx = random.randint(0, len(X)-1)
            numbers_i.append(y[image_idx])
            group_i.append(image_idx)
        #When the strip is full, add the target image. Use random to obtain a balanced set.
        repeated = np.random.choice([0, 1], p=[0.50, 0.50])
        if repeated:
            #Look for a number whose class is already contained in the strip.
            random_idx = random.randint(0, len(numbers_i)-1)
            number = numbers_i[random_idx]
            numbers_i.append(number)
            #Choose a random image representing the chosen class
            image_idx = random.randint(0, len(list_indices_by_number[number])-1)
            group_i.append(list_indices_by_number[number][image_idx])
            y_label.append(1)
        else:
            #Add a number that is not aready in the strip
            possible_numbers = [x for x in range(10) if x not in numbers_i]
            random_number = random.choice(possible_numbers)
            numbers_i.append(random_number)
            #Choose a random image representing the chosen class
            image_idx = random.randint(0, len(list_indices_by_number[random_number])-1)
            group_i.append(list_indices_by_number[random_number][image_idx])
            y_label.append(0)
        X_groups.append(group_i)
        number_groups.append(numbers_i)
    
    #We now want our examples to have the following shape: (N, X_train[1], X_train[2], (strip_size+1)*3 donde
    #And create the expected labels
    N = len(X_groups)
    img_size1 = X.shape[1]
    img_size2 = X.shape[2]
    X_processed= np.zeros([N, strip_size+1, img_size1, img_size2, 1])
    y_processed = np.zeros([N])
    for i in range(N):
        numbers_i = list(dict.fromkeys(number_groups[i]))
        for j in range(strip_size):
            X_processed[i, j:j+1, :, :, :] = tf.expand_dims(X[X_groups[i][j]], axis=-1)
        X_processed[i, strip_size, :, :, :] = tf.expand_dims(X[X_groups[i][strip_size]], axis=-1)
        y_processed[i] = y_label[i]
        
    return X_processed, y_processed

In [None]:
def create_random_set_RGB(X, y, strip_size=4, set_size=60000):
    #Create a list of lists where every sublist contains the indexes of the images belonging to a class
    list_indices_by_number = [np.where(y == i)[0] for i in range(10)]
    
    #Create the strips of images
    X_groups = []
    number_groups = []
    y_label = []
    
    for i in range(set_size): #Create as many images as strip_size
        group_i = []
        numbers_i = []
        while len(group_i) < strip_size: #While the strip is shorter that the size wanted
            #Choose a random index
            image_idx = random.randint(0, len(X)-1)
            numbers_i.append(y[image_idx])
            group_i.append(image_idx)
        #When the strip is full, add the target image. Use random to obtain a balanced set.
        repeated = np.random.choice([0, 1], p=[0.50, 0.50])
        if repeated:
            #Look for a number whose class is already contained in the strip.
            random_idx = random.randint(0, len(numbers_i)-1)
            number = numbers_i[random_idx]
            numbers_i.append(number)
            #Choose a random image representing the chosen class
            image_idx = random.randint(0, len(list_indices_by_number[number])-1)
            group_i.append(list_indices_by_number[number][image_idx])
            y_label.append(1)
        else:
            #Add a number that is not aready in the strip
            possible_numbers = [x for x in range(10) if x not in numbers_i]
            random_number = random.choice(possible_numbers)
            numbers_i.append(random_number)
            #Choose a random image representing the chosen class
            image_idx = random.randint(0, len(list_indices_by_number[random_number])-1)
            group_i.append(list_indices_by_number[random_number][image_idx])
            y_label.append(0)
        X_groups.append(group_i)
        number_groups.append(numbers_i)
    
    #We now want our examples to have the following shape: (N, X_train[1], X_train[2], (strip_size+1)*3 donde
    ##And create the expected labels
    N = len(X_groups)
    img_size1 = X.shape[1]
    img_size2 = X.shape[2]
    X_processed= np.zeros([N, strip_size+1, img_size1, img_size2, 3])
    y_processed = np.zeros([N])
    for i in range(N):
        numbers_i = list(dict.fromkeys(number_groups[i]))
        for j in range(strip_size):
            X_processed[i, j:j+1, :, :, :] = X[X_groups[i][j]]
        X_processed[i, strip_size, :, :, :] = X[X_groups[i][strip_size]]
        y_processed[i] = y_label[i]
        
    return X_processed, y_processed

In [None]:
def create_random_data_sets(X, y, Xt, yt, strip_size, training_size=30000, test_size=2000, RGB=False):
    if RGB:
        Xtrain, ytrain = create_random_set_RGB(X, y, strip_size, set_size=training_size)
        Xtrain, Xval, ytrain, yval = train_test_split(Xtrain, ytrain, test_size=0.2)
        Xtest, ytest = create_random_set_RGB(Xt, yt, strip_size, set_size=test_size)
    else:
        Xtrain, ytrain = create_random_set(X, y, strip_size, set_size=training_size)
        Xtrain, Xval, ytrain, yval = train_test_split(Xtrain, ytrain, test_size=0.2)
        Xtest, ytest = create_random_set(Xt, yt, strip_size, set_size=test_size)
    
    print ("Training examples classified as 0: " + str(len(np.where(ytrain==0)[0])))
    print ("Training examples classified as 1: " + str(len(np.where(ytrain==1)[0])))
    print ("Validation examples classified as 0: " + str(len(np.where(yval==0)[0])))
    print ("Validation examples classified as 1: " + str(len(np.where(yval==1)[0])))
    print ("Test examples classified as 0: " + str(len(np.where(ytest==0)[0])))
    print ("Test examples classified as 1: " + str(len(np.where(ytest==1)[0])))
    
    return Xtrain, Xval, Xtest, ytrain, yval, ytest

In [None]:
def train_model(model, Xtrain, ytrain, Xval, yval, Xtest, ytest, lr=1e-3, batch_size=32, model_save_name="best_model"):

    #Define callbacks
    #Save the best model
    dirname = os.getcwd()
    filepath = os.path.join(dirname, model_save_name)
    filepath = os.path.join(filepath, 'model')
    
    model_checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss',
        mode='min', verbose = 0, save_best_only=True, save_weights_only=True)
    #Add early stopping
    early_stopping_cb = tf.keras.callbacks.EarlyStopping(monitor='val_loss', mode='min', patience=10, verbose = 0)
    #Reduce learning rate on plateau
    reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_accuracy', factor=0.65, patience=5, min_lr=1e-5)
    callbacks = [model_checkpoint_cb, early_stopping_cb, reduce_lr]

    #Compile and fit the model
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=lr), 
                  loss='binary_crossentropy', 
                  metrics=['accuracy'])

    history = model.fit(Xtrain, ytrain,
                        batch_size=batch_size,
                        epochs=80,
                        validation_data=(Xval, yval),
                        callbacks=callbacks,
                        verbose=1)
    
    """plt.figure(figsize=(12,6))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend(['Train', 'Val'], loc='upper right')


    plt.subplot(1, 2, 2)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend(['Train', 'Val'], loc='upper right')"""
    
    model.load_weights(filepath)
    ypredict = model.predict(Xtest)
    #ypredict = tf.squeeze(ypredict).numpy()
    #print(ypredict)
    #ypredict_round = [round(x) for x in ypredict]
    score = model.evaluate(Xtest, ytest, verbose=0)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])
    
    #cm = confusion_matrix(ytest, ypredict)
    #disp = ConfusionMatrixDisplay(confusion_matrix=cm)
    #disp.plot()
    #plt.show()
    
    return score[1], model

In [None]:
def see_embeddings(model, X, y, save_name, preprocess_func, RGB):
    embeddings_per_class = 10
    classes = 10
    # Buscamos 3 imagenes de cada clase
    #Creamos una lista de listas donde cada sublista contiene los índices de la imágenes de un número
    list_indices_by_number = [np.where(y == i)[0] for i in range(classes)]
    
    if RGB:
        images = np.zeros((classes*embeddings_per_class, X.shape[1], X.shape[2], 3))
        for i in range(classes):
            list_of_indexes = random.choices(list_indices_by_number[i], k=embeddings_per_class)
            images[embeddings_per_class*i:(i+1)*embeddings_per_class, :, :, :] = X[list_of_indexes]
    else:
        images = np.zeros((classes*embeddings_per_class, X.shape[1], X.shape[2], 1))
        for i in range(classes):
            list_of_indexes = random.choices(list_indices_by_number[i], k=embeddings_per_class)
            images[embeddings_per_class*i:(i+1)*embeddings_per_class, :, :, :] = tf.expand_dims(X[list_of_indexes], axis=-1)

    if preprocess_func:
        prep_images = preprocess_func(images)
        encoded_images = model.encoder(prep_images).numpy()
    else:
        encoded_images = model.encoder(images).numpy()
    
    plt.figure(figsize=(15, 30))

    m = embeddings_per_class
    n = classes
    for i in range(m):
        for j in range(n):
            ax = plt.subplot(m, n, n*i+j+1)
            plt.imshow(tf.expand_dims(encoded_images[n*i+j], axis=-1))
            plt.title('Class: ' + str(i), fontsize=16)
            plt.gray()
            ax.get_xaxis().set_visible(False)
            ax.get_yaxis().set_visible(False)
    plt.suptitle('Image embeddings per class', fontsize=20)
    plt.tight_layout()
    plt.subplots_adjust(top=0.95)
    plt.savefig(save_name)
    plt.show()

In [None]:
def save_to_db(db_name, column_name, list_to_save):
    df = pd.read_csv(db_name + '.csv')
    df[column_name] = list_to_save
    df.to_csv(db_name + '.csv', index=False)

def create_db(db_name):
    df = pd.DataFrame()
    df.to_csv(db_name + '.csv', index=True)

def plot_db_columns(db_name, title, xlabel, ylabel, save_name):
    df = pd.read_csv(db_name + '.csv')
    # plot lines
    x = [i for i in range(2, 11)]
    plt.figure(figsize=(9,7))
    
    for column in df:
        if (column != 'Unnamed: 0'):
            plt.plot(x, df[column], label = column)

    plt.title(title, fontsize=20)
    plt.xlabel(xlabel, fontsize=16)
    plt.ylabel(ylabel, fontsize=16)
    plt.legend()
    plt.grid(axis = 'y', color = 'gray', linestyle = '--', linewidth = 0.5)
    plt.tick_params(labelsize=14)
    plt.savefig(save_name + '.png')
    plt.show()

In [None]:
def run_model(X, y, Xt, yt, model_class, latent_dim, model_save_name, db_name, column_name, preprocess_func, preprocess_before, num_iterations=3, RGB=True, pre_model=None, lr=1e-3):
    
    #Preprocess data
    if preprocess_before:
        X_processed = preprocess_func(X)
        Xt_processed = preprocess_func(Xt)
    else:
        X_processed = X
        Xt_processed = Xt
        
    accuracy_per_strip_size = []
    for strip_size in range(1, 10):
        print('-------------------' + str(strip_size) + '-------------------')
        channels = strip_size + 1
        accuracy_per_iteration = []
        for i in range(num_iterations):
            print('--------------Iteration ' + str(i+1) + '--------------')
            Xtrain, Xval, Xtest, ytrain, yval, ytest = create_random_data_sets(X_processed, y, Xt_processed, yt, strip_size, RGB=RGB)
            tf.keras.backend.clear_session()
            if pre_model:
                model = model_class(latent_dim, channels, (None, Xtrain[0].shape[1], Xtrain[0].shape[2], Xtrain[0].shape[3]), pre_model)
            else:
                model = model_class(latent_dim, channels, (None, Xtrain[0].shape[1], Xtrain[0].shape[2], Xtrain[0].shape[3])) 
            score, model = train_model(model, Xtrain, ytrain, Xval, yval, Xtest, ytest, lr=lr, batch_size=32, model_save_name=model_save_name)
            if i == (num_iterations-1):
                embedding_file_path = os.path.join('embedding_images', model_save_name + str(channels) + '.png')
                see_embeddings(model, X, y, embedding_file_path, preprocess_func, RGB)
            accuracy_per_iteration.append(score)
            del Xtrain
            del Xval
            del Xtest
            del ytrain
            del yval
            del ytest
            del model
        accuracy_per_strip_size.append(np.mean(accuracy_per_iteration))
    save_to_db(db_name, column_name, accuracy_per_strip_size)

# Build models

In [None]:
os.mkdir('embedding_images') #Only run this line of no directory already called embedding images is already created
create_db('mnist_fashion_32')
create_db('mnist_fashion_10')

## Build models based on the maximum operation for combining the context images
We will test the models with two different encoder blocks:
- If the model is named ---1, then it will be using the same encoder we used with the MNIST dataset. 
- If the model is named ---2, then it will be using a new encoder better prepared for the Fashion-MNIST dataset.

### Models with the dot product

In [None]:
class DotConv1(keras.Model):
    def __init__(self, latent_dim, channels, shape_in):
        super(DotConv1, self).__init__()
        self.latent_dim = latent_dim
        self.channels = channels
        self.shape_in = shape_in
        
        self.encoder = tf.keras.Sequential([
            layers.Conv2D(filters=1, kernel_size=(3, 3), padding="same", strides=1, input_shape=self.shape_in[1:]),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2D(filters=32, kernel_size=(3, 3), padding="same", strides=2),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same", strides=1),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Flatten(),
            
            layers.Dense(units=latent_dim)
            
        ])

    def call(self, x):

        encoded_images = layers.TimeDistributed(self.encoder)(x)
        
        if(self.channels > 2):
            max_image = layers.Maximum()([layers.Lambda(lambda x : x[:,i,:])(encoded_images) for i in range(self.channels-1)])
        else:
            max_image = layers.Lambda(lambda x : x[:,0,:])(encoded_images)
            
        last_embedding = layers.Lambda(lambda x : x[:,-1,:])(encoded_images)
        y_predict = layers.Dot(axes=1, normalize=True)([max_image, last_embedding])
        return y_predict

In [None]:
class DotConv2(keras.Model):
    def __init__(self, latent_dim, channels, shape_in):
        super(DotConv2, self).__init__()
        self.latent_dim = latent_dim
        self.channels = channels
        self.shape_in = shape_in
        
        self.encoder = tf.keras.Sequential([
            layers.Conv2D(32,(3,3), activation='relu', padding='same', input_shape=self.shape_in[1:]),
            layers.BatchNormalization(),
            layers.Conv2D(32, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.MaxPool2D(pool_size=(2,2)),
            
            layers.Conv2D(64, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(64, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.MaxPool2D(pool_size=(2,2)),
            
            layers.Conv2D(128, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            
            layers.Conv2D(256, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            
            layers.Flatten(),
            
            layers.Dense(units=latent_dim)
            
        ])


    def call(self, x):

        encoded_images = layers.TimeDistributed(self.encoder)(x)
        
        if(self.channels > 2):
            max_image = layers.Maximum()([layers.Lambda(lambda x : x[:,i,:])(encoded_images) for i in range(self.channels-1)])
        else:
            max_image = layers.Lambda(lambda x : x[:,0,:])(encoded_images)
            
        last_embedding = layers.Lambda(lambda x : x[:,-1,:])(encoded_images)
        y_predict = layers.Dot(axes=1, normalize=True)([max_image, last_embedding])
        return y_predict

### Models with the dot product and a dense layer

In [None]:
class DotDenseConv1(keras.Model):
    def __init__(self, latent_dim, channels, shape_in):
        super(DotDenseConv1, self).__init__()
        self.latent_dim = latent_dim
        self.channels = channels
        self.shape_in = shape_in
        
        self.encoder = tf.keras.Sequential([
            layers.Conv2D(filters=1, kernel_size=(3, 3), padding="same", strides=1, input_shape=self.shape_in[1:]),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2D(filters=32, kernel_size=(3, 3), padding="same", strides=2),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same", strides=1),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Flatten(),
            
            layers.Dense(units=latent_dim)  
        ])
        
        
        self.classifier = tf.keras.Sequential([
            layers.Dense(1, input_shape=[1])
        ])

    def call(self, x):

        encoded_images = layers.TimeDistributed(self.encoder)(x)
        
        if(self.channels > 2):
            max_image = layers.Maximum()([layers.Lambda(lambda x : x[:,i,:])(encoded_images) for i in range(self.channels-1)])
        else:
            max_image = layers.Lambda(lambda x : x[:,0,:])(encoded_images)
            
        last_embedding = layers.Lambda(lambda x : x[:,-1,:])(encoded_images)
        dot = layers.Dot(axes=1, normalize=True)([max_image, last_embedding])
        y_predict = self.classifier(dot)
        
        return y_predict

In [None]:
class DotDenseConv2(keras.Model):
    def __init__(self, latent_dim, channels, shape_in):
        super(DotDenseConv2, self).__init__()
        self.latent_dim = latent_dim
        self.channels = channels
        self.shape_in = shape_in
        
        self.encoder = tf.keras.Sequential([
            layers.Conv2D(32,(3,3), activation='relu', padding='same', input_shape=self.shape_in[1:]),
            layers.BatchNormalization(),
            layers.Conv2D(32, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.MaxPool2D(pool_size=(2,2)),
            
            layers.Conv2D(64, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(64, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.MaxPool2D(pool_size=(2,2)),
            
            layers.Conv2D(128, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            
            layers.Conv2D(256, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            
            layers.Flatten(),
            
            layers.Dense(units=latent_dim)
            
        ])
        
        self.classifier = tf.keras.Sequential([
            layers.Dense(1, input_shape=[1])
        ])

    def call(self, x):
        
        encoded_images = layers.TimeDistributed(self.encoder)(x)
        
        if(self.channels > 2):
            max_image = layers.Maximum()([layers.Lambda(lambda x : x[:,i,:])(encoded_images) for i in range(self.channels-1)])
        else:
            max_image = layers.Lambda(lambda x : x[:,0,:])(encoded_images)
            
        last_embedding = layers.Lambda(lambda x : x[:,-1,:])(encoded_images)
        dot = layers.Dot(axes=1, normalize=True)([max_image, last_embedding])
        y_predict = self.classifier(dot)
        
        return y_predict

### Models with a dense classifier

In [None]:
class MaxClassifier1(keras.Model):
    def __init__(self, latent_dim, channels, shape_in):
        super(MaxClassifier1, self).__init__()
        self.latent_dim = latent_dim
        self.channels = channels
        self.shape_in = shape_in
        
        self.encoder = tf.keras.Sequential([
            layers.Conv2D(filters=1, kernel_size=(3, 3), padding="same", strides=1, input_shape=self.shape_in[1:]),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2D(filters=32, kernel_size=(3, 3), padding="same", strides=2),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Conv2D(filters=64, kernel_size=(3, 3), padding="same", strides=1),
            layers.BatchNormalization(),
            layers.ReLU(),
            
            layers.Flatten(),
            
            layers.Dense(units=latent_dim)
            
        ])
        
        self.classifier = tf.keras.Sequential([
            layers.Flatten(),
            #layers.Dense(512, activation='relu'),
            layers.Dense(128, activation='relu'),
            layers.Dense(64, activation='relu'),
            layers.Dense(1, activation='sigmoid')
        ])

    def call(self, x):
        
        encoded_images = layers.TimeDistributed(self.encoder)(x)
        if self.channels > 2:
            max_image = layers.Maximum()([layers.Lambda(lambda x : x[:,i,:])(encoded_images) for i in range(self.channels-1)])
        else:
            max_image = layers.Lambda(lambda x : x[:,0,:])(encoded_images)

        target_image = layers.Lambda(lambda x : x[:,-1,:])(encoded_images)
        stacked_image = tf.stack([max_image, target_image], axis=-1)
        y_predict = self.classifier(stacked_image)
        return y_predict

In [None]:
class MaxClassifier2(keras.Model):
    def __init__(self, latent_dim, channels, shape_in):
        super(MaxClassifier2, self).__init__()
        self.latent_dim = latent_dim
        self.channels = channels
        self.shape_in = shape_in
        
        self.encoder = tf.keras.Sequential([
            layers.Conv2D(32,(3,3), activation='relu', padding='same', input_shape=self.shape_in[1:]),
            layers.BatchNormalization(),
            layers.Conv2D(32, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.MaxPool2D(pool_size=(2,2)),
            
            layers.Conv2D(64, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(64, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.MaxPool2D(pool_size=(2,2)),
            
            layers.Conv2D(128, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(128, (3, 3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            
            layers.Conv2D(256, (3,3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            layers.Conv2D(256, (3, 3), activation='relu', padding='same'),
            layers.BatchNormalization(),
            
            layers.Flatten(),
            
            layers.Dense(units=latent_dim)
            
        ])
        
        self.classifier = tf.keras.Sequential([
            layers.Flatten(),
            #layers.Dense(512, activation='relu'),
            layers.Dense(128, activation='relu'),
            layers.Dense(64, activation='relu'),
            layers.Dense(1, activation='sigmoid')
        ])

    def call(self, x):
        
        encoded_images = layers.TimeDistributed(self.encoder)(x)
        if self.channels > 2:
            max_image = layers.Maximum()([layers.Lambda(lambda x : x[:,i,:])(encoded_images) for i in range(self.channels-1)])
        else:
            max_image = layers.Lambda(lambda x : x[:,0,:])(encoded_images)

        target_image = layers.Lambda(lambda x : x[:,-1,:])(encoded_images)
        stacked_image = tf.stack([max_image, target_image], axis=-1)
        y_predict = self.classifier(stacked_image)
        return y_predict

In [None]:
run_model(Xtrain, ytrain, Xtest, ytest, DotConv1, 32, 'DotConv1', 'mnist_fashion_32', 'DotConv1', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, DotConv2, 32, 'DotConv2', 'mnist_fashion_32', 'DotConv2', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, DotDenseConv1, 32, 'DotDenseConv1', 'mnist_fashion_32', 'DotDenseConv1', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, DotDenseConv2, 32, 'DotDenseConv2', 'mnist_fashion_32', 'DotDenseConv2', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, MaxClassifier1, 32, 'MaxClassifier1', 'mnist_fashion_32', 'MaxClassifier1', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, MaxClassifier2, 32, 'MaxClassifier2', 'mnist_fashion_32', 'MaxClassifier2', None, False, num_iterations=1, RGB=False)

In [None]:
run_model(Xtrain, ytrain, Xtest, ytest, DotConv1, 10, 'DotConv1', 'mnist_fashion_10', 'DotConv1', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, DotConv2, 10, 'DotConv2', 'mnist_fashion_10', 'DotConv2', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, DotDenseConv1, 10, 'DotDenseConv1', 'mnist_fashion_10', 'DotDenseConv1', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, DotDenseConv2, 10, 'DotDenseConv2', 'mnist_fashion_10', 'DotDenseConv2', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, MaxClassifier1, 10, 'MaxClassifier1', 'mnist_fashion_10', 'MaxClassifier1', None, False, num_iterations=1, RGB=False)
run_model(Xtrain, ytrain, Xtest, ytest, MaxClassifier2, 10, 'MaxClassifier2', 'mnist_fashion_10', 'MaxClassifier2', None, False, num_iterations=1, RGB=False)

# Plot results

In [None]:
plot_db_columns('mnist_fashion_32', 'Fashion MNIST model testing', 'Strip size', 'Accuracy', 'mnist_fashion_32')

In [None]:
plot_db_columns('mnist_fashion_10', 'Fashion MNIST model testing', 'Strip size', 'Accuracy', 'mnist_fashion_10')