In [2]:
import tensorflow as tf
import numpy as np

In [5]:
tf.reset_default_graph()
np.random.seed(42)
tf.set_random_seed(42)

In [6]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets("/tmp/data/")

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting /tmp/data/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting /tmp/data/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting /tmp/data/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz


In [10]:
X=tf.placeholder(shape=[None,28,28,1],dtype=tf.float32)
#As described in the paper
caps1_maps=32
caps1_n_caps=6*6*caps1_maps
caps1_dims=8
conv1_params = {
    "filters": 256,
    "kernel_size": 9,
    "strides": 1,
    "padding": "valid",
    "activation": tf.nn.relu,
}
conv2_params = {
    "filters": caps1_maps * caps1_dims, 
    "kernel_size": 9,
    "strides": 2,
    "padding": "valid",
    "activation": tf.nn.relu
}
#Primary Capsule layer
conv1 = tf.layers.conv2d(X, **conv1_params)
conv2 = tf.layers.conv2d(conv1, **conv2_params)
caps1_raw = tf.reshape(conv2, [-1, caps1_n_caps, caps1_dims])

In [11]:
def squash(s, axis=-1, epsilon=1e-7):
    squared_norm = tf.reduce_sum(tf.square(s), axis=axis,keep_dims=True)
    safe_norm = tf.sqrt(squared_norm + epsilon)
    squash_factor = squared_norm / (1. + squared_norm)
    unit_vector = s / safe_norm
    return squash_factor * unit_vector

In [12]:
caps1_output = squash(caps1_raw)

In [20]:
#Digit Capsule Layer
caps2_n_caps = 10
caps2_n_dims = 16
init_sigma = 0.1

W_init = tf.random_normal(
    shape=(1, caps1_n_caps, caps2_n_caps, caps2_n_dims, caps1_dims),
    stddev=init_sigma, dtype=tf.float32)
W = tf.Variable(W_init)
batch_size = tf.shape(X)[0]
# making duplicates for each batch
W_tiled = tf.tile(W, [batch_size, 1, 1, 1, 1], name="W_tiled")

In [21]:
#expanding output of primary caps so as to get column vectors instead of scalars
caps1_output_expanded = tf.expand_dims(caps1_output, -1,
                                       name="caps1_output_expanded")
#Creating extra dim for creating vector for 10 different digit
caps1_output_tile = tf.expand_dims(caps1_output_expanded, 2,
                                   name="caps1_output_tile")
caps1_output_tiled = tf.tile(caps1_output_tile, [1, 1, caps2_n_caps, 1, 1],
                             name="caps1_output_tiled")

In [22]:
caps2_predicted = tf.matmul(W_tiled, caps1_output_tiled)
raw_weights = tf.zeros([batch_size, caps1_n_caps, caps2_n_caps, 1, 1],dtype=np.float32)

In [25]:
#Round 1 of dynamic routing
routing_weights = tf.nn.softmax(raw_weights, dim=2)

weighted_predictions = tf.multiply(routing_weights, caps2_predicted)
weighted_sum = tf.reduce_sum(weighted_predictions, axis=1, keep_dims=True)
caps2_output_round_1 = squash(weighted_sum, axis=-2)
#This is done so that we can multiply for all capsule instances of i,j simultaneously
caps2_output_round_1_tiled = tf.tile(
    caps2_output_round_1, [1, caps1_n_caps, 1, 1, 1])

In [27]:
#Dot product
agreement = tf.matmul(caps2_predicted, caps2_output_round_1_tiled,transpose_a=True)
raw_weights_round_2 = tf.add(raw_weights, agreement)

In [32]:
#Round 2 Of dynamic Routing
routing_weights_round_2 = tf.nn.softmax(raw_weights_round_2,
                                        dim=2)
weighted_predictions_round_2 = tf.multiply(routing_weights_round_2,
                                           caps2_predicted)
weighted_sum_round_2 = tf.reduce_sum(weighted_predictions_round_2,
                                     axis=1, keep_dims=True)
caps2_output_round_2 = squash(weighted_sum_round_2,
                              axis=-2)
caps2_output = caps2_output_round_2

In [33]:
#For calculating prob(Norm of a vector gives prob)But directly using norm ,if it is zero causes weight problems (NaN)
def safe_norm(s, axis=-1, epsilon=1e-7, keep_dims=False):
    squared_norm = tf.reduce_sum(tf.square(s), axis=axis,
                                 keep_dims=keep_dims)
    return tf.sqrt(squared_norm + epsilon)

In [34]:
y_proba = safe_norm(caps2_output, axis=-2)
y_proba_argmax = tf.argmax(y_proba, axis=2)
y_pred = tf.squeeze(y_proba_argmax, axis=[1,2])

In [38]:
y = tf.placeholder(shape=[None], dtype=tf.int64)
#For Marginal loss func L(k) = Tk max(0, (m+) − ||vk||)^2 + λ (1 − Tk) max(0, ||vk|| − (m−))^2
m_plus = 0.9
m_minus = 0.1
lambda_ = 0.5
T = tf.one_hot(y, depth=caps2_n_caps)
caps2_output_norm = safe_norm(caps2_output, axis=-2, keep_dims=True)
present_error_raw = tf.square(tf.maximum(0., m_plus - caps2_output_norm))
present_error = tf.reshape(present_error_raw, shape=(-1, 10))
absent_error_raw = tf.square(tf.maximum(0., caps2_output_norm - m_minus))
absent_error = tf.reshape(absent_error_raw, shape=(-1, 10))
L = tf.add(T * present_error, lambda_ * (1.0 - T) * absent_error)
margin_loss = tf.reduce_mean(tf.reduce_sum(L, axis=1))

In [39]:
mask_with_labels = tf.placeholder_with_default(False, shape=())

In [40]:
reconstruction_targets = tf.cond(mask_with_labels, # condition
                                 lambda: y,        # if True for training
                                 lambda: y_pred)   # for testing

In [41]:
reconstruction_mask = tf.one_hot(reconstruction_targets,
                                 depth=caps2_n_caps)   # Its shape is now (?,10) but the shape of caps2_oytput is (?, 1, 10, 16, 1)
#reshaping so we can multiply to mask
reconstruction_mask_reshaped = tf.reshape(reconstruction_mask, [-1, 1, caps2_n_caps, 1, 1]) 
caps2_output_masked = tf.multiply(caps2_output, reconstruction_mask_reshaped)
#reshape operation to flatten the decoder's inputs
decoder_input = tf.reshape(caps2_output_masked,[-1, caps2_n_caps * caps2_n_dims]) #Decoder input shape is now (?,160)

In [42]:
#Decoder
n_hidden1 = 512
n_hidden2 = 1024
n_output = 28 * 28

hidden1 = tf.layers.dense(decoder_input, n_hidden1,activation=tf.nn.relu)
hidden2 = tf.layers.dense(hidden1, n_hidden2,activation=tf.nn.relu)
decoder_output = tf.layers.dense(hidden2, n_output,activation=tf.nn.sigmoid)

In [43]:
#Reconstruction Loss
X_flat = tf.reshape(X, [-1, n_output])
squared_difference = tf.square(X_flat - decoder_output)
reconstruction_loss = tf.reduce_mean(squared_difference)

In [44]:
#Final loss 
alpha = 0.0005
loss = tf.add(margin_loss, alpha * reconstruction_loss)

In [45]:
#Accuracy
correct = tf.equal(y, y_pred, name="correct")
accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
#Optimizer
optimizer = tf.train.AdamOptimizer()
training_op = optimizer.minimize(loss)
#init
init = tf.global_variables_initializer()

In [None]:
#Training the model
n_epochs = 10
batch_size = 50
restore_checkpoint = True

n_iterations_per_epoch = mnist.train.num_examples // batch_size
n_iterations_validation = mnist.validation.num_examples // batch_size
best_loss_val = np.infty


with tf.Session() as sess:
    init.run()

    for epoch in range(n_epochs):
        for iteration in range(1, n_iterations_per_epoch + 1):
            X_batch, y_batch = mnist.train.next_batch(batch_size)
            # Run the training operation and measure the loss:
            _, loss_train = sess.run(
                [training_op, loss],
                feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
                           y: y_batch,
                           mask_with_labels: True})
            print("\rIteration: {}/{} ({:.1f}%)  Loss: {:.5f}".format(
                      iteration, n_iterations_per_epoch,
                      iteration * 100 / n_iterations_per_epoch,
                      loss_train),
                  end="")

In [None]:
#Testing the model
n_iterations_test = mnist.test.num_examples // batch_size

with tf.Session() as sess:
    saver.restore(sess, checkpoint_path)

    loss_tests = []
    acc_tests = []
    for iteration in range(1, n_iterations_test + 1):
        X_batch, y_batch = mnist.test.next_batch(batch_size)
        loss_test, acc_test = sess.run(
                [loss, accuracy],
                feed_dict={X: X_batch.reshape([-1, 28, 28, 1]),
                           y: y_batch})
        loss_tests.append(loss_test)
        acc_tests.append(acc_test)
        print("\rEvaluating the model: {}/{} ({:.1f}%)".format(
                  iteration, n_iterations_test,
                  iteration * 100 / n_iterations_test),
              end=" " * 10)
    loss_test = np.mean(loss_tests)
    acc_test = np.mean(acc_tests)
    print("\rFinal test accuracy: {:.4f}%  Loss: {:.6f}".format(
        acc_test * 100, loss_test))