In [1]:
import tensorflow as tf
import numpy as np
from datetime import datetime

In [2]:
class Mnist:
    def __init__(self, path, train=True):
        if train:
            image_file_name = "train-images.idx3-ubyte"
            label_file_name = "train-labels.idx1-ubyte"
        else:
            image_file_name = "t10k-images.idx3-ubyte"
            label_file_name = "t10k-labels.idx1-ubyte"
            
        self.f = open("{}\\{}".format(path, image_file_name), "rb")
        magic_number, number_of_images, image_rows, image_columns = int(self.f.read(4).hex(), 16), \
        int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16)
        print("Images magic no: {},  No of images: {}, Image rows: {}, Image cols: {}".\
          format(magic_number, number_of_images, image_rows, image_columns))

        self.flab = open("{}\\{}".format(path, label_file_name), "rb")
        magic_number, number_of_items = int(self.flab.read(4).hex(), 16), int(self.flab.read(4).hex(), 16)
        print("Labels magic no: {}, No of items: {}".format(magic_number, number_of_items))

    def _get_next(self):
        val = self.f.read(28*28)
        if len(val)>0:
            img = val.hex()
            idxs = np.arange(0, len(img), 2).astype(int)
            return(np.reshape([int(img[i:(i+2)], 16) for i in idxs], (28, 28)), int(self.flab.read(1).hex(), 16)) 
        else:
            return("", "")

    def get_batch(self, batch_size=20):
        im_lst = np.zeros(shape=[batch_size, 28, 28], dtype="float")
        lbl_lst = np.zeros(shape=[batch_size], dtype="int")
        lbl_one_hot_lst = np.zeros(shape=[batch_size, 10], dtype="int")
        for i in range(batch_size):
            im, lbl = self._get_next()
            if len(im)>0:
                im_lst[i,...] = im
                lbl_lst[i] = lbl
                lbl_one_hot_lst[i,lbl] = 1
        return im_lst, lbl_lst, lbl_one_hot_lst
    

        
        
    
    

In [30]:
def get_s(b_prior, u_hat):
    c = tf.nn.softmax(b_prior, dim=-1)
    s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1)
    return(s)


#Squash
def get_v(s):
    norm_s = tf.norm(s)
    norm_s_2 = tf.pow(norm_s, 2)
    v = (norm_s_2*s)/((1+norm_s_2)*norm_s)
    print("v shape {}".format(v.shape))
    return(v)

def update_prior(batch_item_cnt, b_prior, u_hat):

    for i in range(3):
        aggr = tf.reduce_sum(tf.multiply(v[batch_item_cnt, :, :], u_hat[batch_item_cnt, :, :]), axis=-1)
        aggr = tf.expand_dims(aggr, axis=0)
        aggr = tf.expand_dims(aggr, axis=3)
        b_prior = tf.add(b_prior, aggr)

    return([tf.add(batch_item_cnt, 1), b_prior, u_hat])


graph = tf.Graph()
batch_size=2
end_points = dict()

with graph.as_default():
    X = tf.placeholder(shape=[batch_size, 28, 28, 1], name="X", dtype=tf.float32)
    Y = tf.placeholder(shape=[batch_size], name="Y", dtype=tf.int32)
    Y_one_hot = tf.one_hot(indices=Y, depth=10)
    print("Y_one_hot shape={}\n".format(Y_one_hot.shape))

    
    #First Layer
    with tf.name_scope("Conv_Layer_1"):
      
        
        conv_1 = tf.contrib.layers.conv2d(inputs=X, 
                                          num_outputs=256,  
                                          kernel_size=9, 
                                          stride=1, padding="VALID", 
                                          activation_fn=tf.nn.relu, 
                                          weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                          biases_initializer=tf.zeros_initializer())
        end_points["conv_layer_1_act"] = conv_1
        print("Conv_1 shape={}\n".format(conv_1.shape))
    capsules = list()
    
    #In this implementation every capsule layer has unique weights
    #I think the meaning of sharing is that in each plane [6, 6, 1]
    #each cepsule output sharing the weights
    with tf.variable_scope("Capsule_Layer"):
        for n in range(32):
         
           #If we user auto resuse then the output of each capsule would be the same and it doesn't make any sense
        
            cap_1 = tf.contrib.layers.conv2d(inputs=conv_1, 
                                             #reuse=tf.AUTO_REUSE, 
                                             #scope="Capsule_Layer",
                                             num_outputs=8, kernel_size=9, 
                                             stride=2, padding="VALID", 
                                             activation_fn=None, 
                                             weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                             biases_initializer=tf.zeros_initializer())
            print("Cap_{} shape={}\n".format(n, conv_1.shape))
            end_points["capsule_{}_act".format(n)] = cap_1
         
            capsules.append(tf.expand_dims(cap_1, axis=3))
     
    with tf.name_scope("DigitalCaps"):
        caps_all = tf.concat(capsules, axis=3)
        end_points["caps_all"] = caps_all
        
        caps_all = tf.reshape(caps_all, shape=[-1, 1152, 8])
        caps_all = tf.expand_dims(caps_all, axis=2)
        end_points["caps_all"] = caps_all
        print("Caps_all shape={}\n".format(caps_all.shape))

        W = tf.get_variable(shape=[10, 1152, 8, 16], name="W")
        end_points["W_t"] = W
        
        b = tf.get_variable(shape=[10, 1152, 1, 1], name="b", initializer=tf.zeros_initializer())
        end_points["b"] = b
        
        b_prior = tf.get_variable(shape=[1, 1152, 10, 1], name="b_prior", initializer=tf.zeros_initializer(), trainable=False)
        end_points["b_prior"] = b_prior
        
        caps_all_list = list()
    
        for cl in range(10):
            caps_all_list.append(tf.map_fn(fn=lambda x: tf.add(tf.matmul(x, W[cl,:,:,:]), b[cl,:,:,:]), elems=caps_all))

        u_hat = tf.concat(caps_all_list, axis=2)
        end_points["u_hat"] = u_hat
        print("u_hat shape={}\n".format(u_hat.shape))
        
    s = get_s(b_prior, u_hat)
    v = get_v(s)
 
    with tf.name_scope("Routing"):
        
        #Routing
        batch_item_cnt = tf.constant(0)
        end_of_batch = lambda batch_item_cnt, b_prior, u_hat: tf.less(batch_item_cnt, batch_size)
        route_op = tf.while_loop(body=update_prior, cond=end_of_batch, loop_vars=[batch_item_cnt, b_prior, u_hat])
        
    with tf.control_dependencies(route_op):
        v_out = get_v(get_s(b_prior, u_hat))
        end_points["v_out"] = v_out
        print("v_out shape={}\n".format(v_out.shape))
    
    with tf.name_scope("Prepare_for_FC"):
        Y_one_hot_ex = tf.expand_dims(Y_one_hot, axis=1)
        end_points["Y_one_hot_ex"] = Y_one_hot_ex
        
        v_out_masked = tf.squeeze(tf.matmul(Y_one_hot_ex, v_out))
        end_points["v_out_masked"] = v_out_masked
        print("v_out_masked shape={}\n".format(v_out_masked.shape))
    
    with tf.name_scope("Fully_connected"):
        fc_1 = tf.contrib.layers.fully_connected(inputs=v_out_masked, num_outputs=512, 
                                                 activation_fn=tf.nn.relu,
                                                 weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                 biases_initializer=tf.zeros_initializer())
        
        end_points["fc_1"] = fc_1
        print("fc_1 shape={}\n".format(fc_1.shape))
        
        fc_2 = tf.contrib.layers.fully_connected(inputs=v_out_masked, num_outputs=1024, 
                                                 activation_fn=tf.nn.relu,
                                                 weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                 biases_initializer=tf.zeros_initializer())
        end_points["fc_2"] = fc_2
        print("fc_2 shape={}\n".format(fc_2.shape))
        
        fc_out = tf.contrib.layers.fully_connected(inputs=v_out_masked, num_outputs=784, 
                                                 activation_fn=tf.nn.sigmoid,
                                                 weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                                 biases_initializer=tf.zeros_initializer())
        end_points["fc_out"] = fc_out
        print("fc_out shape={}\n".format(fc_out.shape))
        
        
        #Calsulatong loss for each capsule
        lambd_a = 0.5
        #margin_loss =Y_one_hot_ex 
        v_out_norm = tf.norm(v_out, axis=2)
        end_points["v_out_norm"] = v_out_norm
        print("v_out_norm shape={}\n".format(v_out_norm.shape))
        
        margin_loss_present = tf.multiply(tf.pow(tf.maximum(0., 0.9 - v_out_norm), 2), Y_one_hot)
        end_points["margin_loss_present"] = margin_loss_present
        print("margin_loss_present shape={}\n".format(margin_loss_present.shape))
        
        margin_loss_not_present = lambd_a * tf.multiply(tf.pow(tf.maximum(0., v_out_norm - 0.1), 2), (1 - Y_one_hot))
        end_points["margin_loss_not_present"] = margin_loss_not_present
        print("margin_loss_not_present shape={}\n".format(margin_loss_not_present.shape))
        
        total_margin_loss = tf.add(margin_loss_present, margin_loss_not_present)
        end_points["total_margin_loss"] = total_margin_loss
        print("total_margin_loss shape={}\n".format(total_margin_loss.shape))
        
        

    
    
  
        
        
        
            
            
            
    
    

    init = tf.global_variables_initializer()
    

Y_one_hot shape=(2, 10)

Conv_1 shape=(2, 20, 20, 256)

Cap_0 shape=(2, 20, 20, 256)

Cap_1 shape=(2, 20, 20, 256)

Cap_2 shape=(2, 20, 20, 256)

Cap_3 shape=(2, 20, 20, 256)

Cap_4 shape=(2, 20, 20, 256)

Cap_5 shape=(2, 20, 20, 256)

Cap_6 shape=(2, 20, 20, 256)

Cap_7 shape=(2, 20, 20, 256)

Cap_8 shape=(2, 20, 20, 256)

Cap_9 shape=(2, 20, 20, 256)

Cap_10 shape=(2, 20, 20, 256)

Cap_11 shape=(2, 20, 20, 256)

Cap_12 shape=(2, 20, 20, 256)

Cap_13 shape=(2, 20, 20, 256)

Cap_14 shape=(2, 20, 20, 256)

Cap_15 shape=(2, 20, 20, 256)

Cap_16 shape=(2, 20, 20, 256)

Cap_17 shape=(2, 20, 20, 256)

Cap_18 shape=(2, 20, 20, 256)

Cap_19 shape=(2, 20, 20, 256)

Cap_20 shape=(2, 20, 20, 256)

Cap_21 shape=(2, 20, 20, 256)

Cap_22 shape=(2, 20, 20, 256)

Cap_23 shape=(2, 20, 20, 256)

Cap_24 shape=(2, 20, 20, 256)

Cap_25 shape=(2, 20, 20, 256)

Cap_26 shape=(2, 20, 20, 256)

Cap_27 shape=(2, 20, 20, 256)

Cap_28 shape=(2, 20, 20, 256)

Cap_29 shape=(2, 20, 20, 256)

Cap_30 shape=(2, 20, 20,

Lk = Tk max(0;m+ 􀀀 jjvkjj)2 +  (1 􀀀 Tk) max(0; jjvkjj 􀀀 m􀀀)2

In [13]:
mnist = Mnist(path=".\\", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size)    
data = np.expand_dims(data, -1)

Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000


In [27]:
with tf.Session(graph=graph) as sess:
   # saver = tf.summary.FileWriter(filename_suffix="capsNet", logdir="log", graph=graph)
   # saver.flush()
    sess.run(init)

    feed_dict = {X:data, Y:lbls}
    before = datetime.now()
    print(sess.run([margin_loss_present, margin_loss_not_present], feed_dict=feed_dict))
    #cap_1, cap_2 = sess.run([end_points['capsule_0_act'], end_points['capsule_1_act']], feed_dict = feed_dict)
    
    #assert(np.all(cap_1==cap_2))
    #print("Time {}\n{}".format((datetime.now() - before).total_seconds(), rt.shape))

[array([[ 0.        ,  0.        ,  0.        ,  0.        ,  0.        ,
         0.4468953 ,  0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.49729478,  0.        ,  0.        ,  0.        ,  0.        ,
         0.        ,  0.        ,  0.        ,  0.        ,  0.        ]], dtype=float32), array([[ 0.00904059,  0.00757149,  0.00328853,  0.00535245,  0.00924961,
         0.        ,  0.00210414,  0.01467124,  0.0071143 ,  0.00482973],
       [ 0.        ,  0.00775989,  0.00282875,  0.00294855,  0.0170508 ,
         0.01836263,  0.00634866,  0.0115443 ,  0.00478459,  0.01005425]], dtype=float32)]


In [31]:
end_points.keys()

dict_keys(['conv_layer_1_act', 'capsule_0_act', 'capsule_1_act', 'capsule_2_act', 'capsule_3_act', 'capsule_4_act', 'capsule_5_act', 'capsule_6_act', 'capsule_7_act', 'capsule_8_act', 'capsule_9_act', 'capsule_10_act', 'capsule_11_act', 'capsule_12_act', 'capsule_13_act', 'capsule_14_act', 'capsule_15_act', 'capsule_16_act', 'capsule_17_act', 'capsule_18_act', 'capsule_19_act', 'capsule_20_act', 'capsule_21_act', 'capsule_22_act', 'capsule_23_act', 'capsule_24_act', 'capsule_25_act', 'capsule_26_act', 'capsule_27_act', 'capsule_28_act', 'capsule_29_act', 'capsule_30_act', 'capsule_31_act', 'caps_all', 'W_t', 'b', 'b_prior', 'u_hat', 'v_out', 'Y_one_hot_ex', 'v_out_masked', 'fc_1', 'fc_2', 'fc_out', 'v_out_norm', 'margin_loss_present', 'margin_loss_not_present', 'total_margin_loss'])