Import packages

In [1]:
import importlib
import tensorflow as tf
from triplet_dataset import MnistDatasetSmallRotations as MnistDataset
# import triplet_dataset_arrows as triplet_dataset
# importlib.reload(MnistDataset)
import visualize_embed as visualize_embed
import visualize_embed_tsne as visualize_embed_tsne
# importlib.reload(visualize_embed)

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


Define generic functions to initialize convolutional/pooling layers

In [3]:
def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.01)
    return tf.get_variable("weights", dtype=tf.float32, initializer=initial)

def bias_variable(shape):
    initial = tf.constant(0.01, shape=shape, dtype=tf.float32)
    return tf.get_variable("biases", dtype=tf.float32, initializer=initial)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')

def compute_euclidean_distances(x, y, w=None):
    d = tf.square(tf.subtract(x, y))
    if w is not None:
        d = tf.transpose(tf.multiply(tf.transpose(d), w))
    d = tf.sqrt(tf.reduce_sum(d))
    return d

Define the triplet network architecture in a class

In [4]:
class Triplet:
    
    # Create model
    def __init__(self):
        # Input and label placeholders
        with tf.variable_scope('input'):
            self.x = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name='x')
            self.xp = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name='xp')
            self.xn = tf.placeholder(tf.float32, shape=[None, 28, 28, 1], name='xn')
            self.weights = tf.placeholder(tf.float32, shape=[None], name='weights')
            self.y_ = tf.placeholder(tf.float32, shape=[None, 10], name='y_')
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        
        with tf.variable_scope('embedding') as scope:
            self.o = self.embedding_network(self.x)
            scope.reuse_variables()
            self.op = self.embedding_network(self.xp)
            self.on = self.embedding_network(self.xn)
        
        with tf.variable_scope('distances'):
            self.dp = compute_euclidean_distances(self.o, self.op)
            self.dn = compute_euclidean_distances(self.o, self.on)
            self.logits = tf.nn.softmax([self.dp, self.dn], name="logits")
        
        with tf.variable_scope('embed_loss'):
            self.embed_loss = tf.reduce_mean(tf.pow(self.logits[0], 2))
            
        with tf.variable_scope('classifier'):
            self.y = self.classification_network(self.o, self.y_, self.keep_prob)
        
        with tf.variable_scope('class_loss'):
            self.class_loss = tf.reduce_mean(-tf.reduce_sum(self.y_ * tf.log(self.y), reduction_indices=[1]))
    
    def embedding_network(self, x):
        dim = 1
        with tf.variable_scope('conv1'):
            out = 32
            w = weight_variable([5, 5, dim, out])
            b = bias_variable([out])
            h = max_pool_2x2(tf.nn.relu(conv2d(x, w) + b))
            dim = out
            x = h
        with tf.variable_scope('conv2'):
            out = 64
            w = weight_variable([3, 3, dim, out])
            b = bias_variable([out])
            h = max_pool_2x2(tf.nn.relu(conv2d(x, w) + b))
            dim = out
            x = h
        with tf.variable_scope('conv3'):
            out = 128
            w = weight_variable([3, 3, dim, out])
            b = bias_variable([out])
            h = max_pool_2x2(tf.nn.relu(conv2d(x, w) + b))
            dim = out
            x = h
        with tf.variable_scope('readout'):
            gpool = tf.nn.pool(x, [h.get_shape()[1], h.get_shape()[2]], pooling_type="MAX", padding="VALID", name="gpool")
            return tf.reshape(gpool, [-1, 128])
        
    def classification_network(self, x, y_, dropout):
        dim = 128
        with tf.variable_scope('fc1') as scope:
            out = 256
            x = tf.nn.dropout(x, keep_prob=dropout)
            w = weight_variable([dim, out])
            b = bias_variable([out])
            h = tf.nn.relu(tf.matmul(x, w) + b)
            dim = out
            x = h
        """with tf.variable_scope('fc2') as scope:
            out = 64
            x = tf.nn.dropout(x, keep_prob=dropout)
            w = weight_variable([dim, out])
            b = bias_variable([out])
            h = tf.nn.relu(tf.matmul(x, w) + b)
            dim = out
            x = h"""
        with tf.variable_scope('fc3') as scope:
            out = 10
            x = tf.nn.dropout(x, keep_prob=dropout)
            w = weight_variable([dim, out])
            b = bias_variable([out])
            self.y = tf.nn.softmax(tf.matmul(x, w) + b)
            correct_prediction = tf.equal(tf.argmax(self.y, 1), tf.argmax(y_, 1))
            self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
            return self.y

Prepare the network for training

In [None]:
triplet = Triplet()
embed_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "embedding")
embed_train_step = tf.train.AdamOptimizer().minimize(triplet.embed_loss, var_list=embed_train_vars)
class_train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "classifier")
class_train_step = tf.train.AdamOptimizer().minimize(triplet.class_loss)

Start training!

In [None]:
embed_batch_size = 32
embed_iterations = 2
logging_frequency = 5
class_batch_size = 32
class_eval_batch_size = 100
class_iterations = 2
mnist_dataset = MnistDataset()
eval_full = True
kp = 0.75
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(embed_iterations):
        for j in range((int)(50000/embed_batch_size)):
            batch = mnist_dataset.generate_train_data(embed_batch_size)
            if j % logging_frequency == 0:
                loss = sess.run(triplet.embed_loss, feed_dict={triplet.x: batch[0], triplet.xp: batch[1], triplet.xn: batch[2], triplet.weights: batch[3]}) 
                print('epoch %d (%d/50000), training loss %g' % (i, j*embed_batch_size, loss))
            embed_train_step.run(feed_dict={triplet.x: batch[0], triplet.xp: batch[1], triplet.xn: batch[2], triplet.weights: batch[3]})
        batch = mnist_dataset.generate_train_data(embed_batch_size)
        loss = sess.run(triplet.embed_loss, feed_dict={triplet.x: batch[0], triplet.xp: batch[1], triplet.xn: batch[2], triplet.weights: batch[3]}) 
        print('epoch %d, training loss %g' % (i, loss))
    test_batch = mnist_dataset.generate_test_data(100)
    embed = triplet.o.eval({triplet.x: test_batch})
    visualize_embed.visualize(embed, test_batch[:,:,:,0]) # Simple PCA
    visualize_embed_tsne.visualize(embed, test_batch[:,:,:,0]) # t-SNE
    for i in range(class_iterations):
        for j in range((int)(50000/class_batch_size)):
            input_images, correct_predictions = mnist.train.next_batch(class_batch_size)
            input_images = input_images.reshape(-1, 28, 28, 1)
            if j % logging_frequency == 0:
                loss, acc = sess.run([triplet.class_loss, triplet.accuracy], feed_dict={triplet.x: input_images, triplet.y_: correct_predictions, triplet.keep_prob:1.0})
                print('epoch %d (%d/50000), training loss %g, training accuracy %g' % (i, j*class_batch_size, loss, acc))
            class_train_step.run(feed_dict={triplet.x: input_images, triplet.y_: correct_predictions, triplet.keep_prob:kp})
        test_input_images, test_correct_predictions = mnist.test.next_batch(class_eval_batch_size)
        test_input_images = test_input_images.reshape(-1, 28, 28, 1)
        acc = sess.run(triplet.accuracy, feed_dict={triplet.x:test_input_images, triplet.y_:test_correct_predictions, triplet.keep_prob:kp})
        print('epoch %d, test acc: %g', i, acc)
    if eval_full:
        full_input_images = mnist.test.images.reshape(-1, 28, 28, 1)
        full_input_labels = mnist.test.labels
        acc = sess.run(triplet.accuracy, feed_dict={triplet.x:full_input_images, triplet.y_:full_input_labels, triplet.keep_prob:1.0})
        print ('full test acc: %g', acc)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
epoch 0 (0/50000), training loss 0.24948
epoch 0 (160/50000), training loss 0.208888
epoch 0 (320/50000), training loss 0.177448
epoch 0 (480/50000), training loss 0.00574936
epoch 0 (640/50000), training loss 0.00073135
epoch 0 (800/50000), training loss 0.000369516
epoch 0 (960/50000), training loss 2.7824e-08
epoch 0 (1120/50000), training loss 7.16799e-08
epoch 0 (1280/50000), training loss 8.14933e-08
epoch 0 (1440/50000), training loss 4.18311e-11
epoch 0 (1600/50000), training loss 2.73193e-05
epoch 0 (1760/50000), training loss 1.03003e-05
epoch 0 (1920/50000), training loss 1.05038e-11
epoch 0 (2080/50000), training loss 8.1261e-11
epoch 0 (2240/50000), training loss 5.32728e-11
epoch 0 (2400/50000), training loss 3.34033e-11
epoch 0 (2560/50000), training loss 8.35656e-06
epoch 0 (2720

In [None]:
visualize_embed.visualize(embed, test_batch[:,:,:,0]) # Simple PCA
visualize_embed_tsne.visualize(embed, test_batch[:,:,:,0]) # t-SNE

Define a classifier network that takes embeddings as input