In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

In [None]:
#build the dataset

N = 300
norm00 = np.random.multivariate_normal([-3,-3], [[1,0],[0,1]], size=N//2)
norm11 = np.random.multivariate_normal([3,3], [[1,0],[0,1]], size=N//2)
labels_np = np.int64(np.hstack((np.zeros(norm00.shape[0]),
                             np.ones(norm11.shape[0])
                            )))
features_np = np.float32(np.vstack((norm00, norm11)))
data = np.hstack([features_np, labels_np[:,None]])
np.random.shuffle(data)

labels_np, features_np = np.intp(data[:,-1]), np.float32(data[:,:2])

cut = N * 2 // 3
features_training, labels_training = features_np[:cut], labels_np[:cut]
features_testing, labels_testing = features_np[cut:], labels_np[cut:]

In [None]:
#generate the fingerprints
FP_NUM = 4
EPS = 0.06
alpha = 0.25
beta = 0.75
NUM_CLASS = 2
dx = (np.random.rand(FP_NUM, features_np.shape[1]) -.5)*2*EPS

dy_train = -np.ones((NUM_CLASS, NUM_CLASS))*alpha #2 is number of classes
for i in range(NUM_CLASS):
    dy_train[i,i] = beta

In [None]:
#build graph
features = tf.placeholder(features_np.dtype, [None, 2])
labels = tf.placeholder(labels_np.dtype, [None])
fp_dx = tf.placeholder(features_np.dtype, [FP_NUM, None, 2])
fp_dy = tf.constant(dy_train, dtype=fp_dx.dtype)

labels_oh = tf.one_hot(labels, NUM_CLASS)

fc1 = tf.layers.Dense(units=200, activation='relu', kernel_initializer=tf.initializers.truncated_normal)
fc2 = tf.layers.Dense(units=200, activation='relu', kernel_initializer=tf.initializers.truncated_normal)
logits_out = tf.layers.Dense(units=NUM_CLASS, activation=None, kernel_initializer=tf.initializers.truncated_normal)

# real data tensors
normed_x = features - tf.reduce_mean(features)
logits = logits_out(fc2(fc1(features)))
loss_vanilla = tf.losses.softmax_cross_entropy(labels_oh, logits)
probs = tf.nn.softmax(logits)
classification = tf.argmax(probs, axis=1)
correct_classification = tf.equal(classification, labels)
accuracy = tf.reduce_mean(tf.cast(correct_classification, tf.float32))


# fingerprint tensors
perturbed = features + fp_dx
fp_logits = logits_out(fc2(fc1(perturbed)))

FxDx = logits/(tf.norm(logits, axis=1)[:,None])[None,:,:] - fp_logits/tf.norm(fp_logits, axis=1)[:,None]
Dy = tf.gather(fp_dy, labels)
Dy = tf.stack([Dy]*FP_NUM)

loss_fp = tf.losses.mean_squared_error(Dy, FxDx)
loss = loss_vanilla + loss_fp

FxDx - Dy

t1 = tf.expand_dims(FxDx, 2)
t2 = t1 - fp_dy
t3 = t2 ** 2
t4 = tf.reduce_sum(t3, axis=-1)
t5 = (1.0 / FP_NUM) * tf.sqrt(t4)
t6 = tf.transpose(t5, [1, 2, 0])
t7 = tf.reduce_sum(t6, axis=-1)
t8 = tf.reduce_min(t7, axis=-1)

train = tf.train.AdamOptimizer(0.001).minimize(loss)

sess = tf.Session()
sess.run(tf.global_variables_initializer())

training_dict = {features:features_training, labels:labels_training, fp_dx:dx[:,np.newaxis,:]}
testing_dict = {features:features_testing, labels:labels_testing, fp_dx:dx[:,np.newaxis,:]}

for epoch in range(1000):
    sess.run([train], feed_dict=training_dict)
    test_data_loss, test_fp_loss, test_loss, test_acc = sess.run(
        [loss_vanilla, loss_fp, loss, accuracy],
        feed_dict=testing_dict
    )
    if epoch % 10 == 0:
        print('epoch', epoch, 'data loss %.6f fp loss %.6f total loss: %.6f accuracy: %.3f' %\
              (test_data_loss, test_fp_loss, test_loss, test_acc))

In [None]:
dx[:,0]

In [None]:
plt.arrow

In [None]:
x, y = np.mgrid[-8:8:0.05, -8:8:0.05]
prob_c = sess.run(probs > 0.5, feed_dict={features:np.dstack((x,y)).reshape((-1, 2))})[:,1]
plt.contourf(x, y, prob_c.reshape(x.shape), cmap='gray');
plt.colorbar()
plt.scatter(features_testing[:,0], features_testing[:,1], c = labels_testing)
for i in range(FP_NUM):
    plt.arrow(-3,-3, 100*dx[i,0], 100*dx[i,1], color = 'r')

In [None]:
x, y = np.mgrid[-8:8:0.05, -8:8:0.05]
dissimilarities = sess.run(t8, feed_dict={features:np.dstack((x,y)).reshape((-1, 2)), fp_dx:dx[:,None,:]})

In [None]:
plt.contourf(x, y, dissimilarities.reshape(x.shape), cmap='gray');
plt.colorbar()
plt.scatter(features_testing[:,0], features_testing[:,1], c = labels_testing)

In [None]:
# VERIFY FINGERPRINTS
mini_dict = {
    features:features_testing[:10],
    labels:labels_testing[:10],
    fp_dx:dx[:,np.newaxis,:]
}
fxdx, dy = sess.run([FxDx, Dy], feed_dict=mini_dict)

In [None]:
i = 6
print(dy[:,i])
print(fxdx[:,i])
print(labels_testing[0])