# tf.Similarity Visualization Demo: MNIST

## Background

**tf.similarity**

TensorFlow Similarity (tf.similarity) is a soon-to-be open-sourced package for tensorflow that makes training and deploying similarity models easier. More information about tensorflow similarity can be found on its [design doc](https://docs.google.com/document/d/1fEUrWd-XGIHeUoerPPazpKtZVeceBGgyXAuvY41Y2xc/edit#) and [user guide](https://docs.google.com/document/d/1Cx6E1o5o-0wEngNtKhYr1OaKRwKPZAiwRZWyhQJLOc8/edit).

**MNIST** 

**Data Type:** Images, handwritten-digits, black-and-white

**Number of Classes:** 10

**Description:** Dataset of 60,000 28x28 grayscale images of the 10 digits, along with a test set of 10,000 images. This is also the dataset that has been used [Embedding Learner Experiments](https://docs.google.com/document/d/1J8miO0KEu9tjPTeWSubtItPAIAgJ9SGbHKJFwGuzIxI/edit#). This can make it easy to understand how to apply zero/one/few shots learning as we can remove either odds/evens digits samples during time and then in validation/test time test if our models is able to recognize the removed digits. Some research papers used MNIST as benchmark but this dataset may be too simple.
[Tf.similarity MNIST experiment](https://security-and-privacy-group-research.googlesource.com/similarity/+/refs/heads/master/moirai/experiments/mnist/).

**Single-shot learning**

Single-shot learning refers to a type of machine learning problem where we only have a few examples of labeled data. Therefore during training time we only provide a few labeled examples to our models and at test time trying to classify, or find the most similar, unlabeled data.

## Setup and Import

In [None]:
# this block will not be necessary soon
try:
    # %tensorflow_version only exists in Colab.
    %tensorflow_version 2.x
except Exception:
    pass

In [None]:
# uncomment the following lines to install the required packages
# !pip install umap-learn
# !pip install plotly
# !pip install altair
# !pip install MulticoreTSNE
# !pip install -U altair vega_datasets notebook vega

### Download TSNE-CUDA for embedding visualization as Tensorboard Embedding Projector does not work in Colab

In [None]:
# download and unpack tsnecuda from anaconda.org

# uncomment this cell if running on colab, other wise run this in command line
'''
!wget https://anaconda.org/CannyLab/tsnecuda/2.1.0/download/linux-64/tsnecuda-2.1.0-cuda100.tar.bz2
!tar xvjf tsnecuda-2.1.0-cuda100.tar.bz2
!cp -r site-packages/* /usr/local/lib/python3.6/dist-packages/

# create a symbolic link between the downloaded libfaiss.so file and the location python's looking at

!echo $LD_LIBRARY_PATH 
# this is probably /usr/lib64-nvidia

!ln -s /content/lib/libfaiss.so $LD_LIBRARY_PATH/libfaiss.so
'''

### Tensorflow imports

In [None]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import (Conv2D, Dense, Dropout, Flatten,
                                     Input, MaxPooling2D)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

### Tensorflow similarity imports

In [None]:
from tensorflow_similarity.api.engine.preprocessing import Preprocessing
from tensorflow_similarity.api.engine.simhash import SimHash

### Other imports

In [None]:
import numpy as np
from numba import cuda
import altair as alt

# enable us to visualize more than 5,000 items for Altair
alt.data_transformers.disable_max_rows()

# do not need the below line for colab
alt.renderers.enable('notebook')

## Helper methods

In [None]:
from sklearn.metrics import pairwise_distances
from collections import defaultdict
import numpy as np

def read_mnist_data(zero_shot=False, collapse=True):
    """ Returns the mnist data.
    
    Opens the data file specified by the argument, read each
    line and puts 20% of the data into the testing set.
    
    Args:
        data_path: A string that points to the cached mnist
            dataset.
    
    Returns:
        A tuple that contains three elements. The first element
        is a tuple that contains data used for training and
        the second element is a tuple that contains data used
        for testing. The third element is a tuple that contains
        the target data. All three tuples have the same
        structure, they contains two elements. The first
        element contains a dictionary for the specs of mnist data
        (in 2d np array), the second element contains
        an np array of labels of class.
    """
    
    (x_train, y_train), (x_test_raw, y_test_raw) = mnist.load_data()
    
    
    if zero_shot:
        # train on only even digits and not odd digits
        filtered_x_train = []
        filtered_y_train = []
        for x, y in zip(x_train, y_train):
            if y % 2 == 0:
                filtered_x_train.append(x)
                filtered_y_train.append(y)

        x_train = filtered_x_train
        y_train = filtered_y_train
    elif collapse:
        filtered_x_train = []
        filtered_y_train = []
        for x, y in zip(x_train, y_train):
            new_y = y % 5
            filtered_x_train.append(x)
            filtered_y_train.append(new_y)

        x_train = filtered_x_train
        y_train = filtered_y_train
    
        

    x_tests = []
    y_tests = []

    x_targets = []
    y_targets = []

    test_dicts = defaultdict(list)
    for x, y in zip(x_test_raw, y_test_raw):
        test_dicts[y].append(np.array(x).flatten())
        
    for label in test_dicts:
        label_test_raw = np.array(test_dicts[label])
        
        # find mediod for each label
        distances = pairwise_distances(label_test_raw, label_test_raw)
        med_idx = np.argmin(distances.sum(axis=0))
        med = label_test_raw[med_idx]
        x_targets.append(med.reshape((28,28)))
        y_targets.append(label)
        label_test_raw = np.delete(label_test_raw, med_idx, axis=0)
        x_tests.extend(label_test_raw.reshape((len(label_test_raw), 28, 28)))
        labels = [label] * len(label_test_raw)
        y_tests.extend(labels)
        

    return (({
        "example": np.array(x_train)
    }, np.array(y_train)), ({
        "example": np.array(x_tests)
    }, np.array(y_tests)), ({
        "example": np.array(x_targets)
    }, np.array(y_targets)))

In [None]:
def model_fn():
    """A simple tower model for mnist dataset.
    
    Returns:
        model: A tensorflow model.
    """
    
    i = Input(shape=(28, 28, 1), name="example")
    o = Conv2D(
        32,
        kernel_size=(5, 5),
        padding='same',
        activation='relu',
        input_shape=(28, 28, 1))(i)
    o = Conv2D(
        32,
        kernel_size=(5, 5),
        padding='same',
        activation='relu',
        input_shape=(28, 28, 1))(i)
    o = MaxPooling2D(pool_size=(2, 2))(o)
    o = Dropout(.25)(o)

    o = Conv2D(64, (3, 3), padding='same', activation='relu')(o)
    o = Conv2D(64, (3, 3), padding='same', activation='relu')(o)
    o = MaxPooling2D(pool_size=(2, 2))(o)
    o = Dropout(.25)(o)

    o = Flatten()(o)
    o = Dense(256, activation="relu")(o)
    o = Dropout(.25)(o)
    o = Dense(32)(o)
    model = Model(inputs=i, outputs=o)
    return model

In [None]:
class Normalize(Preprocessing):
    """A Preprocessing class that normalize the MNIST example inputs."""
    
    def preprocess(self, img):
        """Normalized and reshape the input images."""
        
        normed = img["example"] / 255.0
        normed = normed.reshape((28, 28, 1))
        out = {"example": normed}
        return out


In [None]:
def run_mnist_example(data, model, strategy, epochs):
    """An example usage of tf.similarity.

    This basic similarity run will first unpackage training,
    testing, and target data from the arguments and then construct a
    simple moirai model, fit the model with training data, then
    evaluate our model with training and testing datasets.

    Args:
        data: Sets, contains training, testing, and target datasets.
        model: tf.Model, the tower model to fit into moirai.
        strategy: String, specify the strategy to use for mining triplets.
        epochs: Integer, number of epochs to fit our moirai model.

    Returns:
        tf_similarity_model: tf.similarity Model instance
    """
        
    (x_train, y_train), (x_test, y_test), (x_targets, y_targets) = data

    tf_similarity_model = SimHash(
        model,
        preprocessing=Normalize(),
        strategy=strategy,
        optimizer=Adam(lr=.001))
    
    tf_similarity_model.fit(
        x_train,
        y_train,
        epochs=epochs)

    return tf_similarity_model

### Get Embeddings for test and target dataset when we only train on even digits (zero-shot)

In [None]:
#@title Parameters for tensorflow similarity model
data = read_mnist_data(zero_shot=True)
model = model_fn()
# Strategy we want to use.
strategy = "triplet_loss" #@param ["triplet_loss", "quadruplet_loss", "stable_quadruplet_loss"]
# Number of epochs
epochs = 2 #@param {type:"integer"}

similarity_model = run_mnist_example(data, model, strategy, epochs)

(x_train, y_train), (x_test, y_test), (x_targets, y_targets) = data
zero_shot_test_embeddings = similarity_model.predict(x_test)
zero_shot_targets_embeddings = similarity_model.predict(x_targets)

In [None]:
# only need to run this cell in Ipython Notebook, not on Colab
cuda.select_device(0)
cuda.close()

## Visualizations

### Read in MNIST Data

In [None]:
(x_train, y_train), (x_test, y_test), (x_targets, y_targets) = read_mnist_data()
# the value inside x_test's example key is our images
images = x_test["example"]

### Import Visualization Methods

In [None]:
from tensorflow_similarity.visualization import *

### Visualization for Zero Shot on MNIST

In the below visualizations, we are visualize how well we did on clustering the odd digits given that we only train on even digits.

In [None]:
title = "Confusion Matrix of nearest neighbor search trained using tf similarity"
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
figure = plot_confusion_matrix(zero_shot_test_embeddings, y_test, zero_shot_targets_embeddings, y_targets, title, classes=classes)
figure

In [None]:
title = "2D Embedding Projector of NMIST dataset"
figure = plot_embedding_projector(zero_shot_test_embeddings, y_test, title=title)
figure

In [None]:
title = "Nearest neighbors of targets trained using tf similarity"
N = 5
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
figure = plot_nearest_neighbors(zero_shot_test_embeddings, y_test, zero_shot_targets_embeddings, y_targets, x_test["example"], x_targets["example"], title, N, classes=classes)
figure

In [None]:
title = "Nearest neighbors table of targets trained using tf similarity"
N = 8
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
figure = plot_nearest_neighbors_table(zero_shot_test_embeddings, y_test, zero_shot_targets_embeddings, y_targets, title, N, classes=classes)
figure

In [None]:
title = "Distance histogram of neareast neighbor search trained using tf similarity"
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
figure = plot_distance_histograms(zero_shot_test_embeddings, y_test, zero_shot_targets_embeddings, y_targets, title, classes=classes)
figure

In [None]:
title = "Violin plots of distances trained using tf similarity"
classes = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
interactive = False
figure = plot_distance_violins(zero_shot_test_embeddings, y_test, zero_shot_targets_embeddings, y_targets, title, interactive=interactive, classes=classes)
figure