In [1]:
"""Simplest Capsule Implementation Based on my understanding of the Paper: https://arxiv.org/abs/1710.09829
Referred to clear Doubts: https://github.com/naturomics/CapsNet-Tensorflow
Tried to comment as much as possible, let me know if you find any error.
Test Acc: 99.19 % # First round with default params of paper
"""
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
# set mnist data obj.
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True, reshape=False)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [3]:
# Already trained it for 54500 global steps
ckpt = '../capsnet.ckpt' # Checkpoint file
resume = True # True if you have the above file

lr = 1e-4 # for 2nd round of training # first round was default 0.001

train_round = 2
do_test_first = True # do test eval
save_after = 500 # and validate # checkpointing global step

### Hyperparameters

In [4]:
# Margin loss
m_plus = 0.9
m_minus = 0.1
lambda_val = 0.5  # down weight of the loss for absent digit classes
epsilon = 1e-9

batch_size = 128
epoch = 3
routing_iter = 3  # number of iterations in routing algorithm

# Functions

### 1. Squashing

In [5]:
def squash(in_tensor):
    """Squashing Function
    Args:
        (S) in_tensor: A 5-D input tensor of a capsule layer, shape [N, 1, J, next_D, 1],
    Returns:
        (V) out_tensor: A 5-D tensor with the same shape as input tensor of capsule layer but squashed in 3rd dimension.
    """
    l2 = tf.norm(in_tensor, axis=-2, keep_dims=True,
                 name='l2_per_vector_per_capsule_unit')  # [N, 1, J, 1, 1]
    l2_square = tf.square(l2)
    squash_factor = l2_square / (1 + l2_square)
    return squash_factor * (in_tensor / l2)  # [N, 1, J, next_D, 1]

### 2. Capsule Operations

In [6]:
# 2. Agreement
def prediction_agreement(current_active_out, prediction):
    # Tile current_active_out to get I=1152
    V_J = tf.tile(current_active_out, multiples=[
                  1, 1152, 1, 1, 1])  # [N, 1152, 10, 16, 1]
    # [N, 1152, 10, 1, 16] x [N, 1152, 10, 16, 1] = [N, 1152, 10, 1, 1]
    agreement = tf.matmul(prediction, V_J, transpose_a=True)
    # Learning from samples aggregated # [1, 1152, 10, 1, 1]
    agreement = tf.reduce_sum(agreement, axis=0, keep_dims=True)
    return agreement

In [7]:
# 2. Routing Algorithm
def routing(prediction, num_iters, prior_IJ):

    with tf.variable_scope('routing'):
        for r in range(int(num_iters)):
            with tf.variable_scope('routing_iter_' + str(r)):
                # Step 4:
                c_IJ = tf.nn.softmax(prior_IJ, dim=2)  # [N, 1152, 10, 1, 1]
                # Step 5:
                weighted_unactive_out = tf.reduce_sum((c_IJ * prediction),
                                                      axis=1,
                                                      keep_dims=True)  # [N, 1, 10, 16, 1]
                # Step 6:
                current_active_out = squash(
                    weighted_unactive_out)  # [N, 1, 10, 16, 1]
                # Step 7:
                prior_IJ += prediction_agreement(
                    current_active_out, prediction)

        return current_active_out  # [N, 1, 10, 16, 1]

In [8]:
# attribs of caps layer = W_IJ, prior_IJ, Conv2dunits
# We Know I how to Know J?
# Wij = [8 x 16]

# I see this as information in 4D encoded ---> distributed represenation in 3D,
# We can use it to flatten as well by setting next_capsules = 1, and next_D = vector dimension
# Naive Flattening won't be required.
def capsule_layer_prediction(in_tensor, num_iters=3, D=8, capsules=32, next_D=16, next_capsules=10, kernel_size=9, 
                             strides=2, name='caps_conv2d'):

    batch_size = tf.shape(in_tensor)[0]
    with tf.variable_scope(name):
        # Optimized way from: https://github.com/naturomics/CapsNet-Tensorflow
        capsule = tf.layers.conv2d(in_tensor,
                                   capsules * D,
                                   kernel_size,
                                   strides,
                                   kernel_initializer=tf.contrib.layers.xavier_initializer()) # NHWC*D # [N, 6, 6, 32*8]
        # output of capsule units
        # NID where I = C*H*W = [N, 1152, 1, 8, 1] <-- 1152[8, 1]
        U_I = tf.reshape(capsule, shape=[batch_size, -1, 1, D, 1])
        U_I = tf.tile(U_I, multiples=[1, 1, 10, 1, 1]) #  [N, 1152, 10, 8, 1]
        U_I = squash(U_I) # Technically this is the end of primary capsules

        capsule_shape = tf.shape(capsule)
        I = 1152  # capsule_shape[1] # 1152 = 32 * 6 * 6 <-------- can be calculated dynamically
        J = next_capsules  # 10

        # W_IJ, shared weights
        W_IJ = tf.Variable(tf.random_normal([1, I, J, D, next_D], stddev=0.03))  # [1, 1152, 10, 8, 16]
        W_IJ = tf.tile(W_IJ, 
                       multiples=[batch_size, 1, 1, 1, 1]) # [N, 1152, 10, 8, 16] <-- IJ[8, 16]
        
        # prediction vectors, prediction u_j = [16D]
        # [N, 1152, 10, 16, 8] x [N, 1152, 10, 8, 1] = [N, 1152, 10, 16, 1]
        prediction_vectors = tf.matmul(W_IJ, U_I, transpose_a=True)
    
    with tf.variable_scope('routing'): # Not required
        prior_IJ = tf.constant(np.zeros([1, I, J, 1, 1]), dtype=np.float32)  # nijkl
        # NInext_D = [N, 1152, 10, 1, 1]
        prior_IJ = tf.tile(prior_IJ, multiples=[batch_size, 1, 1, 1, 1])

    activations = routing(prediction_vectors, num_iters, prior_IJ)

    with tf.control_dependencies([activations]): # Sanity
        return tf.squeeze(activations)  # [N, 10, 16]

# Network

In [9]:
## Network Arch:
class CapsNet:
    def __init__(self, inp_dim=28, num_iters=3, pred_vec_len=16, lr=1e-3, 
                 classes=10, m_plus=0.9, m_minus=0.1, lambda_val=0.5, scope="CapsNet"):

        print("Learning Rate={}, m_plus={}, m_minus={}, lambda_val={}".format(lr, 
                                                                              m_plus,
                                                                              m_minus,
                                                                              lambda_val))

        with tf.variable_scope(scope):
            self.X = tf.placeholder(
                tf.float32, shape=[None, inp_dim, inp_dim, 1], name='inputs')
            self.Y = tf.placeholder(tf.float32, shape=[None, 10], name='one_hot_labels')

            with tf.variable_scope('conv_layer'):
                self.conv1 = tf.layers.conv2d(self.X,
                                              filters=256,
                                              kernel_size=9,
                                              kernel_initializer=tf.contrib.layers.xavier_initializer(),
                                              activation=tf.nn.relu,
                                              name='conv1')  # [N, 20, 20, 256]

            # Capsule Layer Out
            self.pred = tf.squeeze(capsule_layer_prediction(in_tensor=self.conv1, 
                                                            num_iters=num_iters, next_D=pred_vec_len),  # [N, 1, 10, 16, 1]
                                   name='distributed_prediction')  # [N, 10, 16]

            with tf.variable_scope('masking'):
                self.masked_pred = tf.matmul(self.pred, tf.reshape(self.Y, shape=[-1, 10, 1]),
                                             transpose_a=True, name='masked_pred')  # [N, 16, 10] x [N, 10, 1] = [N, 16, 1]
                # [N, 10, 16] --> [N, 10]
                self.pred_length = tf.sqrt(tf.reduce_sum(tf.square(self.pred) + 1e-9, axis=2, keep_dims=False))  

            # [N, 10] * [N, 10] = [N, 10]
            self.m_plus = self.Y * tf.square(tf.maximum(0., m_plus - self.pred_length)) # [N, 10]
            self.m_minus = lambda_val * (1 - self.Y) * tf.square(tf.maximum(0., self.pred_length - m_minus))  # [N, 10]
            self.margin_loss = tf.reduce_mean(tf.reduce_sum(self.m_plus + self.m_minus, axis=-1), name='margin_loss')

            with tf.variable_scope('decoder'):
                # [N, 16, 1] --> [N, 16]
                decoder_inp = tf.reshape(self.masked_pred, 
                                         shape=[-1, pred_vec_len])
                self.fc1 = tf.layers.dense(decoder_inp,
                                           units=512,
                                           activation=tf.nn.relu,
                                           name='fc1')
                self.fc2 = tf.layers.dense(self.fc1,
                                           units=1024,
                                           activation=tf.nn.relu,
                                           name='fc2')

                self.pred_X = tf.layers.dense(self.fc2,
                                              units=inp_dim * inp_dim,
                                              activation=tf.nn.sigmoid,
                                              name='pred_X')  # [N, 28*28]

                self.pred_image = tf.reshape(self.pred_X, shape=[-1, inp_dim, inp_dim, 1])
            
            # Backward
            self.reconstruction_loss = tf.reduce_sum(tf.square(tf.reshape(self.X, shape=[-1, inp_dim * inp_dim]) - self.pred_X),
                                                     name='reconstruction_loss')

            self.loss = tf.add(self.margin_loss, 0.0005 * self.reconstruction_loss, name='loss')
            self.optimizer = tf.train.AdamOptimizer(learning_rate=lr)
            self.global_step = tf.Variable(0, name='global_step', trainable=False)
            self.train_op = self.optimizer.minimize(self.loss, global_step=self.global_step)
            
            # function variables/ops:
            # [N, 10] --> # [N,]
            self.predictions = tf.squeeze(tf.argmax(self.pred_length, axis=1, output_type=tf.int32, name='predictions'))
            # [N, 10] --> # [N,]
            self.true = tf.squeeze(tf.argmax(self.Y, axis=1, output_type=tf.int32, name='true_values'))
            
            self.correct = tf.cast(tf.equal(self.true, self.predictions), dtype=tf.float32)
            self.acc = tf.reduce_mean(self.correct) * 100

            # meta variables:
            self.tvars = tf.trainable_variables()

    def predict(self, xs, get_recon_images=False, sess=None):
        """Returns Predicted Number and Reconstructed Image."""
        sess = sess or tf.get_default_session()
        if get_recon_images:
            return sess.run([self.predictions, self.pred_image], feed_dict={self.X: xs})
        else:
            return sess.run(self.predictions, feed_dict={self.X: xs})

    def accuracy(self, xs, ys, sess=None):
        """Predicts and returns accuracy at current state."""
        sess = sess or tf.get_default_session()
        return sess.run(self.acc, feed_dict={self.X: xs, self.Y: ys})

    def learn(self, xs, ys, val_xs=None, val_ys=None, sess=None):
        """Train Step"""
        sess = sess or tf.get_default_session()
        
        if val_xs is not None and val_ys is not None:
            val_acc = self.accuracy(val_xs, val_ys, sess=sess)
            return val_acc
        else:
            train_acc, loss, _ = sess.run([self.acc, self.loss, self.train_op], feed_dict={self.X: xs, self.Y: ys})
            return train_acc, loss

# Train

In [10]:
## Training:
"""
55,000 data points of training data (mnist.train),
10,000 points of test data (mnist.test)
and 5,000 points of validation data (mnist.validation).
"""
tf.reset_default_graph()

caps_net = CapsNet(lr=lr)
print('Network Built..')

tvars = tf.trainable_variables()
saver = tf.train.Saver(var_list=tvars)

file = open('avg_log_{}.csv'.format(train_round), 'w') # avg. acc over last 500 train steps, validation acc to file.
file.write('step,avg_train_acc,step_val_acc\n')

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    if resume:
        saver.restore(sess, ckpt)
        print("RESUMED..")
    
    if do_test_first:
        test_batch_size = 250
        ta = 0
        total_steps = len(mnist.test.images)//test_batch_size
        for step in range(1, total_steps+1):
            val_x, val_y = mnist.test.next_batch(test_batch_size)
            test_batch_accuracy = caps_net.learn(None, 
                                                 None, 
                                                 val_xs=val_x,
                                                 val_ys=val_y,
                                                 sess=sess)
            ta += test_batch_accuracy
            print('{},{:>3.3f}'.format(step, test_batch_accuracy))
    
    print('Test Accuracy: ', ta / total_steps)
    print('*'*90)
    print('Training..learning_rate = {}, total_epochs = {}'.format(lr, epoch))
    total_steps_approx = 500 * epoch
    ta = 0
    for step in range(1, total_steps_approx + 1):

        batch_x, batch_y = mnist.train.next_batch(batch_size, shuffle=True)
        global_step = sess.run(caps_net.global_step)

        if step % save_after == 0:
            saver.save(sess, ckpt)
            print("SAVED..")

            va = 0
            for v in range(200):
                val_x, val_y = mnist.validation.next_batch(250, shuffle=True)
                val_accuracy = caps_net.learn(None, 
                                              None, 
                                              val_xs=val_x,
                                              val_ys=val_y,
                                              sess=sess)
                va += val_accuracy

            print()
            file.write('{},{:>3.3f},{:>3.3f}\n'.format(step, ta / 500, va / 200))
            file.flush()
            ta = 0

        else:
            train_accuracy, train_loss = caps_net.learn(batch_x, batch_y, sess=sess)
            print('{},{:>3.3f},{:>3.3f}'.format(step, train_accuracy, train_loss))
            ta += train_accuracy

file.close()

Learning Rate=0.0001, m_plus=0.9, m_minus=0.1, lambda_val=0.5
Network Built..
INFO:tensorflow:Restoring parameters from capsnet.ckpt
RESUMED..
1,98.800
2,99.200
3,99.600
4,98.800
5,99.200
6,99.600
7,98.800
8,98.800
9,99.600
10,99.600
11,99.600
12,99.600
13,98.000
14,98.400
15,99.600
16,98.000
17,100.000
18,99.600
19,99.600
20,99.600
21,99.600
22,98.800
23,99.600
24,99.600
25,98.800
26,98.400
27,99.200
28,98.800
29,100.000
30,98.000
31,99.200
32,99.600
33,99.600
34,99.200
35,99.600
36,98.400
37,98.400
38,99.600
39,100.000
40,99.200
Test Accuracy:  99.1900085449
******************************************************************************************
Training..learning_rate = 0.0001, total_epochs = 3
1,100.000,0.167
2,100.000,0.176
3,100.000,0.156
4,100.000,0.172
5,100.000,0.164
6,100.000,0.154
7,100.000,0.175
8,100.000,0.172
9,100.000,0.165
10,100.000,0.160
11,100.000,0.174
12,100.000,0.175
13,100.000,0.157
14,100.000,0.158
15,100.000,0.150
16,100.000,0.165
17,100.000,0.179
18,100.000,

423,100.000,0.163
424,100.000,0.165
425,100.000,0.180
426,100.000,0.177
427,100.000,0.161
428,100.000,0.159
429,100.000,0.172
430,100.000,0.171
431,100.000,0.153
432,100.000,0.176
433,100.000,0.175
434,100.000,0.166
435,100.000,0.176
436,100.000,0.169
437,100.000,0.160
438,100.000,0.151
439,100.000,0.157
440,100.000,0.162
441,100.000,0.171
442,100.000,0.169
443,100.000,0.157
444,100.000,0.180
445,100.000,0.169
446,100.000,0.168
447,100.000,0.162
448,100.000,0.173
449,100.000,0.160
450,100.000,0.157
451,100.000,0.156
452,100.000,0.163
453,100.000,0.158
454,100.000,0.154
455,100.000,0.183
456,100.000,0.175
457,100.000,0.151
458,100.000,0.160
459,100.000,0.155
460,100.000,0.176
461,100.000,0.159
462,100.000,0.156
463,100.000,0.166
464,100.000,0.171
465,100.000,0.165
466,100.000,0.169
467,100.000,0.157
468,100.000,0.155
469,100.000,0.168
470,100.000,0.163
471,100.000,0.173
472,100.000,0.162
473,100.000,0.160
474,100.000,0.170
475,100.000,0.155
476,100.000,0.161
477,100.000,0.166
478,100.00

879,100.000,0.168
880,100.000,0.152
881,100.000,0.182
882,100.000,0.163
883,100.000,0.166
884,100.000,0.169
885,100.000,0.171
886,100.000,0.156
887,100.000,0.155
888,100.000,0.156
889,100.000,0.170
890,100.000,0.168
891,100.000,0.160
892,100.000,0.154
893,100.000,0.160
894,100.000,0.172
895,100.000,0.156
896,100.000,0.143
897,100.000,0.166
898,100.000,0.180
899,100.000,0.164
900,100.000,0.165
901,100.000,0.155
902,100.000,0.160
903,100.000,0.163
904,100.000,0.172
905,100.000,0.174
906,100.000,0.151
907,100.000,0.171
908,100.000,0.159
909,100.000,0.164
910,100.000,0.170
911,100.000,0.166
912,100.000,0.152
913,100.000,0.168
914,100.000,0.174
915,100.000,0.157
916,100.000,0.165
917,100.000,0.165
918,100.000,0.163
919,100.000,0.160
920,100.000,0.164
921,100.000,0.156
922,100.000,0.168
923,100.000,0.151
924,100.000,0.149
925,100.000,0.163
926,100.000,0.158
927,100.000,0.164
928,100.000,0.152
929,100.000,0.164
930,100.000,0.162
931,100.000,0.165
932,100.000,0.163
933,100.000,0.156
934,100.00

1318,100.000,0.163
1319,100.000,0.168
1320,100.000,0.168
1321,100.000,0.173
1322,100.000,0.170
1323,100.000,0.168
1324,99.219,0.146
1325,100.000,0.163
1326,100.000,0.173
1327,100.000,0.164
1328,100.000,0.187
1329,100.000,0.147
1330,100.000,0.146
1331,100.000,0.166
1332,100.000,0.170
1333,100.000,0.164
1334,100.000,0.160
1335,100.000,0.157
1336,100.000,0.169
1337,100.000,0.150
1338,100.000,0.166
1339,100.000,0.164
1340,100.000,0.174
1341,100.000,0.175
1342,100.000,0.171
1343,100.000,0.172
1344,100.000,0.176
1345,100.000,0.143
1346,100.000,0.165
1347,100.000,0.158
1348,100.000,0.156
1349,100.000,0.146
1350,100.000,0.175
1351,100.000,0.168
1352,100.000,0.173
1353,100.000,0.161
1354,100.000,0.163
1355,100.000,0.163
1356,100.000,0.158
1357,100.000,0.173
1358,100.000,0.160
1359,100.000,0.165
1360,100.000,0.163
1361,100.000,0.174
1362,100.000,0.169
1363,100.000,0.158
1364,100.000,0.169
1365,100.000,0.146
1366,100.000,0.172
1367,100.000,0.162
1368,100.000,0.159
1369,100.000,0.172
1370,100.000,