In [3]:
import tensorflow as tf
import numpy as np


In [4]:
class Mnist:
    def __init__(self, path, train=True):
        if train:
            image_file_name = "train-images.idx3-ubyte"
            label_file_name = "train-labels.idx1-ubyte"
        else:
            image_file_name = "t10k-images.idx3-ubyte"
            label_file_name = "t10k-labels.idx1-ubyte"
            
        self.f = open("{}//{}".format(path, image_file_name), "rb")
        magic_number, number_of_images, image_rows, image_columns = int(self.f.read(4).hex(), 16), \
        int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16)
        print("Images magic no: {},  No of images: {}, Image rows: {}, Image cols: {}".\
          format(magic_number, number_of_images, image_rows, image_columns))

        self.flab = open("{}//{}".format(path, label_file_name), "rb")
        magic_number, number_of_items = int(self.flab.read(4).hex(), 16), int(self.flab.read(4).hex(), 16)
        print("Labels magic no: {}, No of items: {}".format(magic_number, number_of_items))
        

    def _get_next(self):
        val = self.f.read(28*28)
        if len(val)>0:
            img = val.hex()
            idxs = np.arange(0, len(img), 2).astype(int)
            return(np.reshape([int(img[i:(i+2)], 16) for i in idxs], (28, 28)), int(self.flab.read(1).hex(), 16)) 
        else:
            return("", "")

    def get_batch(self, batch_size=20, scale=True):
        im_lst = np.zeros(shape=[batch_size, 28, 28], dtype="float")
        lbl_lst = np.zeros(shape=[batch_size], dtype="int")
        lbl_one_hot_lst = np.zeros(shape=[batch_size, 10], dtype="int")
        for i in range(batch_size):
            im, lbl = self._get_next()
           
            if len(im)>0:
                if scale:
                    im=im/255.
                im_lst[i,...] = im
                lbl_lst[i] = lbl
                lbl_one_hot_lst[i,lbl] = 1
        return im_lst, lbl_lst, lbl_one_hot_lst
    

        
        
    
    

In [15]:
def get_s(b_prior, u_hat):
    c = tf.nn.softmax(b_prior, dim=2)
    s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1)
    return(s, c)


#Squash
def get_v(s):
    norm_s = tf.norm(s, axis=2, keep_dims=True)
    norm_s_2 = tf.pow(norm_s, 2)
    v = (tf.multiply(norm_s_2,s))/(tf.multiply(1+norm_s_2, norm_s))

    return(v)


def route(n_iter, batch_item_cnt, v, b_prior, u_hat, aggr):
  
    for i in range(n_iter):
        s, c = get_s(b_prior, u_hat)
        v = get_v(s)
        v = tf.stop_gradient(v) if i<(n_iter-1) else v
        aggr = tf.reduce_sum(tf.multiply(v[batch_item_cnt, :, :], u_hat[batch_item_cnt, :, :]), axis=-1)
        aggr = tf.expand_dims(aggr, axis=0)
        aggr = tf.expand_dims(aggr, axis=3)
        b_prior = tf.add(b_prior, aggr)
    return (v, b_prior, aggr)

def update_prior(batch_item_cnt, v, b_prior, u_hat, aggr):
    v, b_prior, aggr = route(3, batch_item_cnt, v, b_prior, u_hat, aggr)
    return([tf.add(batch_item_cnt, 1), v, b_prior, u_hat, aggr])

def update_prior_10(batch_item_cnt, v, b_prior, u_hat, aggr):
    v, b_prior, aggr = route(10, batch_item_cnt, v, b_prior, u_hat, aggr)
    return([tf.add(batch_item_cnt, 1), v, b_prior, u_hat])

def build_graph(batch_size, is_train=True, high_routing= False, add_summaries=False, print_shape=False):

    graph = tf.Graph()
    end_points = dict()

    with graph.as_default():

        X = tf.placeholder(shape=[batch_size, 28, 28, 1], name="X", dtype=tf.float32)
        end_points["X"] = X
        Y = tf.placeholder(shape=[batch_size], name="Y", dtype=tf.int32)
        end_points["Y"] = Y
        Y_one_hot = tf.one_hot(indices=Y, depth=10)
        end_points["Y_one_hot"] = Y_one_hot

        #First Layer
        with tf.name_scope("Conv_Layer_1"):

            conv_1 = tf.contrib.layers.conv2d(inputs=X, 
                                              num_outputs=256,  
                                              kernel_size=9, 
                                              stride=1, padding="VALID", 
                                              activation_fn=tf.nn.relu, 
                                              weights_initializer=tf.contrib.layers.xavier_initializer_conv2d(), 
                                              biases_initializer=tf.zeros_initializer())


            end_points["conv_layer_1_act"] = conv_1

        capsules = list()

        #In this implementation every capsule layer has unique weights
        #I think the meaning of sharing is that in each plane [6, 6, 1]
        #each cepsule output sharing the weights
        with tf.variable_scope("Capsule_Layer"):
            for n in range(32):

               #If we user auto resuse then the output of each capsule would be the same and it doesn't make any sense

                cap_1 = tf.contrib.layers.conv2d(inputs=conv_1, 
                                                 #reuse=tf.AUTO_REUSE, 
                                                 #scope="Capsule_Layer",
                                                 num_outputs=8, kernel_size=9, 
                                                 activation_fn=None,
                                                 stride=2, padding="VALID" ,
                                                 weights_initializer=tf.variance_scaling_initializer(),
                                                 biases_initializer=tf.zeros_initializer())

                end_points["capsule_{}_act".format(n)] = cap_1
                capsules.append(tf.expand_dims(cap_1, axis=3))

        with tf.name_scope("Transformation"):
            u = tf.concat(capsules, axis=3)
            end_points["u_b_r"] = u
            u = tf.reshape(u, shape=[-1, 1152, 8])
            u = tf.expand_dims(tf.expand_dims(u, axis=2), axis=2)
            end_points["u"] = u

            W = tf.get_variable(shape=[1, 1152, 10, 16, 8], name="W_t", 
                                initializer=tf.variance_scaling_initializer())


            end_points["W_t"] = W

            b = tf.get_variable(shape=[1, 1152, 10, 1], name="b_t", 
                                initializer=tf.zeros_initializer())

            end_points["b_t"] = b

            u_hat =tf.add(tf.reduce_sum(tf.multiply(u, W), axis=4), b)
            end_points["u_hat"] = u_hat

            b_prior = tf.get_variable(shape=[1, 1152, 10, 1], name="b_prior", initializer=tf.zeros_initializer(), 
                                      trainable=False)
            
            aggr = tf.get_variable(shape=[1, 1152, 10, 1], name="aggr", initializer=tf.zeros_initializer(), 
                                      trainable=False)

            end_points["b_prior"] = b_prior

            s, c = get_s(b_prior, u_hat)
            end_points["s"] = s
            end_points["c"] = c

            v = get_v(s)

            end_points["v"] = v

            #if add_summaries:
                #tf.summary.histogram("v_before_routing", values=v)

        if is_train:
            with tf.name_scope("Routing"):

                #Routing
                batch_item_cnt = tf.Variable(initial_value=0, trainable=False)

                batch_item_cnt.assign(0)
                end_of_batch = lambda batch_item_cnt, v, b_prior, u_hat, aggr: tf.less(batch_item_cnt, batch_size)
                if high_routing:
                    route_op = tf.while_loop(body=update_prior_10, cond=end_of_batch, 
                                         loop_vars=[batch_item_cnt, v, b_prior, u_hat, aggr])
                    print("Using routing 10")
                else:
                    route_op = tf.while_loop(body=update_prior, cond=end_of_batch, 
                                         loop_vars=[batch_item_cnt, v, b_prior, u_hat, aggr])
                    print("Using routing 3")
                    
                end_points["route_op"] = route_op

                [batch_item_cnt, v, b_prior_updated, u_hat, aggr] = route_op
                end_points["aggr"] = aggr
                b_prior = tf.assign(b_prior, b_prior_updated)


                end_points["batch_item_cnt"] = batch_item_cnt



        s, c = get_s(b_prior, u_hat)
        v = get_v(s)
        end_points["s_routed"] = s
        end_points["c_routed"] = c

        with tf.name_scope("Digicaps"):


            end_points["b_prior_routed"] = b_prior
            end_points["u_hat_routed"] = u_hat
            end_points["v_routed"] = v
            end_points["s_routed"] = s
            end_points["c_routed"] = c
        #if add_summaries:
            #tf.summary.histogram(name="v_after_routing", values=v)

        with tf.name_scope("Prepare_for_FC"):
            Y_one_hot_ex = tf.expand_dims(Y_one_hot, axis=1)
            end_points["Y_one_hot_ex"] = Y_one_hot_ex

            v_masked = tf.squeeze(tf.matmul(Y_one_hot_ex, v))
            end_points["v_masked"] = v_masked
            
        with tf.name_scope("Fully_connected"):
            fc_1 = tf.contrib.layers.fully_connected(inputs=v_masked, num_outputs=512, 
                                                     activation_fn=tf.nn.relu,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())

            end_points["fc_1"] = fc_1
            #print("fc_1 shape={}".format(fc_1.shape), end =" ")

            fc_2 = tf.contrib.layers.fully_connected(inputs=fc_1, num_outputs=1024, 
                                                     activation_fn=tf.nn.relu,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())
            end_points["fc_2"] = fc_2
            #print("fc_2 shape={}".format(fc_2.shape), end =" ")

            fc_out = tf.contrib.layers.fully_connected(inputs=fc_2, num_outputs=784, 
                                                     activation_fn=tf.nn.sigmoid,
                                                     weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                     biases_initializer=tf.zeros_initializer())
            end_points["fc_out"] = fc_out
            #print("fc_out shape={}".format(fc_out.shape), end =" ")



                   #Calsulatong loss for each capsule
            lambd_a = 0.5


            #margin_loss =Y_one_hot_ex 
            v_norm = tf.norm(v, axis=2)
            end_points["v_norm"] = v_norm


            margin_loss_present = tf.multiply(tf.pow(tf.maximum(0., 0.9 - v_norm), 2), Y_one_hot)
            end_points["margin_loss_present"] = margin_loss_present


            margin_loss_not_present = lambd_a * tf.multiply(tf.pow(tf.maximum(0., v_norm - 0.1), 2), (1 - Y_one_hot))
            end_points["margin_loss_not_present"] = margin_loss_not_present


            total_margin_loss = tf.reduce_sum(tf.add(margin_loss_present, margin_loss_not_present), axis=1)
            end_points["total_margin_loss"] = total_margin_loss
            
            scaler = 0.0005
            reconstruction_loss = tf.multiply(scaler, 
                                              tf.reduce_sum(
                                                  tf.squared_difference(
                                                      tf.reshape(X, shape=[-1, 784]), fc_out), axis=1))
            end_points["reconstruction_loss"] = reconstruction_loss

            
            loss = tf.reduce_mean(tf.add(total_margin_loss, reconstruction_loss))
            end_points["loss"] = loss
            
            if add_summaries:
                tf.summary.scalar("loss_avg", loss)



            if is_train:
                optimizer = tf.train.AdamOptimizer()
                grads = optimizer.compute_gradients(loss)
                end_points["grads"] = grads
                optimize = optimizer.apply_gradients(grads)
                end_points["optimize"] = optimize

            if add_summaries:
                summaries_merged = tf.summary.merge_all()
                end_points["summaries_merged"] = summaries_merged

        init = tf.global_variables_initializer()
        end_points["init"] = init
        if print_shape:
            for k in end_points.keys():
                k_shape="NA"
                try:
                    k_shape = str(end_points[k].shape)
                except AttributeError:
                    k_shape = type(end_points[k])

                print("\r {} shape={}".format(k, k_shape))

        return(graph, end_points)

The problem is with routing when I disable the routing everythinbg seems fine
- Action: try with none batch of 60 and see how b_prior chnages 
    - The problem is with the routing algorithms look at the current implementations
      some use the stop gradient function to make it work. the problem is there lloka rpouind 



In [None]:
batch_size = 6
graph, end_points = build_graph(batch_size, is_train=True, add_summaries=True, print_shape=False)
n_epochs = 1

with tf.Session(graph=graph) as sess:
    sess.run(end_points["init"])
    saver = tf.summary.FileWriter("..//log", graph=graph)
    
    other_variables = [v for v in graph.get_collection("variables") if v.name=='b_prior:0']
    trainable_variables = graph.get_collection("trainable_variables")
    variables_to_restore = trainable_variables.extend(other_variables)
    
    model_saver = tf.train.Saver(variables_to_restore)
   
    loss_report_counter = 0
    nan_caps=list()
    c_list=list()
    
    
    for epoch in range(n_epochs):
        mnist = Mnist(path="..//", train=True)
        data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)  
        
        i=0 
     
        while i<1000:
            i+=1
      
            data = np.expand_dims(data, -1)
            feed_dict = {end_points["X"]:data, end_points["Y"]:lbls}
            _, loss, grads, u_hat, c_bef, c_aft, summaries_merged = sess.run([end_points["optimize"], 
                                                        end_points["loss"],
                                                        end_points["grads"],
                                                        end_points["u_hat"],
                                                        end_points["c"],
                                                        end_points["c_routed"],
                                                        end_points["summaries_merged"]], feed_dict=feed_dict)

            loss_report_counter +=1
            if(loss_report_counter%1==0):
                print("----- Epoch {} Loss:{}-----c: {} --u_hat {}".
                      format(epoch, loss, c_aft.mean(axis=1).squeeze(), 
                     np.argmax(np.linalg.norm(u_hat.mean(axis=0).mean(axis=0), axis=1))))
                c_list.append([c_bef, c_aft])
                
                #if np.any(np.isnan(grads[0][0])==True):
                    #nan_caps.append((prob, Y))
                    #break;
                saver.add_summary(summaries_merged, global_step=loss_report_counter)
            data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True) 
        
                

                
    
    print("Model saves under path {}".format(model_saver.save(sess,save_path="..//saved_model//capsnet")))


Using routing 3
Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000
----- Epoch 0 Loss:0.9001719355583191-----c: [ 0.09999897  0.09999897  0.09999897  0.09999897  0.09999897  0.09999897
  0.09999897  0.09999897  0.09999897  0.09999897] --u_hat 6
----- Epoch 0 Loss:0.7267899513244629-----c: [ 0.10064867  0.10040376  0.10058739  0.09916154  0.10061565  0.10060498
  0.09916152  0.09916157  0.09916154  0.10049354] --u_hat 0
----- Epoch 0 Loss:1.0509648323059082-----c: [ 0.0976347   0.10708978  0.09756611  0.09764081  0.10766241  0.10663421
  0.09608519  0.09608519  0.09608527  0.09751626] --u_hat 4
----- Epoch 0 Loss:0.7843363881111145-----c: [ 0.09697973  0.10850625  0.09669541  0.10031777  0.10705373  0.10627137
  0.09568817  0.09573886  0.09587558  0.0968732 ] --u_hat 3
----- Epoch 0 Loss:1.210477590560913-----c: [ 0.1140355   0.10168911  0.09425597  0.09811378  0.10025128  0.09968051
  0.09700204  0.08954023  0.09005612 

----- Epoch 0 Loss:0.5985404849052429-----c: [ 0.11578533  0.10386945  0.0945731   0.1118882   0.08774342  0.08350245
  0.09375134  0.09618071  0.07886557  0.13384052] --u_hat 3
----- Epoch 0 Loss:0.6083140969276428-----c: [ 0.11213044  0.1021642   0.09161146  0.11667018  0.08959571  0.08646039
  0.09512945  0.09362085  0.07760669  0.13501048] --u_hat 5
----- Epoch 0 Loss:0.7118145823478699-----c: [ 0.10872924  0.10117492  0.08888269  0.11663088  0.09244476  0.09015572
  0.0959783   0.09124099  0.07626155  0.13850079] --u_hat 5
----- Epoch 0 Loss:0.7190137505531311-----c: [ 0.10565737  0.10122105  0.08639272  0.11387302  0.09669486  0.09320644
  0.09578297  0.08924493  0.07481696  0.14310949] --u_hat 4
----- Epoch 0 Loss:0.7633551955223083-----c: [ 0.10314551  0.10173413  0.08409546  0.11072061  0.10265194  0.09435727
  0.09487439  0.0875688   0.07350021  0.14735179] --u_hat 4
----- Epoch 0 Loss:0.7254934310913086-----c: [ 0.10195319  0.10206617  0.08202343  0.10774257  0.10947385  0.0

----- Epoch 0 Loss:0.9995203018188477-----c: [ 0.07381076  0.11327007  0.07792756  0.10536363  0.07868953  0.22039197
  0.06818417  0.08798849  0.07760511  0.09676885] --u_hat 5
----- Epoch 0 Loss:0.8406946063041687-----c: [ 0.07226957  0.11764942  0.07525238  0.10653722  0.07733247  0.22264436
  0.06856988  0.08510659  0.07799364  0.0966445 ] --u_hat 6
----- Epoch 0 Loss:0.5204634070396423-----c: [ 0.07160705  0.12826578  0.07278224  0.10773551  0.07667819  0.21612141
  0.06889611  0.08266047  0.07820179  0.09705175] --u_hat 1
----- Epoch 0 Loss:0.6908552050590515-----c: [ 0.07054631  0.13188653  0.07056231  0.11134589  0.07641289  0.21392663
  0.06926415  0.08046913  0.07833728  0.09724881] --u_hat 6
----- Epoch 0 Loss:0.8246449828147888-----c: [ 0.06871471  0.13153186  0.06784756  0.11362222  0.07599458  0.22234538
  0.06894398  0.07774167  0.07720425  0.09605377] --u_hat 6
----- Epoch 0 Loss:0.8205428123474121-----c: [ 0.0668152   0.12840919  0.06515156  0.11446318  0.07582212  0.2

----- Epoch 0 Loss:0.6891588568687439-----c: [ 0.08113972  0.12558635  0.17595537  0.05695274  0.05889788  0.13007955
  0.07644743  0.09686743  0.11644112  0.08163252] --u_hat 8
----- Epoch 0 Loss:0.6098141670227051-----c: [ 0.08031894  0.12332499  0.17478885  0.05588611  0.06077207  0.12759903
  0.07592679  0.09651358  0.12352439  0.08134513] --u_hat 4
----- Epoch 0 Loss:0.5793821215629578-----c: [ 0.07932644  0.12094731  0.17114632  0.0558446   0.0619587   0.12584268
  0.07460514  0.09775923  0.13167211  0.08089771] --u_hat 4
----- Epoch 0 Loss:0.5757226347923279-----c: [ 0.07912029  0.11906968  0.16887818  0.05622406  0.06537653  0.1251469
  0.07365818  0.09761299  0.13359307  0.08132008] --u_hat 4
----- Epoch 0 Loss:0.4751546382904053-----c: [ 0.07972871  0.11718837  0.1721466   0.05719011  0.06692722  0.12429444
  0.07191346  0.09957443  0.13120395  0.0798327 ] --u_hat 4
----- Epoch 0 Loss:0.4852518141269684-----c: [ 0.0805515   0.11590365  0.1776454   0.05992308  0.06899688  0.12

----- Epoch 0 Loss:0.3667741119861603-----c: [ 0.11006358  0.10309723  0.12247136  0.10048988  0.10057548  0.08097097
  0.08530938  0.13086767  0.09950289  0.06665169] --u_hat 9
----- Epoch 0 Loss:0.20984865725040436-----c: [ 0.11251362  0.10580336  0.1216228   0.0995044   0.10059282  0.07979515
  0.08442421  0.13103373  0.0976155   0.06709459] --u_hat 9
----- Epoch 0 Loss:0.3006325960159302-----c: [ 0.11463448  0.10534853  0.12185504  0.09916768  0.09979878  0.07956205
  0.08452848  0.13140576  0.09611221  0.06758735] --u_hat 9
----- Epoch 0 Loss:0.45734819769859314-----c: [ 0.11339863  0.10420983  0.1236051   0.10072878  0.09939256  0.08118713
  0.08304304  0.13103682  0.09475615  0.06864211] --u_hat 9
----- Epoch 0 Loss:0.33005672693252563-----c: [ 0.11335569  0.10350165  0.12359076  0.10339013  0.09904667  0.08099099
  0.08200774  0.13077153  0.09302755  0.07031732] --u_hat 9
----- Epoch 0 Loss:0.20812834799289703-----c: [ 0.11851401  0.10282325  0.12396879  0.10397483  0.09680822 

----- Epoch 0 Loss:0.26957669854164124-----c: [ 0.09474055  0.08705369  0.1141009   0.09574497  0.11795545  0.06761346
  0.08134813  0.13325854  0.11082378  0.09736052] --u_hat 9
----- Epoch 0 Loss:0.43706345558166504-----c: [ 0.09613983  0.0864736   0.11545204  0.09499647  0.11725227  0.06860563
  0.07978226  0.13335735  0.10867651  0.09926412] --u_hat 9
----- Epoch 0 Loss:0.4075956344604492-----c: [ 0.09908758  0.08532362  0.11415192  0.09383383  0.11783314  0.06953366
  0.08209208  0.1327327   0.10765132  0.09776007] --u_hat 6
----- Epoch 0 Loss:0.3074035942554474-----c: [ 0.09882631  0.08490945  0.1131091   0.09339055  0.11714672  0.06915295
  0.08315252  0.13583475  0.10719094  0.09728635] --u_hat 6
----- Epoch 0 Loss:0.39857327938079834-----c: [ 0.0983487   0.08482257  0.11305266  0.09231343  0.11716281  0.06906858
  0.08306073  0.1380285   0.10688479  0.09725744] --u_hat 6
----- Epoch 0 Loss:0.3397235870361328-----c: [ 0.09730717  0.08421084  0.11210258  0.09225871  0.11694062  

----- Epoch 0 Loss:0.3397844731807709-----c: [ 0.11057916  0.07914664  0.11294761  0.10074674  0.14257698  0.06470861
  0.07912646  0.13831729  0.0805299   0.09132081] --u_hat 3
----- Epoch 0 Loss:0.29852747917175293-----c: [ 0.11299235  0.07993878  0.1124538   0.1000939   0.14340757  0.0641247
  0.07911903  0.13791798  0.07981545  0.09013628] --u_hat 6
----- Epoch 0 Loss:0.15570460259914398-----c: [ 0.113365    0.07974076  0.11180913  0.10106279  0.1426709   0.06523181
  0.07966816  0.13831525  0.07929218  0.08884399] --u_hat 5
----- Epoch 0 Loss:0.13284747302532196-----c: [ 0.11279324  0.08074578  0.11119772  0.10067371  0.14461501  0.06460433
  0.07975662  0.13820836  0.07905546  0.08835001] --u_hat 1
----- Epoch 0 Loss:0.06459647417068481-----c: [ 0.11376251  0.08278029  0.11077378  0.10156394  0.14409716  0.06448895
  0.0791603   0.13803875  0.07842324  0.08691092] --u_hat 1
----- Epoch 0 Loss:0.19664783775806427-----c: [ 0.11492152  0.08282773  0.1103284   0.10126835  0.1440853  

# Testing the structure based on mnist with batch size 2

In [None]:
batch_size = 2
graph, end_points = build_graph(batch_size, is_train=True, high_routing=True, add_summaries=False, print_shape=False)

mnist = Mnist(path="..//", train=True)
#with scale there is a problem beacuse weights are still low and outputs gets small
#for testing we simply override scaling to test the functionality

data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)

with tf.Session(graph=graph) as sess:
   
    init = end_points["init"]
    X = end_points["X"]
    Y = end_points["Y"]
    
    sess.run(init)

    feed_dict = {X:data, Y:lbls}

    
    
    #"Test output of conv 1"
    x, w, b, conv_1_act =  sess.run([X, graph.get_tensor_by_name('Conv/weights:0'), 
                                     graph.get_tensor_by_name('Conv/biases:0'), 
                                     end_points["conv_layer_1_act"]], 
                                     feed_dict=feed_dict)
    
    assert(np.max([0, np.sum((np.matmul(np.expand_dims(x[0, 0:9, 0:9], axis=3), w) + b)[:,:,:,0])]) - 
           conv_1_act[0, 0, 0, 0]<0.0001)
    assert(conv_1_act.shape == (2, 20, 20, 256))
    
    #Test for capsule out values"
    #Capsule 8
    
    w, b, conv_1_act, capsule_7_act, capsule_2_act =  sess.run([graph.get_tensor_by_name('Capsule_Layer/Conv_7/weights:0'), 
                                     graph.get_tensor_by_name('Capsule_Layer/Conv_7/biases:0'), 
                                                 end_points["conv_layer_1_act"],
                                                 end_points["capsule_7_act"],
                                                 end_points["capsule_2_act"]],
                                                 feed_dict=feed_dict)
    
    assert(np.linalg.norm(np.sum((np.matmul(np.expand_dims(conv_1_act[0, 0:9, 0:9], axis=2), w) + b), axis=(0, 1, 2)) - 
           capsule_7_act[0, 0, 0, :])<0.00005)
    
    assert(capsule_7_act.shape==(2, 6, 6, 8))
    
    #Test for not sharing the weights
    assert(np.linalg.norm(capsule_7_act[0, 0, 0, :] - capsule_2_act[0, 0, 0, :])>0.01)
    
    #Test the squash function before routing
    v, s, b_prior, W_t, b_t, capsule_7_act, u, u_hat = sess.run([end_points["v"], 
                                                                 end_points["s"],
                                                                 end_points["b_prior"],
                                                                 end_points["W_t"],
                                                                 end_points["b_t"],
                                                                 end_points["capsule_7_act"],
                                                                 end_points["u"],
                                                                 end_points["u_hat"]], feed_dict=feed_dict)
    
    #Test for the first one digit 5
    u = np.squeeze(u[0,252:288,...], axis=1)
    W_t = np.squeeze(W_t[:, 252:288, 5, ...])
    b_t = b_t[:, 252:288, 5, ...]
    u_hat_7_5 = np.add(np.sum(np.multiply(u, W_t), axis=2), b_t)
    
    assert(np.linalg.norm(u_hat_7_5 - u_hat[0,252:288, 5,...])<0.0001)
    b_prior_all_10= np.squeeze(b_prior[0,:,...])
    b_prior_7_5 =np.squeeze(b_prior[0,:, 5,...])
    c_7_5 = np.exp(b_prior_7_5)/np.sum(np.exp(b_prior_all_10), axis=1)
    c_7_5 = np.expand_dims(c_7_5, axis=1)
    u_hat_5 =  np.squeeze(u_hat[0, :, 5,...])
    s_5 = np.sum(np.multiply(c_7_5, u_hat_5), axis=0)
    
    #Testing the vector for element 5
    assert(np.linalg.norm(s_5 - s[0, 5,...])<0.00001)
    
    #test squash
    v_5 = (np.power(np.linalg.norm(s_5), 2)/(1+np.power(np.linalg.norm(s_5), 2)))*(s_5/np.linalg.norm(s_5))
    assert(np.linalg.norm(v_5) - np.linalg.norm(v[0, 5, ...])<0.00001)
  
    #Test routing
    
    _, v_before, s_before, c_before, v_routed,s_routed, \
    c_routed, u_hat_routed, b_prior, b_prior_routed, batch_item_count = sess.run([
                                     end_points["route_op"],
                                     end_points["v"],
                                     end_points["s"],
                                     end_points["c"],
                                     end_points["v_routed"],
                                     end_points["s_routed"],
                                     end_points["c_routed"],
                                     end_points["u_hat_routed"],
                                     end_points["b_prior"],
                                     end_points["b_prior_routed"],
                                     end_points["batch_item_cnt"]], feed_dict=feed_dict)
  
    assert(np.all(u_hat_routed==u_hat))
    
    ###############################################################################################
    # Here we are measuring the agreements before and after applying the routing algorithms and   #
    # we are testing that the expected value for each class in fact increased meaning that        #
    # the routing algorithm helped to route them to better digiCaps                               #
    ###############################################################################################
    
    aggr_before, aggr_after = np.sum(np.multiply(v_before[0,...], u_hat[0,...]), axis=2), \
        np.sum(np.multiply(v_routed[0,...], u_hat[0,...]), axis=2)
  
    assert(np.min(aggr_after.mean(axis=0) - aggr_before.mean(axis=0))>0)
    
    #Test the losses
    
    X, v_out, margin_loss_present, margin_loss_not_present, total_margin_loss, fc_out, reconstruction_loss, loss, v_norm =\
                                                            sess.run([ end_points[k] 
                                                              for k in ["X", "v_routed", 
                                                                        "margin_loss_present",
                                                                        "margin_loss_not_present",
                                                                        'total_margin_loss', 
                                                                        "fc_out",
                                                                        'reconstruction_loss', 
                                                                        'loss', "v_norm"]], feed_dict=feed_dict)
    #print(v_out)
    
    assert(np.all(np.where(margin_loss_present != 0.)[1] == np.array([5, 0])))
    
    #assert(np.all(np.where(margin_loss_not_present != 0.)[1] == np.array([0, 1, 2, 3, 4,6,7,8,9,1,2,3,4,5,6,7,8,9])))
    
    assert(margin_loss_present[0][5] - np.power(np.max([0, 0.9 - np.linalg.norm(v_out[0][5])]), 2)<0.00001)
    
    assert(np.linalg.norm(margin_loss_not_present[0][0] - 
                          0.5*np.power(np.max([0., np.linalg.norm(v_out[0][0])-0.1]), 2))<0.00001)
    
    assert(np.linalg.norm(total_margin_loss[0] - np.sum(margin_loss_present[0]+ margin_loss_not_present[0]))<0.00001)
    
    assert(np.abs(np.sum(np.square(fc_out[0] - X[0].reshape(784,)))*0.0005- reconstruction_loss[0])<0.01)
    
    assert(loss == np.mean(np.add(reconstruction_loss, total_margin_loss)))
    print("All done")
sess.close()


# Training the capsule network

Paper mentiones shifting images by two pixels but it is not done here<br/>
I also did not add noise for reconstruction make sure the essentials work

In [None]:
batch_size = 2
tf.reset_default_graph()
graph, end_points = build_graph(batch_size, is_train=False, add_summaries=False, print_shape=True)

mnist = Mnist(path="..\\", train=True)
#with scale there is a problem beacuse weights are still low and outputs gets small
#for testing we simply override scaling to test the functionality

data, lbls, lbls_one = mnist.get_batch(batch_size, scale=True)    
data = np.expand_dims(data, -1)


with tf.Session(graph=graph) as sess:
    restorer = tf.train.Saver()
    restorer.restore(sess, save_path="..\\saved_model\\capsnet")
    feed_dict = {end_points["X"]:data, end_points["Y"]:lbls}
    print(sess.run(end_points["b_prior"], feed_dict=feed_dict))
    
    
    

In [91]:
nan_caps[0][0].shape

(1, 1152, 10, 1)

In [None]:
np.all(np.isnan(o)==False)

In [None]:
idxs = np.where(np.isnan(nan_caps[3]) == True)

In [178]:
b_prior.shape

(1, 1152, 10, 1)

In [6]:
margin_loss_present.shape

(60, 10)

In [13]:
lbls

array([8, 6, 5, 7, 7, 8, 8, 9, 7, 4, 7, 3, 2, 0, 8, 6, 8, 6, 1, 6, 8, 9, 4,
       0, 9, 0, 4, 1, 5, 4, 7, 5, 3, 7, 4, 9, 8, 5, 8, 6, 3, 8, 6, 9, 9, 1,
       8, 3, 5, 8, 6, 5, 9, 7, 2, 5, 0, 8, 5, 1])

In [168]:
np.argmax(np.linalg.norm(u_hat.mean(axis=0).mean(axis=0), axis=1))

3