This notebook will train an mnist model

In [6]:
# imports
import numpy as np
import six
import tensorflow as tf
from absl import app, flags
from tensorflow.keras.layers import (Conv2D, Dense, Dropout, Flatten, Input,
                                     MaxPooling2D, Reshape, UpSampling2D, Lambda)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import tabulate
from tensorflow_similarity.api.engine.preprocessing import Preprocessing
from tensorflow_similarity.api.engine.simhash import SimHash
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tqdm.auto import tqdm

In [16]:
# datagen is reusable
(x_train, y_train), (original_x_test, original_y_test) = tf.keras.datasets.mnist.load_data()
# initilize datagen
datagen = ImageDataGenerator(
  rotation_range=40,
  width_shift_range=0.2,
  height_shift_range=0.2,
  shear_range=0.2,
  zoom_range=0.2,
  horizontal_flip=True,
  fill_mode='nearest')

datagen.fit(x_train[..., np.newaxis])

In [17]:
# Relative datapath to the downloaded mnist dataset.
DEFAULT_MNIST_DATA_PATH = "./mnist.npz"

In [26]:
def read_mnist_data(data_path, num_augment=8):
    """ Returns the mnist data.
    
    Opens the data file specified by the argument, read each
    line and puts 20% of the data into the testing set.
    
    Args:
        data_path: A string that points to the cached mnist
            dataset.
    
    Returns:
        A tuple that contains three elements. The first element
        is a tuple that contains data used for training and
        the second element is a tuple that contains data used
        for testing. The third element is a tuple that contains
        the target data. All three tuples have the same
        structure, they contains two elements. The first
        element contains a dictionary for the specs of mnist data
        (in 2d np array), the second element contains
        an np array of labels of class.
    """
    
    (x_train, y_train), (x_test_raw, y_test_raw) = tf.keras.datasets.mnist.load_data(path=data_path)
    
    x_train = x_train[..., np.newaxis]
    
    # train on only even digits and not odd digits
    filtered_x_train = []
    filtered_y_train = []
    for x, y in zip(x_train, y_train):
        if y % 2 == 0:
            filtered_x_train.append(x)
            filtered_y_train.append(y)

    x_train = np.array(filtered_x_train)
    y_train = np.array(filtered_y_train)

    x_tests = []
    y_tests = []

    x_targets = []
    y_targets = []

    seen = set()
    for x, y in zip(x_test_raw, y_test_raw):
        if y not in seen:
            seen.add(y)
            x_targets.append(x[..., np.newaxis])
            y_targets.append(y)
        else:
            x_tests.append(x[..., np.newaxis])
            y_tests.append(y)

    if num_augment > 0:
        print("performing data augmentation")
        batch_size = 64 * 4
        new_x_train = np.zeros((len(x_train) * num_augment, 28, 28, 1))
        new_y_train = np.zeros(len(new_x_train))

        num_steps = int(len(new_x_train) / batch_size)

        image_generator = datagen.flow(x=x_train, y=y_train, batch_size=batch_size)


        # pbar that we want to update
        pbar = tqdm(total=len(new_x_train))
        idx = 0
        while idx < len(new_x_train):
            x_train_batch, y_train_batch = next(image_generator)

            batch_size = len(x_train_batch)
            batch_size = min(batch_size, len(new_x_train) - idx)
            new_x_train[idx:idx+batch_size] = x_train_batch
            new_y_train[idx:idx+batch_size] = y_train_batch

            idx += batch_size
            pbar.update(batch_size)

        x_train = new_x_train
        y_train = new_y_train
        
    return (({
        "example": np.array(x_train) / 255.0
    }, np.array(y_train)), ({
        "example": np.array(x_tests) / 255.0
    }, np.array(y_tests)), ({
        "example": np.array(x_targets) / 255.0
    }, np.array(y_targets)))

In [27]:
def model_fn():
    """A simple tower model for mnist dataset.
    
    Returns:
        model: A tensorflow model.
    """
    
    i = Input(shape=(28, 28, 1), name="example")
    o = Conv2D(
        32,
        kernel_size=(5, 5),
        padding='same',
        activation='relu',
        input_shape=(28, 28, 1))(i)
    o = Conv2D(
        32,
        kernel_size=(5, 5),
        padding='same',
        activation='relu',
        input_shape=(28, 28, 1))(o)
    o = MaxPooling2D(pool_size=(2, 2))(o)
    o = Dropout(.25)(o)

    o = Conv2D(64, (3, 3), padding='same', activation='relu')(o)
    o = Conv2D(64, (3, 3), padding='same', activation='relu')(o)
    o = MaxPooling2D(pool_size=(2, 2))(o)
    o = Dropout(.25)(o)

    o = Flatten()(o)
    o = Dense(256, activation="relu")(o)
    o = Dropout(.25)(o)
    o = Dense(100)(o)
    o = Lambda(lambda x: tf.math.l2_normalize(x, axis=1), name="l2_norm")(o)
    model = Model(inputs=i, outputs=o)
    return model

In [28]:
def run_mnist_example(data, model, strategy, augment, autoencoder, epochs, prewarm_epochs):
    """An example usage of tf.similarity with tensorboard callback.

    This basic similarity run will first unpackage training,
    testing, and target data from the arguments and then construct a
    simple moirai model, fit the model with training data, then
    evaluate our model with training and testing datasets.

    Args:
        data: Sets, contains training, testing, and target datasets.
        model: tf.Model, the tower model to fit into moirai.
        strategy: String, specify the strategy to use for mining triplets.
        agument: Boolean, indicates whether we want to augment our data.
        epochs: Integer, number of epochs to fit our moirai model.
        prewarm_epochs: Integer, number of prewarm epochs for our moirai model.

    Returns:
        metrics: Dictionary, containing metrics performed on the
            testing dataset. The key is the name of the metric and the
            value is the np array of the metric values.
    """
        
    (x_train, y_train), (x_test, y_test), (x_targets, y_targets) = data

    moirai = SimHash(
        model,
        strategy=strategy,
        optimizer=Adam(lr=.001))
    
    moirai.fit(
        x_train,
        y_train,
        prewarm_epochs=prewarm_epochs if autoencoder else 0,
        epochs=epochs)
    
    moirai.save('augmented_mnist_model.h5')

    metrics = moirai.evaluate(x_test, y_test, x_targets, y_targets)
    return metrics, moirai

In [29]:
data = read_mnist_data(DEFAULT_MNIST_DATA_PATH)
tower_model = model_fn()
# Strategy we want to use.
strategy = "triplet_loss"
# Whether we want to augment our data.
augment = False
# Whether or not we want to use auxillary autoencoder task.
autoencoder = False
# Number of epochs
epochs = 5
# Number of prewarm epochs
prewarm_epochs = 0

test_metrics, similar_model = run_mnist_example(data, tower_model, strategy, augment, autoencoder, epochs, prewarm_epochs)

performing data augmentation


HBox(children=(IntProgress(value=0, max=235936), HTML(value='')))

Model: "triplet_loss"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
anchor_example (InputLayer)     [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
anchor_idx (InputLayer)         [(None, 1)]          0                                            
__________________________________________________________________________________________________
neg_example (InputLayer)        [(None, 28, 28, 1)]  0                                            
__________________________________________________________________________________________________
neg_idx (InputLayer)            [(None, 1)]          0                                            
_______________________________________________________________________________________