In [1]:
import tensorflow as tf
import numpy as np


In [2]:
class Mnist:
    def __init__(self, path, train=True):
        if train:
            image_file_name = "train-images.idx3-ubyte"
            label_file_name = "train-labels.idx1-ubyte"
        else:
            image_file_name = "t10k-images.idx3-ubyte"
            label_file_name = "t10k-labels.idx1-ubyte"
            
        self.f = open("{}//{}".format(path, image_file_name), "rb")
        magic_number, number_of_images, image_rows, image_columns = int(self.f.read(4).hex(), 16), \
        int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16)
        print("Images magic no: {},  No of images: {}, Image rows: {}, Image cols: {}".\
          format(magic_number, number_of_images, image_rows, image_columns))

        self.flab = open("{}//{}".format(path, label_file_name), "rb")
        magic_number, number_of_items = int(self.flab.read(4).hex(), 16), int(self.flab.read(4).hex(), 16)
        print("Labels magic no: {}, No of items: {}".format(magic_number, number_of_items))
        

    def _get_next(self):
        val = self.f.read(28*28)
        if len(val)>0:
            img = val.hex()
            idxs = np.arange(0, len(img), 2).astype(int)
            return(np.reshape([int(img[i:(i+2)], 16) for i in idxs], (28, 28)), int(self.flab.read(1).hex(), 16)) 
        else:
            return("", "")

    def get_batch(self, batch_size=20, scale=True):
        im_lst = np.zeros(shape=[batch_size, 28, 28], dtype="float")
        lbl_lst = np.zeros(shape=[batch_size], dtype="int")
        lbl_one_hot_lst = np.zeros(shape=[batch_size, 10], dtype="int")
        for i in range(batch_size):
            im, lbl = self._get_next()
           
            if len(im)>0:
                if scale:
                    im=im/255.
                im_lst[i,...] = im
                lbl_lst[i] = lbl
                lbl_one_hot_lst[i,lbl] = 1
        return im_lst, lbl_lst, lbl_one_hot_lst


# Helper functions

In [3]:
def get_s(b_prior, u_hat):
    c = tf.nn.softmax(b_prior, dim=2)
    s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1)
    return(s, c)

#Squash
def get_v(s):
    norm_s = tf.norm(s, axis=2, keep_dims=True)
    norm_s_2 = tf.pow(norm_s, 2)
    v = (tf.multiply(norm_s_2,s))/(tf.multiply(1+norm_s_2, norm_s))

    return(v)

def route(n_iter, batch_item_cnt, v, b_prior, u_hat, aggr, enable_b_prior_update):
    u_hat_stopped = tf.stop_gradient(u_hat)
    
    for i in range(n_iter):
        with tf.name_scope("Iteration_{}".format(i)):
            s, c = get_s(b_prior, u_hat_stopped)
            v = get_v(s)
            aggr = tf.reduce_sum(tf.multiply(v[batch_item_cnt, :, :], u_hat_stopped[batch_item_cnt, :, :]),
                                 axis=-1, keep_dims=True)
            aggr = tf.expand_dims(aggr, axis=0)
            aggr = tf.multiply(aggr, enable_b_prior_update)
            b_prior = tf.add(b_prior, aggr)
            
    return (v, b_prior, aggr)

def update_prior(batch_item_cnt, v, b_prior, u_hat, aggr, enable_b_prior_update):
    v, b_prior, aggr = route(2, batch_item_cnt, v, b_prior, u_hat, aggr, enable_b_prior_update)
    return([tf.add(batch_item_cnt, 1), v, b_prior, u_hat, aggr, enable_b_prior_update])

# Build CapsNet architecture
32 layer Capsules each with (6, 6) 8D vectors

In [4]:
def build_graph(batch_size, is_train=True, add_summaries=False, print_shape=False):

    graph = tf.Graph()
    end_points = dict()

    with graph.as_default():

        X = tf.placeholder(shape=[batch_size, 28, 28, 1], name="X", dtype=tf.float32)
        end_points["X"] = X
        Y = tf.placeholder(shape=[batch_size], name="Y", dtype=tf.int32)
        end_points["Y"] = Y
        Y_one_hot = tf.one_hot(indices=Y, depth=10)
        end_points["Y_one_hot"] = Y_one_hot
        
        enable_b_prior_update = tf.placeholder(shape = [], dtype=tf.float32)
        end_points["enable_b_prior_update"] = enable_b_prior_update
        
   

        #First Layer
        with tf.name_scope("Conv_Layer_1"):

            conv_1 = tf.contrib.layers.conv2d(inputs=X, 
                                              num_outputs=256,  
                                              kernel_size=9, 
                                              stride=1, padding="VALID", 
                                              activation_fn=tf.nn.relu, 
                                              weights_initializer=tf.contrib.layers.xavier_initializer_conv2d(), 
                                              biases_initializer=tf.zeros_initializer())


            end_points["conv_layer_1_act"] = conv_1

        capsules = list()
        with tf.variable_scope("Capsule_Layer"):
            for n in range(32):
                cap_1 = tf.contrib.layers.conv2d(inputs=conv_1, 
                                                 num_outputs=8, kernel_size=9, 
                                                 activation_fn=None,
                                                 stride=2, padding="VALID" ,
                                                 weights_initializer=tf.random_uniform_initializer(minval=-0.1, 
                                                                                                   maxval=0.1),
                                                 biases_initializer=tf.zeros_initializer())

                end_points["capsule_{}_act".format(n)] = cap_1
                capsules.append(tf.expand_dims(cap_1, axis=3))

        with tf.name_scope("Transformation"):
            u = tf.concat(capsules, axis=3)
            end_points["u_b_r"] = u

            u = tf.reshape(u, shape=[-1, 1152, 8])
            u = tf.expand_dims(tf.expand_dims(u, axis=2), axis=2)
            end_points["u"] = u

            W = tf.get_variable(shape=[1, 1152, 10, 16, 8], name="W_t", 
                                initializer=tf.random_uniform_initializer(minval=-0.1, maxval=0.1))
            end_points["W_t"] = W
           
            u_hat =tf.reduce_sum(tf.multiply(u, W), axis=4)
            end_points["u_hat"] = u_hat

            b_prior= tf.get_variable(shape=[1, 1152, 10, 1], name="b_prior", initializer=tf.zeros_initializer(), 
                                      trainable=False)
            end_points["b_prior"] = b_prior
           
            #aggr = tf.get_variable(shape=[1, 1152, 10, 1], name="aggr", initializer=tf.zeros_initializer(), 
                                      #trainable=False)
            aggr = tf.constant(value=0, shape=[1, 1152, 10, 1], dtype=tf.float32, name="aggr")
            
            s, c = get_s(b_prior, u_hat)
            end_points["s"],  end_points["c"] = s, c
            
            v = get_v(s)
            end_points["v"] = v

        if is_train:
            with tf.name_scope("Routing"):

                #Routing
                batch_item_cnt = tf.constant(0)
                end_points["batch_item_cnt_before"] = batch_item_cnt  
                
                end_of_batch = lambda batch_item_cnt, v, b_prior, u_hat, aggr, enable_b_prior_update: \
                                        tf.less(batch_item_cnt, batch_size)

                route_op = tf.while_loop(body=update_prior, cond=end_of_batch, 
                                         loop_vars=[batch_item_cnt, v, b_prior, u_hat, aggr, enable_b_prior_update])

                end_points["route_op"] = route_op

                [batch_item_cnt, _, b_prior_updated, u_hat, aggr, enable_b_prior_update] = route_op
                end_points["aggr"] = aggr
                end_points["batch_item_cnt_after"] = batch_item_cnt
             
                
                b_prior_op = tf.assign(b_prior, b_prior_updated)
                
                s_r, c_r = get_s(b_prior_op, u_hat)
                v_r = get_v(s_r)
                end_points["s_routed"] = s_r
                end_points["c_routed"] = c_r
                end_points["b_prior_routed"] = b_prior_op
                end_points["u_hat_routed"] = u_hat
                end_points["v_routed"] = v_r
                if add_summaries:
                    c_squeezed = tf.squeeze(c_r)
                    for j in range(10):
                        tf.summary.histogram(values=c_squeezed[:,j], name="c{}".format(j))
                

        with tf.name_scope("Prepare_for_FC"):
            Y_one_hot_ex = tf.expand_dims(Y_one_hot, axis=1)
            end_points["Y_one_hot_ex"] = Y_one_hot_ex

            v_masked = tf.squeeze(tf.matmul(Y_one_hot_ex, v_r))
            end_points["v_masked"] = v_masked
            
        with tf.name_scope("Fully_connected"):
            fc_1 = tf.contrib.layers.fully_connected(inputs=v_masked, num_outputs=512, 
                                                     activation_fn=tf.nn.relu,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())

            end_points["fc_1"] = fc_1
      
            fc_2 = tf.contrib.layers.fully_connected(inputs=fc_1, num_outputs=1024, 
                                                     activation_fn=tf.nn.relu,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())
            end_points["fc_2"] = fc_2
           
            fc_out = tf.contrib.layers.fully_connected(inputs=fc_2, num_outputs=784, 
                                                     activation_fn=tf.nn.sigmoid,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())
            end_points["fc_out"] = fc_out
           
        
        with tf.name_scope("Calculate_Losses"):
            lambd_a = 0.5
            scaler = 0.0005

            #margin_loss =Y_one_hot_ex 
            v_norm = tf.norm(v_r, axis=2)
            end_points["v_norm"] = v_norm

            margin_loss_present = tf.multiply(tf.pow(tf.maximum(0., 0.9 - v_norm), 2), Y_one_hot)
            end_points["margin_loss_present"] = margin_loss_present

            margin_loss_not_present = lambd_a * tf.multiply(tf.pow(tf.maximum(0., v_norm - 0.1), 2), (1 - Y_one_hot))
            end_points["margin_loss_not_present"] = margin_loss_not_present

            total_margin_loss = tf.reduce_sum(tf.add(margin_loss_present, margin_loss_not_present), axis=1)
            end_points["total_margin_loss"] = total_margin_loss

            reconstruction_loss = tf.multiply(scaler, 
                                              tf.reduce_sum(
                                                  tf.squared_difference(
                                                      tf.reshape(X, shape=[-1, 784]), fc_out), axis=1))
            end_points["reconstruction_loss"] = reconstruction_loss


            loss = tf.reduce_mean(tf.add(total_margin_loss, reconstruction_loss))
            end_points["loss"] = loss
            
        with tf.name_scope("Calculate_Accuracy"):

            predicted_labels = tf.argmax(v_norm, axis=1, output_type=tf.int32)
            end_points["predicted_labels"] = predicted_labels

            no_of_corrects = tf.reduce_sum(tf.cast(tf.equal(predicted_labels, Y), tf.int32))

            accuracy = no_of_corrects/batch_size
            end_points["accuracy"] = accuracy

        if add_summaries:
            tf.summary.scalar("loss_avg", loss)
            tf.summary.scalar("accuracy", accuracy)
            
        if is_train:
            optimizer = tf.train.AdamOptimizer()
            grads = optimizer.compute_gradients(loss)
            end_points["grads"] = grads
            optimize = optimizer.apply_gradients(grads)
            end_points["optimize"] = optimize

        if add_summaries:
            summaries_merged = tf.summary.merge_all()
            end_points["summaries_merged"] = summaries_merged

        init = tf.global_variables_initializer()
        end_points["init"] = init
        
        if print_shape:
            for k in end_points.keys():
                k_shape="NA"
                try:
                    k_shape = str(end_points[k].shape)
                except AttributeError:
                    k_shape = type(end_points[k])

                print("\r {} shape={}".format(k, k_shape))

        return(graph, end_points)

# Training the capsule network

In [None]:
batch_size = 200
mnist = Mnist(path="..//", train=True)
data, lbls, lbls_one = mnist.get_batch(60000, scale=True)  
all_idxs = np.arange(0, 60000)
np.random.seed(10)
np.random.shuffle(all_idxs)
number_of_training_samples = 60000 - batch_size
assert(number_of_training_samples == int(number_of_training_samples))

number_of_splits = number_of_training_samples/batch_size
assert(number_of_splits == int(number_of_splits))

train_idxs, test_idxs = all_idxs[0:number_of_training_samples], all_idxs[number_of_training_samples::]

batches = np.split(train_idxs, number_of_splits)

Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000


In [None]:
tf.reset_default_graph()

graph, end_points = build_graph(batch_size, is_train=True, add_summaries=True, print_shape=False)
n_epochs = 5

with tf.Session(graph=graph) as sess:
    sess.run(end_points["init"])
    
    saver_train = tf.summary.FileWriter("..//log//train", graph=graph)
    saver_test = tf.summary.FileWriter("..//log//test", graph=graph)
    
    other_variables = [v for v in graph.get_collection("variables") if v.name=='b_prior:0']
    trainable_variables = graph.get_collection("trainable_variables")
    variables_to_restore = trainable_variables.extend(other_variables)
    
    enable_b_prior_update = end_points["enable_b_prior_update"]
    model_saver = tf.train.Saver(variables_to_restore)
   
    loss_report_counter = 0
          
    for epoch in range(n_epochs):
     
 
        for batch_idxs in batches:
            train_data=np.expand_dims(data[batch_idxs,:], -1)
            train_lbls = lbls[batch_idxs]
            
            feed_dict = {end_points["X"]:train_data, end_points["Y"]:train_lbls, enable_b_prior_update:1}
            _, loss_train, accuracy_train, grads, summaries_train = sess.run([
                                                        end_points["optimize"], 
                                                        end_points["loss"],
                                                        end_points["accuracy"],
                                                        end_points["grads"],
                                                        end_points["summaries_merged"]], feed_dict=feed_dict)

            loss_report_counter +=1
   
    
                
            if(loss_report_counter%10==0):
                test_data=np.expand_dims(data[test_idxs,:], -1)
                test_lbls = lbls[test_idxs]

                feed_dict = {end_points["X"]:test_data, end_points["Y"]:test_lbls, enable_b_prior_update:0}
                loss_test, accuracy_test, summaries_test = sess.run([end_points["loss"], 
                                                                     end_points["accuracy"],
                                                                     end_points["summaries_merged"]], feed_dict=feed_dict)
                               
                print("--Epoch:{} Loss train:{} Accuracy train:{}  Accuracy test:{}--".format(epoch,
                                                                                    loss_train,
                                                                                    accuracy_train,
                                                                                    accuracy_test))
                saver_train.add_summary(summaries_train, global_step=loss_report_counter)
                saver_test.add_summary(summaries_test, global_step=loss_report_counter)
                if np.any(np.isnan(grads[0][0])==True):
                    print("Nan detected")
                    break;

    print("Model saves under path {}".format(model_saver.save(sess,save_path="..//saved_model//capsnet")))


--Epoch:0 Loss train:0.5697051882743835 Accuracy train:0.355  Accuracy test:0.365--
--Epoch:0 Loss train:0.31487447023391724 Accuracy train:0.74  Accuracy test:0.745--
--Epoch:0 Loss train:0.20486176013946533 Accuracy train:0.87  Accuracy test:0.85--
--Epoch:0 Loss train:0.15141722559928894 Accuracy train:0.885  Accuracy test:0.895--
--Epoch:0 Loss train:0.12692837417125702 Accuracy train:0.93  Accuracy test:0.915--
--Epoch:0 Loss train:0.1214141920208931 Accuracy train:0.905  Accuracy test:0.94--
--Epoch:0 Loss train:0.09358124434947968 Accuracy train:0.945  Accuracy test:0.945--
--Epoch:0 Loss train:0.07398714125156403 Accuracy train:0.975  Accuracy test:0.935--
--Epoch:0 Loss train:0.08707387745380402 Accuracy train:0.94  Accuracy test:0.935--
--Epoch:0 Loss train:0.07883093506097794 Accuracy train:0.96  Accuracy test:0.945--
--Epoch:0 Loss train:0.0810299888253212 Accuracy train:0.955  Accuracy test:0.945--
--Epoch:0 Loss train:0.07164362818002701 Accuracy train:0.955  Accuracy tes

Paper mentiones shifting images by two pixels but it is not done here<br/>
I also did not add noise for reconstruction make sure the essentials work

## Unit testing for capsNET
Building the graph and getting the data 

In [130]:
tf.reset_default_graph()
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    enable_b_prior_update = end_points["enable_b_prior_update"]
    
    sess.run(init)

    feed_dict = {X:data, Y:lbls, enable_b_prior_update:1}

    
    
    #"Test output of conv 1"
    x, w, b, conv_1_act =  sess.run([X, graph.get_tensor_by_name('Conv/weights:0'), 
                                     graph.get_tensor_by_name('Conv/biases:0'), 
                                     end_points["conv_layer_1_act"]], 
                                     feed_dict=feed_dict)
    
    assert(np.max([0, np.sum((np.matmul(np.expand_dims(x[0, 0:9, 0:9], axis=3), w) + b)[:,:,:,0])]) - 
           conv_1_act[0, 0, 0, 0]<0.0001)
    assert(conv_1_act.shape == (2, 20, 20, 256))
    print("Test output of conv1 passed")
    

Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
Test output of conv1 passed


In [131]:
tf.reset_default_graph()
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    enable_b_prior_update = end_points["enable_b_prior_update"]
    
    sess.run(init)

    feed_dict = {X:data, Y:lbls, enable_b_prior_update:1}

    #Test for capsule out values"
    #Capsule 8
    
    w, b, conv_1_act, capsule_7_act, capsule_2_act =  sess.run([graph.get_tensor_by_name('Capsule_Layer/Conv_7/weights:0'), 
                                     graph.get_tensor_by_name('Capsule_Layer/Conv_7/biases:0'), 
                                                 end_points["conv_layer_1_act"],
                                                 end_points["capsule_7_act"],
                                                 end_points["capsule_2_act"]],
                                                 feed_dict=feed_dict)
    
    assert(np.linalg.norm(np.sum((np.matmul(np.expand_dims(conv_1_act[0, 0:9, 0:9], axis=2), w) + b), axis=(0, 1, 2)) - 
           capsule_7_act[0, 0, 0, :])<0.0005)
   
    assert(capsule_7_act.shape==(2, 6, 6, 8))
    

    #Test for not sharing the weights
    assert(np.linalg.norm(capsule_7_act[0, 0, 0, :] - capsule_2_act[0, 0, 0, :])>0.01)
print("Capsule 8 output test passed")
    

Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
Capsule 8 output test passed


In [134]:
tf.reset_default_graph()
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    enable_b_prior_update = end_points["enable_b_prior_update"]
    
    sess.run(init)

    feed_dict = {X:data, Y:lbls, enable_b_prior_update:1}

    #Test the squash function before routing
    v, s, b_prior, W_t, capsule_7_act, u, u_hat = sess.run([end_points["v"], 
                                                                 end_points["s"],
                                                                 end_points["b_prior"],
                                                                 end_points["W_t"],
                                                          
                                                                 end_points["capsule_7_act"],
                                                                 end_points["u"],
                                                                 end_points["u_hat"]], feed_dict=feed_dict)
    
    #Test for the first one digit 5
    u = np.squeeze(u[0,252:288,...], axis=1)
    W_t = np.squeeze(W_t[:, 252:288, 5, ...])
   
    u_hat_7_5 = np.sum(np.multiply(u, W_t), axis=2)
    
    assert(np.linalg.norm(u_hat_7_5 - u_hat[0,252:288, 5,...])<0.001)
    b_prior_all_10= np.squeeze(b_prior[0,:,...])
    b_prior_7_5 =np.squeeze(b_prior[0,:, 5,...])
    c_7_5 = np.exp(b_prior_7_5)/np.sum(np.exp(b_prior_all_10), axis=1)
    c_7_5 = np.expand_dims(c_7_5, axis=1)
    u_hat_5 =  np.squeeze(u_hat[0, :, 5,...])
    s_5 = np.sum(np.multiply(c_7_5, u_hat_5), axis=0)
    
    #Testing the vector for element 5
    assert(np.linalg.norm(s_5 - s[0, 5,...])<0.00001)
    
    #test squash
    v_5 = (np.power(np.linalg.norm(s_5), 2)/(1+np.power(np.linalg.norm(s_5), 2)))*(s_5/np.linalg.norm(s_5))
    assert(np.linalg.norm(v_5) - np.linalg.norm(v[0, 5, ...])<0.00001)
    
print("Squash test passed")

b_prior:0
Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
Squash test passed


In [144]:
tf.reset_default_graph()
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    enable_b_prior_update = end_points["enable_b_prior_update"]
    
    sess.run(init)

    feed_dict = {X:data, Y:lbls, enable_b_prior_update:1}
    
    #Test routing
    
    b_prior_var = graph.get_tensor_by_name("b_prior:0")
    
    _, v_before, s_before, c_before, v_routed,s_routed, \
    c_routed, u_hat, u_hat_routed, b_prior_routed, aggr, batch_item_count = sess.run([
                                     end_points["route_op"],
                                     end_points["v"],
                                     end_points["s"],
                                     end_points["c"],
                                     end_points["v_routed"],
                                     end_points["s_routed"],
                                     end_points["c_routed"],
                                     end_points["u_hat"],
                                     end_points["u_hat_routed"],
                                     end_points["b_prior_routed"],
                                     end_points["aggr"],
                                     end_points["batch_item_cnt_after"]], feed_dict=feed_dict)
    assert(batch_item_count == 2)
    assert(np.linalg.norm(b_prior_routed) != 0)
    
    #Test that variable got updated

    b_prior = sess.run(end_points["b_prior"], feed_dict=feed_dict)
    
    assert(np.linalg.norm(b_prior) != 0)
    assert(np.all(u_hat_routed==u_hat))
    
print("routing test passed")

Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
routing test passed


In [155]:
tf.reset_default_graph()
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    enable_b_prior_update = end_points["enable_b_prior_update"]
    
    sess.run(init)

    feed_dict = {X:data, Y:lbls, enable_b_prior_update:1}
    #Test the losses
    X,  margin_loss_present, margin_loss_not_present, total_margin_loss, \
    fc_out, reconstruction_loss, loss, v_norm, b_prior, v_out =\
                                                            sess.run([end_points[k] 
                                                              for k in ["X", "margin_loss_present",
                                                                        "margin_loss_not_present",
                                                                        'total_margin_loss', "fc_out",
                                                                        'reconstruction_loss', 
                                                                        'loss', "v_norm", "b_prior_routed", "v_routed"]], feed_dict=feed_dict)
   
    
    assert(margin_loss_present[0][5] - np.power(np.max([0, 0.9 - np.linalg.norm(v_out[0][5])]), 2)<0.00001)
    
    assert(np.linalg.norm(margin_loss_not_present[0][0] - 
                          0.5*np.power(np.max([0., np.linalg.norm(v_out[0][0])-0.1]), 2))<0.00001)
    
    assert(np.linalg.norm(total_margin_loss[0] - np.sum(margin_loss_present[0]+ margin_loss_not_present[0]))<0.00001)
    
    assert(np.abs(np.sum(np.square(fc_out[0] - X[0].reshape(784,)))*0.0005- reconstruction_loss[0])<0.01)
    
    assert(loss == np.mean(np.add(reconstruction_loss, total_margin_loss)))

    print("All done")
sess.close()

Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
All done


In [161]:
tf.reset_default_graph()
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

#test to see if batch_update_count gets reset for the next batch
with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    enable_b_prior_update = end_points["enable_b_prior_update"]
    
    sess.run(init)

    feed_dict = {X:data, Y:lbls, enable_b_prior_update:1}
    
    batch_count_aft, batch_count_before = sess.run([end_points["batch_item_cnt_after"],
                                                   end_points["batch_item_cnt_before"]], feed_dict=feed_dict) 
    assert(batch_count_before == 0)
    assert(batch_count_aft == 2)
    batch_count_aft, batch_count_before = sess.run([end_points["batch_item_cnt_after"],
                                                   end_points["batch_item_cnt_before"]], feed_dict=feed_dict) 
    assert(batch_count_before == 0)
    assert(batch_count_aft == 2)
    
print("test to see if batch_update_count gets reset for the next batch passed")



Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
test to see if batch_update_count gets reset for the next batch passed


In [162]:
tf.reset_default_graph()
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    enable_b_prior_update = end_points["enable_b_prior_update"]
    
    feed_dict = {X:data, Y:lbls, enable_b_prior_update:1}
    sess.run(init)
   
    
    #"Test output of conv 1"
    v_routed, b_prior = sess.run([end_points["v_routed"], end_points["b_prior_routed"]], feed_dict=feed_dict)
    assert(np.linalg.norm(b_prior)!=0)
    
    feed_dict[enable_b_prior_update] = 0
    v_routed_not_routed, b_prior_aft_not_routed = sess.run([end_points["v_routed"], end_points["b_prior"]], feed_dict=feed_dict)
    assert(np.all(b_prior==b_prior_aft_not_routed))
    
    print("All done")
sess.close()

Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
All done
