In [None]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers as kl
import matplotlib.pyplot as plt

In [None]:
def scatter_color(X,y,var=[0,1],colors=[]):
    #Useful function for plot in 2D data with labels
    #X is the data N x P
    #y is the label N x 1
    #var are the variables to plot from 0,P-1
    if len(colors)==0:
        for i in range(len(np.unique(y))):
            colors.append(np.random.rand(3))
    iteri=0
    for i in np.unique(y):
        plt.scatter(X[y==i,var[0]],X[y==i,var[1]],c=colors[iteri])
        iteri=iteri+1
    fig = plt.gcf()
    ax = fig.gca()
    circle1 = plt.Circle((0, 0), 1, color='r',fill=False)
    ax.add_patch(circle1)

In [None]:
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
# Load the data and split it between train and test sets
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
#y_train = y_train[:nexamples]
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")

In [None]:
#For the metric learning part, we have access to num_samples=512
num_samples=512
feature_space=16

In [None]:
def get_models(input_shape,feature_space,augmentation=True):
    xin=tf.keras.Input(shape=input_shape)
    if augmentation:
        xaug=kl.RandomTranslation(width_factor=.1,height_factor=.1)(xin)
        #Include other Augmentation usually helps!
        #xaug=kl.RandomRotation(.1)(xaug)
    else:
        xaug=xin
    x=kl.Flatten()(xaug)
    x=kl.Dense(256,'relu')(x)
    x=kl.Dense(64,'relu')(x)
    x=kl.BatchNormalization()(x)
    x=kl.Dense(feature_space)(x) #,kernel_constraint=tf.keras.constraints.UnitNorm(axis=0)
    x=kl.Dense(num_samples,'softmax')(x)
    model=tf.keras.Model(xin,x)
    model_features=tf.keras.Model(model.input,model.layers[-2].output)
    return model,model_features

In [None]:
model,model_features=get_models(input_shape=input_shape,feature_space=feature_space,augmentation=True)

In [None]:
model.summary()
model_features.summary()

In [None]:
epochs=512
batch_size=128
opt=tf.keras.optimizers.Adam(learning_rate=.001)
model.compile(loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"])
history=model.fit(x_train[0:num_samples], np.arange(num_samples), batch_size=batch_size, epochs=epochs)
    

In [None]:
x_features=model_features(x_train[:num_samples]).numpy()
x_features.shape

In [None]:
#Visualizing the first two coordinates
scatter_color(x_features,y_train[:num_samples],var=[0,1])

In [None]:
#Are this feature space good for classification?
xin=tf.keras.Input(shape=(feature_space))
xclass=kl.Dense(num_classes,'softmax')(xin)
model_classification=tf.keras.Model(xin,xclass)



In [None]:
model_classification.summary()

In [None]:
opt=tf.keras.optimizers.Adam(learning_rate=.01)
model_classification.compile(loss="sparse_categorical_crossentropy", optimizer=opt, metrics=["accuracy"])


In [None]:
history=model_classification.fit(model_features(x_train).numpy(), y_train, batch_size=batch_size, validation_data=(model_features(x_test).numpy(),y_test),epochs=80)