In [1]:
import tensorflow as tf
import numpy as np


In [2]:
class Mnist:
    def __init__(self, path, train=True):
        if train:
            image_file_name = "train-images.idx3-ubyte"
            label_file_name = "train-labels.idx1-ubyte"
        else:
            image_file_name = "t10k-images.idx3-ubyte"
            label_file_name = "t10k-labels.idx1-ubyte"
            
        self.f = open("{}//{}".format(path, image_file_name), "rb")
        magic_number, number_of_images, image_rows, image_columns = int(self.f.read(4).hex(), 16), \
        int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16)
        print("Images magic no: {},  No of images: {}, Image rows: {}, Image cols: {}".\
          format(magic_number, number_of_images, image_rows, image_columns))

        self.flab = open("{}//{}".format(path, label_file_name), "rb")
        magic_number, number_of_items = int(self.flab.read(4).hex(), 16), int(self.flab.read(4).hex(), 16)
        print("Labels magic no: {}, No of items: {}".format(magic_number, number_of_items))
        

    def _get_next(self):
        val = self.f.read(28*28)
        if len(val)>0:
            img = val.hex()
            idxs = np.arange(0, len(img), 2).astype(int)
            return(np.reshape([int(img[i:(i+2)], 16) for i in idxs], (28, 28)), int(self.flab.read(1).hex(), 16)) 
        else:
            return("", "")

    def get_batch(self, batch_size=20, scale=True):
        im_lst = np.zeros(shape=[batch_size, 28, 28], dtype="float")
        lbl_lst = np.zeros(shape=[batch_size], dtype="int")
        lbl_one_hot_lst = np.zeros(shape=[batch_size, 10], dtype="int")
        for i in range(batch_size):
            im, lbl = self._get_next()
           
            if len(im)>0:
                if scale:
                    im=im/255.
                im_lst[i,...] = im
                lbl_lst[i] = lbl
                lbl_one_hot_lst[i,lbl] = 1
        return im_lst, lbl_lst, lbl_one_hot_lst
    

        
        
    
    

In [11]:
def get_s(b_prior, u_hat):
    c = tf.nn.softmax(b_prior, dim=2)
    s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1)
    return(s, c)


#Squash
def get_v(s):
    norm_s = tf.norm(s, axis=2, keep_dims=True)
    norm_s_2 = tf.pow(norm_s, 2)
    v = (tf.multiply(norm_s_2,s))/(tf.multiply(1+norm_s_2, norm_s))

    return(v)


def route(n_iter, batch_item_cnt, v, b_prior, u_hat, aggr):
    u_hat_stopped = tf.stop_gradient(u_hat)
  
    for i in range(n_iter):
        if i == n_iter-1:
            s, c = get_s(b_prior, u_hat)
            v = get_v(s)
            aggr = tf.reduce_sum(tf.multiply(v[batch_item_cnt, :, :], u_hat[batch_item_cnt, :, :]), axis=-1)
            aggr = tf.expand_dims(aggr, axis=0)
            aggr = tf.expand_dims(aggr, axis=3)
            b_prior = tf.add(b_prior, aggr)
        else:
            s, c = get_s(b_prior, u_hat_stopped)
            v = get_v(s)
            aggr = tf.reduce_sum(tf.multiply(v[batch_item_cnt, :, :], u_hat_stopped[batch_item_cnt, :, :]), axis=-1)
            aggr = tf.expand_dims(aggr, axis=0)
            aggr = tf.expand_dims(aggr, axis=3)
            b_prior = tf.add(b_prior, aggr)
            
    return (v, b_prior, aggr)

def update_prior(batch_item_cnt, v, b_prior, u_hat, aggr):
    v, b_prior, aggr = route(3, batch_item_cnt, v, b_prior, u_hat, aggr)
    return([tf.add(batch_item_cnt, 1), v, b_prior, u_hat, aggr])

def update_prior_10(batch_item_cnt, v, b_prior, u_hat, aggr):
    v, b_prior, aggr = route(10, batch_item_cnt, v, b_prior, u_hat, aggr)
    return([tf.add(batch_item_cnt, 1), v, b_prior, u_hat])

def build_graph(batch_size, is_train=True, high_routing= False, add_summaries=False, print_shape=False):

    graph = tf.Graph()
    end_points = dict()

    with graph.as_default():

        X = tf.placeholder(shape=[batch_size, 28, 28, 1], name="X", dtype=tf.float32)
        end_points["X"] = X
        Y = tf.placeholder(shape=[batch_size], name="Y", dtype=tf.int32)
        end_points["Y"] = Y
        Y_one_hot = tf.one_hot(indices=Y, depth=10)
        end_points["Y_one_hot"] = Y_one_hot

        #First Layer
        with tf.name_scope("Conv_Layer_1"):

            conv_1 = tf.contrib.layers.conv2d(inputs=X, 
                                              num_outputs=256,  
                                              kernel_size=9, 
                                              stride=1, padding="VALID", 
                                              activation_fn=tf.nn.relu, 
                                              weights_initializer=tf.contrib.layers.xavier_initializer_conv2d(), 
                                              biases_initializer=tf.zeros_initializer())


            end_points["conv_layer_1_act"] = conv_1

        capsules = list()

        #In this implementation every capsule layer has unique weights
        #I think the meaning of sharing is that in each plane [6, 6, 1]
        #each cepsule output sharing the weights
        with tf.variable_scope("Capsule_Layer"):
            for n in range(32):

               #If we user auto resuse then the output of each capsule would be the same and it doesn't make any sense

                cap_1 = tf.contrib.layers.conv2d(inputs=conv_1, 
                                                 #reuse=tf.AUTO_REUSE, 
                                                 #scope="Capsule_Layer",
                                                 num_outputs=8, kernel_size=9, 
                                                 activation_fn=None,
                                                 stride=2, padding="VALID" ,
                                                 weights_initializer=tf.random_uniform_initializer(),
                                                 biases_initializer=tf.zeros_initializer())

                end_points["capsule_{}_act".format(n)] = cap_1
                capsules.append(tf.expand_dims(cap_1, axis=3))

        with tf.name_scope("Transformation"):
            u = tf.concat(capsules, axis=3)
            end_points["u_b_r"] = u
            u = tf.reshape(u, shape=[-1, 1152, 8])
            u = tf.expand_dims(tf.expand_dims(u, axis=2), axis=2)
            end_points["u"] = u

            W = tf.get_variable(shape=[1, 1152, 10, 16, 8], name="W_t", 
                                initializer=tf.random_normal_initializer())


            end_points["W_t"] = W

            b = tf.get_variable(shape=[1, 1152, 10, 1], name="b_t", 
                                initializer=tf.zeros_initializer())

            end_points["b_t"] = b

            u_hat =tf.add(tf.reduce_sum(tf.multiply(u, W), axis=4), b)
            end_points["u_hat"] = u_hat

            b_prior = tf.get_variable(shape=[1, 1152, 10, 1], name="b_prior", initializer=tf.zeros_initializer(), 
                                      trainable=False)
            
            aggr = tf.get_variable(shape=[1, 1152, 10, 1], name="aggr", initializer=tf.zeros_initializer(), 
                                      trainable=False)

            end_points["b_prior"] = b_prior

            s, c = get_s(b_prior, u_hat)
            end_points["s"] = s
            end_points["c"] = c

            v = get_v(s)

            end_points["v"] = v

            #if add_summaries:
                #tf.summary.histogram("v_before_routing", values=v)

        if is_train:
            with tf.name_scope("Routing"):

                #Routing
                batch_item_cnt = tf.Variable(initial_value=0, trainable=False)

                batch_item_cnt.assign(0)
                end_of_batch = lambda batch_item_cnt, v, b_prior, u_hat, aggr: tf.less(batch_item_cnt, batch_size)
                if high_routing:
                    route_op = tf.while_loop(body=update_prior_10, cond=end_of_batch, 
                                         loop_vars=[batch_item_cnt, v, b_prior, u_hat, aggr])
                    print("Using routing 10")
                else:
                    route_op = tf.while_loop(body=update_prior, cond=end_of_batch, 
                                         loop_vars=[batch_item_cnt, v, b_prior, u_hat, aggr])
                    print("Using routing 3")
                    
                end_points["route_op"] = route_op

                [batch_item_cnt, v, b_prior_updated, u_hat, aggr] = route_op
                end_points["aggr"] = aggr
                b_prior = tf.assign(b_prior, b_prior_updated)


                end_points["batch_item_cnt"] = batch_item_cnt



        s, c = get_s(b_prior, u_hat)
        v = get_v(s)
        end_points["s_routed"] = s
        end_points["c_routed"] = c

        with tf.name_scope("Digicaps"):


            end_points["b_prior_routed"] = b_prior
            end_points["u_hat_routed"] = u_hat
            end_points["v_routed"] = v
            end_points["s_routed"] = s
            end_points["c_routed"] = c
        #if add_summaries:
            #tf.summary.histogram(name="v_after_routing", values=v)

        with tf.name_scope("Prepare_for_FC"):
            Y_one_hot_ex = tf.expand_dims(Y_one_hot, axis=1)
            end_points["Y_one_hot_ex"] = Y_one_hot_ex

            v_masked = tf.squeeze(tf.matmul(Y_one_hot_ex, v))
            end_points["v_masked"] = v_masked
            
        with tf.name_scope("Fully_connected"):
            fc_1 = tf.contrib.layers.fully_connected(inputs=v_masked, num_outputs=512, 
                                                     activation_fn=tf.nn.relu,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())

            end_points["fc_1"] = fc_1
            #print("fc_1 shape={}".format(fc_1.shape), end =" ")

            fc_2 = tf.contrib.layers.fully_connected(inputs=fc_1, num_outputs=1024, 
                                                     activation_fn=tf.nn.relu,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())
            end_points["fc_2"] = fc_2
            #print("fc_2 shape={}".format(fc_2.shape), end =" ")

            fc_out = tf.contrib.layers.fully_connected(inputs=fc_2, num_outputs=784, 
                                                     activation_fn=tf.nn.sigmoid,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())
            end_points["fc_out"] = fc_out
            #print("fc_out shape={}".format(fc_out.shape), end =" ")
            
        



        #Calsulatong loss for each capsule
        lambd_a = 0.5


        #margin_loss =Y_one_hot_ex 
        v_norm = tf.norm(v, axis=2)
        end_points["v_norm"] = v_norm


        margin_loss_present = tf.multiply(tf.pow(tf.maximum(0., 0.9 - v_norm), 2), Y_one_hot)
        end_points["margin_loss_present"] = margin_loss_present


        margin_loss_not_present = lambd_a * tf.multiply(tf.pow(tf.maximum(0., v_norm - 0.1), 2), (1 - Y_one_hot))
        end_points["margin_loss_not_present"] = margin_loss_not_present


        total_margin_loss = tf.reduce_sum(tf.add(margin_loss_present, margin_loss_not_present), axis=1)
        end_points["total_margin_loss"] = total_margin_loss

        scaler = 0.0005
        reconstruction_loss = tf.multiply(scaler, 
                                          tf.reduce_sum(
                                              tf.squared_difference(
                                                  tf.reshape(X, shape=[-1, 784]), fc_out), axis=1))
        end_points["reconstruction_loss"] = reconstruction_loss


        loss = tf.reduce_mean(tf.add(total_margin_loss, reconstruction_loss))
        end_points["loss"] = loss
            
        if add_summaries:
            tf.summary.scalar("loss_avg", loss)



        if is_train:
            optimizer = tf.train.AdamOptimizer()
            grads = optimizer.compute_gradients(loss)
            end_points["grads"] = grads
            optimize = optimizer.apply_gradients(grads)
            end_points["optimize"] = optimize

        if add_summaries:
            summaries_merged = tf.summary.merge_all()
            end_points["summaries_merged"] = summaries_merged

        init = tf.global_variables_initializer()
        end_points["init"] = init
        if print_shape:
            for k in end_points.keys():
                k_shape="NA"
                try:
                    k_shape = str(end_points[k].shape)
                except AttributeError:
                    k_shape = type(end_points[k])

                print("\r {} shape={}".format(k, k_shape))

        return(graph, end_points)

The problem is with routing when I disable the routing everythinbg seems fine
- Action: try with none batch of 60 and see how b_prior chnages 
    - The problem is with the routing algorithms look at the current implementations
      some use the stop gradient function to make it work. the problem is there lloka rpouind 



In [21]:
batch_size = 60
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=True, print_shape=False)
n_epochs = 10

with tf.Session(graph=graph) as sess:
    sess.run(end_points["init"])
    saver = tf.summary.FileWriter("..//log", graph=graph)
    
    other_variables = [v for v in graph.get_collection("variables") if v.name=='b_prior:0']
    trainable_variables = graph.get_collection("trainable_variables")
    variables_to_restore = trainable_variables.extend(other_variables)
    
    model_saver = tf.train.Saver(variables_to_restore)
   
    loss_report_counter = 0
    nan_caps=list()
    c_list=list()
    
    
    for epoch in range(n_epochs):
        mnist = Mnist(path="..//", train=True)
        data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)  
        loss_report_counter = 0
        
    
     
        while data.shape[0]==60:
       
      
            data = np.expand_dims(data, -1)
            feed_dict = {end_points["X"]:data, end_points["Y"]:lbls}
            _, loss, grads, u_hat, c_bef, c_aft, summaries_merged = sess.run([end_points["optimize"], 
                                                        end_points["loss"],
                                                        end_points["grads"],
                                                        end_points["u_hat"],
                                                        end_points["c"],
                                                        end_points["c_routed"],
                                                        end_points["summaries_merged"]], feed_dict=feed_dict)

            loss_report_counter +=1

            if(loss_report_counter%10==0):
                print("----- Epoch {} Loss:{}".format(epoch, loss))
                
                
                if np.any(np.isnan(grads[0][0])==True):
                    print("Nan detected")
                 
                    break;
                saver.add_summary(summaries_merged, global_step=loss_report_counter)
            data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True) 
        
                

   
    
    print("Model saves under path {}".format(model_saver.save(sess,save_path="..//saved_model//capsnet")))


Using routing 3
Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
----- Epoch 0 Loss:3.695007085800171
----- Epoch 0 Loss:3.671670436859131
----- Epoch 0 Loss:3.6710259914398193
----- Epoch 0 Loss:3.67142653465271
----- Epoch 0 Loss:3.6666336059570312
----- Epoch 0 Loss:3.670431613922119
----- Epoch 0 Loss:3.6659631729125977
----- Epoch 0 Loss:3.670924425125122
----- Epoch 0 Loss:3.6689774990081787
----- Epoch 0 Loss:3.6682331562042236
----- Epoch 0 Loss:3.6646199226379395
----- Epoch 0 Loss:3.665210247039795
----- Epoch 0 Loss:3.6676857471466064
----- Epoch 0 Loss:3.666738271713257
----- Epoch 0 Loss:3.665623426437378
----- Epoch 0 Loss:3.6645774841308594
----- Epoch 0 Loss:3.664440870285034
----- Epoch 0 Loss:3.6663198471069336
----- Epoch 0 Loss:3.6650636196136475
----- Epoch 0 Loss:3.667222738265991
----- Epoch 0 Loss:3.667555570602417
----- Epoch 0 Loss:3.6676559448242188
----- Epoch 0 Loss:3.6654860973358154
----

KeyboardInterrupt: 

# Testing the structure based on mnist with batch size 2

In [None]:
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, high_routing=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
#with scale there is a problem beacuse weights are still low and outputs gets small
#for testing we simply override scaling to test the functionality

data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    
    sess.run(init)

    feed_dict = {X:data, Y:lbls}

    
    
    #"Test output of conv 1"
    x, w, b, conv_1_act =  sess.run([X, graph.get_tensor_by_name('Conv/weights:0'), 
                                     graph.get_tensor_by_name('Conv/biases:0'), 
                                     end_points["conv_layer_1_act"]], 
                                     feed_dict=feed_dict)
    
    assert(np.max([0, np.sum((np.matmul(np.expand_dims(x[0, 0:9, 0:9], axis=3), w) + b)[:,:,:,0])]) - 
           conv_1_act[0, 0, 0, 0]<0.0001)
    assert(conv_1_act.shape == (2, 20, 20, 256))
    
    #Test for capsule out values"
    #Capsule 8
    
    w, b, conv_1_act, capsule_7_act, capsule_2_act =  sess.run([graph.get_tensor_by_name('Capsule_Layer/Conv_7/weights:0'), 
                                     graph.get_tensor_by_name('Capsule_Layer/Conv_7/biases:0'), 
                                                 end_points["conv_layer_1_act"],
                                                 end_points["capsule_7_act"],
                                                 end_points["capsule_2_act"]],
                                                 feed_dict=feed_dict)
    
    assert(np.linalg.norm(np.sum((np.matmul(np.expand_dims(conv_1_act[0, 0:9, 0:9], axis=2), w) + b), axis=(0, 1, 2)) - 
           capsule_7_act[0, 0, 0, :])<0.00005)
    
    assert(capsule_7_act.shape==(2, 6, 6, 8))
    
    #Test for not sharing the weights
    assert(np.linalg.norm(capsule_7_act[0, 0, 0, :] - capsule_2_act[0, 0, 0, :])>0.01)
    
    #Test the squash function before routing
    v, s, b_prior, W_t, b_t, capsule_7_act, u, u_hat = sess.run([end_points["v"], 
                                                                 end_points["s"],
                                                                 end_points["b_prior"],
                                                                 end_points["W_t"],
                                                                 end_points["b_t"],
                                                                 end_points["capsule_7_act"],
                                                                 end_points["u"],
                                                                 end_points["u_hat"]], feed_dict=feed_dict)
    
    #Test for the first one digit 5
    u = np.squeeze(u[0,252:288,...], axis=1)
    W_t = np.squeeze(W_t[:, 252:288, 5, ...])
    b_t = b_t[:, 252:288, 5, ...]
    u_hat_7_5 = np.add(np.sum(np.multiply(u, W_t), axis=2), b_t)
    
    assert(np.linalg.norm(u_hat_7_5 - u_hat[0,252:288, 5,...])<0.0001)
    b_prior_all_10= np.squeeze(b_prior[0,:,...])
    b_prior_7_5 =np.squeeze(b_prior[0,:, 5,...])
    c_7_5 = np.exp(b_prior_7_5)/np.sum(np.exp(b_prior_all_10), axis=1)
    c_7_5 = np.expand_dims(c_7_5, axis=1)
    u_hat_5 =  np.squeeze(u_hat[0, :, 5,...])
    s_5 = np.sum(np.multiply(c_7_5, u_hat_5), axis=0)
    
    #Testing the vector for element 5
    assert(np.linalg.norm(s_5 - s[0, 5,...])<0.00001)
    
    #test squash
    v_5 = (np.power(np.linalg.norm(s_5), 2)/(1+np.power(np.linalg.norm(s_5), 2)))*(s_5/np.linalg.norm(s_5))
    assert(np.linalg.norm(v_5) - np.linalg.norm(v[0, 5, ...])<0.00001)
  
    #Test routing
    
    _, v_before, s_before, c_before, v_routed,s_routed, \
    c_routed, u_hat_routed, b_prior, b_prior_routed, batch_item_count = sess.run([
                                     end_points["route_op"],
                                     end_points["v"],
                                     end_points["s"],
                                     end_points["c"],
                                     end_points["v_routed"],
                                     end_points["s_routed"],
                                     end_points["c_routed"],
                                     end_points["u_hat_routed"],
                                     end_points["b_prior"],
                                     end_points["b_prior_routed"],
                                     end_points["batch_item_cnt"]], feed_dict=feed_dict)
  
    assert(np.all(u_hat_routed==u_hat))
    
    ###############################################################################################
    # Here we are measuring the agreements before and after applying the routing algorithms and   #
    # we are testing that the expected value for each class in fact increased meaning that        #
    # the routing algorithm helped to route them to better digiCaps                               #
    ###############################################################################################
    
    aggr_before, aggr_after = np.sum(np.multiply(v_before[0,...], u_hat[0,...]), axis=2), \
        np.sum(np.multiply(v_routed[0,...], u_hat[0,...]), axis=2)
  
    assert(np.min(aggr_after.mean(axis=0) - aggr_before.mean(axis=0))>0)
    
    #Test the losses
    
    X, v_out, margin_loss_present, margin_loss_not_present, total_margin_loss, fc_out, reconstruction_loss, loss, v_norm =\
                                                            sess.run([ end_points[k] 
                                                              for k in ["X", "v_routed", 
                                                                        "margin_loss_present",
                                                                        "margin_loss_not_present",
                                                                        'total_margin_loss', 
                                                                        "fc_out",
                                                                        'reconstruction_loss', 
                                                                        'loss', "v_norm"]], feed_dict=feed_dict)
    #print(v_out)
    
    assert(np.all(np.where(margin_loss_present != 0.)[1] == np.array([5, 0])))
    
    #assert(np.all(np.where(margin_loss_not_present != 0.)[1] == np.array([0, 1, 2, 3, 4,6,7,8,9,1,2,3,4,5,6,7,8,9])))
    
    assert(margin_loss_present[0][5] - np.power(np.max([0, 0.9 - np.linalg.norm(v_out[0][5])]), 2)<0.00001)
    
    assert(np.linalg.norm(margin_loss_not_present[0][0] - 
                          0.5*np.power(np.max([0., np.linalg.norm(v_out[0][0])-0.1]), 2))<0.00001)
    
    assert(np.linalg.norm(total_margin_loss[0] - np.sum(margin_loss_present[0]+ margin_loss_not_present[0]))<0.00001)
    
    assert(np.abs(np.sum(np.square(fc_out[0] - X[0].reshape(784,)))*0.0005- reconstruction_loss[0])<0.01)
    
    assert(loss == np.mean(np.add(reconstruction_loss, total_margin_loss)))
    print("All done")
sess.close()


# Training the capsule network

Paper mentiones shifting images by two pixels but it is not done here<br/>
I also did not add noise for reconstruction make sure the essentials work

In [None]:
batch_size = 2
tf.reset_default_graph()
graph, end_points = build_graph(batch_size, is_train=False, add_summaries=False, print_shape=True)

mnist = Mnist(path="..\\", train=True)
#with scale there is a problem beacuse weights are still low and outputs gets small
#for testing we simply override scaling to test the functionality

data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)


with tf.Session(graph=graph) as sess:
    restorer = tf.train.Saver()
    restorer.restore(sess, save_path="..\\saved_model\\capsnet")
    feed_dict = {end_points["X"]:data, end_points["Y"]:lbls}
    print(sess.run(end_points["b_prior"], feed_dict=feed_dict))
    
    
    

In [91]:
nan_caps[0][0].shape

(1, 1152, 10, 1)

In [None]:
np.all(np.isnan(o)==False)

In [None]:
idxs = np.where(np.isnan(nan_caps[3]) == True)

In [178]:
b_prior.shape

(1, 1152, 10, 1)

In [6]:
margin_loss_present.shape

(60, 10)

In [16]:
data.shape[0]

60

In [168]:
np.argmax(np.linalg.norm(u_hat.mean(axis=0).mean(axis=0), axis=1))

3