In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
FP_NUM = 4
EPS = 0.1
alpha = 0.25
beta = 0.75
NUM_CLASS = 10

(features_tr, labels_tr), (features_tst, labels_tst) = tf.keras.datasets.mnist.load_data()
features_tr = features_tr[...,np.newaxis]
features_tst = features_tst[...,np.newaxis]

In [None]:
dx = (np.random.rand(FP_NUM, *features_tr.shape[1:]) - .5) * 2 * EPS

dy_train = -np.ones((NUM_CLASS, NUM_CLASS)) * alpha
for i in range(NUM_CLASS):
    dy_train[i, i] = beta

In [None]:
# Define the models
def nn(input_shape):
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(
                32,
                (5,5),
                activation = 'relu',
                input_shape = input_shape
        ),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPool2D(),

        tf.keras.layers.Conv2D(
                64,
                (5,5),
                activation = 'relu',
                input_shape = features_tr[0][:,:,np.newaxis].shape
        ),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPool2D(),

        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(200, activation='relu'),
        tf.keras.layers.Dense(200, activation='relu'),
        tf.keras.layers.Dense(NUM_CLASS, activation=None)
    ])
    return model

def print_loss(epoch, data_loss, fp_loss, acc, sess, feed_dict):
    test_data_loss, test_fp_loss, test_loss, test_acc = sess.run(
        [loss_vanilla, loss_fp, loss, accuracy],
        feed_dict=feed_dict
    )
    print('epoch', epoch, 'data loss %.6f fp loss %.6f total loss: %.6f accuracy: %.3f' % \
          (test_data_loss, test_fp_loss, test_loss, test_acc))

def normalize(logits):
    return logits / tf.sqrt(tf.reduce_sum(logits**2, axis=1))[:, None]


In [None]:
tf.reset_default_graph()

# Build the graph
features = tf.placeholder(tf.float32, (None,) + features_tr.shape[1:])
labels = tf.placeholder(tf.int64, [None])
dx_tf = tf.constant(dx, dtype=tf.float32)
dy_tf = tf.constant(dy_train, dtype=dx_tf.dtype)
batch_size = tf.placeholder(tf.int64)

dataset = tf.data.Dataset.from_tensor_slices((features, labels))
batched_dataset = dataset.batch(batch_size)

iterator = batched_dataset.make_initializable_iterator()

features_batch, labels_batch = iterator.get_next()

labels_oh = tf.one_hot(labels_batch, NUM_CLASS)

network = nn(features_batch.shape[1:]) 

logits = network((features_batch + tf.random.uniform(tf.shape(features_batch)))/256)

loss_vanilla = tf.losses.softmax_cross_entropy(labels_oh, logits)
prediction = tf.argmax(tf.nn.softmax(logits), axis=1)
prediction_correct = tf.equal(prediction, labels)

accuracy = tf.reduce_mean(tf.cast(prediction_correct, tf.float32))

perturbed_input = tf.reshape(
    features_batch[tf.newaxis] + dx_tf[:, tf.newaxis],
    (-1,) + features_tr.shape[1:]
)
fp_logits = tf.reshape(
    network(perturbed_input),
    (FP_NUM, tf.shape(labels_batch)[0], NUM_CLASS)
)

FxDx = normalize(fp_logits) - normalize(logits) # the paper shows it in <- this order but it's then not optimizable
# FxDx = normalize(logits) - normalize(fp_logits)


Dy = tf.gather(dy_tf, labels_batch)
Dy = tf.stack([Dy] * FP_NUM)

loss_fp = tf.reduce_mean((Dy - FxDx)**2)
# loss_fp = tf.losses.mean_squared_error(Dy, FxDx)
loss = loss_vanilla# + loss_fp
train = tf.train.AdamOptimizer(.0001).minimize(loss)

In [None]:
# train
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# print_loss(0, loss_vanilla, loss_fp, accuracy, sess, feed_dict={})
for epoch in range(1):
    sess.run(iterator.initializer, {labels:labels_tr, features:features_tr, batch_size:500})
    batch_num = 0
    while True:
        try:  
            print(batch_num)
            sess.run([train])
            batch_num += 1
            
#             print_loss(epoch + 1, loss_vanilla, loss_fp, accuracy, sess, feed_dict={})
        except tf.errors.OutOfRangeError:
            break
        
    

In [None]:
x, y = np.mgrid[-8:8:0.04, -8:8:0.03]

# plot decision boundaries
prediction_np = sess.run(prediction, feed_dict={
                  features: np.dstack((x, y)).reshape((-1, 2))})
plt.contourf(x, y, prediction_np.reshape(x.shape), cmap='gray')
plt.scatter(features_training_np[:, 0],
            features_training_np[:, 1], c=labels_training_np)
# for i in range(FP_NUM):
#     plt.arrow(-3, -3, 100 * dx[i, 0], 100 * dx[i, 1], color='r')
plt.show()

t1 = tf.expand_dims(FxDx, 2)
t2 = t1 - dy_tf
t3 = t2 ** 2
t4 = tf.reduce_sum(t3, axis=-1)
t5 = (1.0 / FP_NUM) * tf.sqrt(t4)
t6 = tf.transpose(t5, [1, 2, 0])
t7 = tf.reduce_sum(t6, axis=-1)
t8 = tf.reduce_min(t7, axis=-1)

# plot fingerprint loss
x, y = np.mgrid[-8:8:0.05, -8:8:0.05]
dissimilarities = sess.run(t8, feed_dict={features: np.dstack(
    (x, y)).reshape((-1, 2)), dx_tf: dx[:, None, :]})
plt.contourf(x, y, np.clip(dissimilarities.reshape(x.shape), 0, 1))
plt.colorbar()
plt.scatter(features_training_np[:, 0],
            features_training_np[:, 1], c=labels_training_np)
plt.show()

In [None]:
# VERIFY FINGERPRINTS


fxdx, dy = sess.run([FxDx, Dy], feed_dict=training_dict)


i = 3
print(dy[:, i])
print(fxdx[:, i])

In [None]:
plt.hist(np.sum((dy - fxdx)**2, axis=(0,2)), log=True)