In [1]:
import tensorflow as tf
import numpy as np
from datetime import datetime

In [16]:
class Mnist:
    def __init__(self, path, train=True):
        if train:
            image_file_name = "train-images.idx3-ubyte"
            label_file_name = "train-labels.idx1-ubyte"
        else:
            image_file_name = "t10k-images.idx3-ubyte"
            label_file_name = "t10k-labels.idx1-ubyte"
            
        self.f = open("{}\\{}".format(path, image_file_name), "rb")
        magic_number, number_of_images, image_rows, image_columns = int(self.f.read(4).hex(), 16), \
        int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16), int(self.f.read(4).hex(), 16)
        print("Images magic no: {},  No of images: {}, Image rows: {}, Image cols: {}".\
          format(magic_number, number_of_images, image_rows, image_columns))

        self.flab = open("{}\\{}".format(path, label_file_name), "rb")
        magic_number, number_of_items = int(self.flab.read(4).hex(), 16), int(self.flab.read(4).hex(), 16)
        print("Labels magic no: {}, No of items: {}".format(magic_number, number_of_items))

    def _get_next(self):
        val = self.f.read(28*28)
        if len(val)>0:
            img = val.hex()
            idxs = np.arange(0, len(img), 2).astype(int)
            return(np.reshape([int(img[i:(i+2)], 16) for i in idxs], (28, 28)), int(self.flab.read(1).hex(), 16)) 
        else:
            return("", "")

    def get_batch(self, batch_size=20):
        im_lst = np.zeros(shape=[batch_size, 28, 28], dtype="float")
        lbl_lst = np.zeros(shape=[batch_size], dtype="int")
        lbl_one_hot_lst = np.zeros(shape=[batch_size, 10], dtype="int")
        for i in range(batch_size):
            im, lbl = self._get_next()
            if len(im)>0:
                im_lst[i,...] = im
                lbl_lst[i] = lbl
                lbl_one_hot_lst[i,lbl] = 1
        return im_lst, lbl_lst, lbl_one_hot_lst
    

        
        
    
    

In [3]:
graph = tf.Graph()
batch_size=200
with graph.as_default():
    X = tf.placeholder(shape=[batch_size, 28, 28, 1], name="X", dtype=tf.float32)
    Y = tf.placeholder(shape=[batch_size], name="Y", dtype=tf.float32)
    Y_one_hot = tf.one_hot(indices=Y, )
    
    #First Layer
    with tf.name_scope("Conv_Layer_1"):
        conv_1 = tf.contrib.layers.conv2d(inputs=X, 
                                          num_outputs=256,  
                                          kernel_size=9, 
                                          stride=1, padding="VALID", 
                                          activation_fn=tf.nn.relu, 
                                          weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                          biases_initializer=tf.zeros_initializer())
    capsules = list()
    
    #In this implementation every capsule layer has unique weights
    #I think the meaning of sharing is that in each plane [6, 6, 1]
    #each cepsule output sharing the weights
    
    for n in range(32):
        with tf.name_scope("Capsule_{}".format(n)):
            cap_1 = tf.contrib.layers.conv2d(inputs=conv_1,
                                              num_outputs=8,
                                              kernel_size=9, 
                                              stride=2, padding="VALID", 
                                              activation_fn=None, 
                                              weights_initializer=tf.contrib.layers.xavier_initializer(), 
                                              biases_initializer=tf.zeros_initializer())
            
        
            capsules.append(tf.expand_dims(cap_1, axis=3))
    
    
    with tf.name_scope("DigitalCaps"):
        caps_all = tf.concat(capsules, axis=3)
        caps_all = tf.reshape(caps_all, shape=[-1, 1152, 8])
        caps_all = tf.expand_dims(caps_all, axis=2)

        W = tf.get_variable(shape=[10, 1152, 8, 16], name="W")
        b = tf.get_variable(shape=[10, 1152, 1, 1], name="b", initializer=tf.zeros_initializer())
        b_prior = tf.get_variable(shape=[1, 1152, 10, 1], name="b_prior", initializer=tf.zeros_initializer(), trainable=False)

        caps_all_list = list()
    
        for cl in range(10):
            caps_all_list.append(tf.map_fn(fn=lambda x: tf.add(tf.matmul(x, W[cl,:,:,:]), b[cl,:,:,:]), elems=caps_all))

        u_hat = tf.concat(caps_all_list, axis=2)
        
    def get_s(b_prior, u_hat):
        c = tf.nn.softmax(b_prior, dim=-1)
        s = tf.reduce_sum(tf.multiply(c, u_hat), axis=1)
        return(s)


    #Squash
    def get_v(s):
        norm_s = tf.norm(s)
        norm_s_2 = tf.pow(norm_s, 2)
        v = (norm_s_2*s)/((1+norm_s_2)*norm_s)
        print("v shape {}".format(v.shape))
        return(v)
    
    s = get_s(b_prior, u_hat)
    v = get_v(s)
 
    with tf.name_scope("Routing"):
        
        #Routing
        batch_item_cnt = tf.constant(0)
        end_of_batch = lambda batch_item_cnt, b_prior, u_hat: tf.less(batch_item_cnt, batch_size)

        def update_prior(batch_item_cnt, b_prior, u_hat):

            for i in range(3):
                aggr = tf.reduce_sum(tf.multiply(v[batch_item_cnt, :, :], u_hat[batch_item_cnt, :, :]), axis=-1)
                aggr = tf.expand_dims(aggr, axis=0)
                aggr = tf.expand_dims(aggr, axis=3)
                         
                b_prior = tf.add(b_prior, aggr)

            return([tf.add(batch_item_cnt, 1), b_prior, u_hat])


        route_op = tf.while_loop(body=update_prior, cond=end_of_batch, loop_vars=[batch_item_cnt, b_prior, u_hat])
        
    with tf.control_dependencies(route_op):
        v_out = get_v(get_s(b_prior, u_hat))
    
    
    
        
        
        
            
            
            
    
    

    init = tf.global_variables_initializer()
    

v shape (200, 10, 16)
v shape (200, 10, 16)


In [17]:
mnist = Mnist(path=".\\", train=True)
data, lbls, lbls_one = mnist.get_batch(batch_size)    
data = np.expand_dims(data, -1)

Images magic no: 2051,  No of images: 60000, Image rows: 28, Image cols: 28
Labels magic no: 2049, No of items: 60000


In [18]:
lbls_one

array([[0, 0, 0, ..., 0, 0, 0],
       [1, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ..., 
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 1, ..., 0, 0, 0]])

In [5]:
with tf.Session(graph=graph) as sess:
    #saver = tf.summary.FileWriter(filename_suffix="capsNet", logdir="log", graph=graph)
    #saver.flush()
    sess.run(init)

    feed_dict = {X:data}
    before = datetime.now()
    rt = sess.run(v_out, feed_dict = feed_dict)
    print("Time {}\n{}".format((datetime.now() - before).total_seconds(), rt.shape))

Time 12.332
(200, 10, 16)


In [6]:
np.zeros(10)

array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.])

In [8]:
o = np.zeros(10)
o[7]=1

In [9]:
o

array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,  0.,  0.])