In [1]:
import tensorflow as tf
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm

  from ._conv import register_converters as _register_converters


In [2]:
def safe_norm(ivec,eps=10**-7,axis=-1):
    ovec = tf.reduce_sum(tf.square(ivec),axis=axis,keepdims=True)
    ovec = tf.sqrt(ovec + eps)
    return ovec
def squash(s_j, eps=10**-7, name = None, axis=-1):
    #v_j = (np.linalg.norm(s_j)**2/(1 + np.linalg.norm(s_j)**2)) * (s_j/np.linalg.norm(s_j))
    with tf.name_scope(name, default_name='squash'):
        sq_norm = tf.square(safe_norm(s_j,axis=axis))
        v_j = (sq_norm / (1 + sq_norm)) * (s_j/(safe_norm(s_j) + eps))
        return v_j


In [3]:
im_shape = (28,28,1)
k = 9

X = tf.placeholder(tf.float32,(None,) + im_shape)
num_pri_caps, num_pri_dims, num_dig_caps, num_dig_dims = 32*36, 8, 10, 16
#Feature extraction layer
c1 = tf.layers.Conv2D(256,(k,k),name='conv1')(X)
#Primary capsules
c2 = tf.layers.Conv2D(256,(k,k),strides=(2,2),name='conv2')(c1)
#c2 = tf.layers.Conv2D(256,(3,3),name='conv2')(c2)
pri_caps_raw = tf.reshape(c2, [-1,num_pri_caps,num_pri_dims], name = 'pri_caps_raw')
pri_caps_output = squash(pri_caps_raw)

#Digit capsules
init_sigma = 0.01
W_init = tf.random_normal(shape=(1,num_pri_caps,num_dig_caps,num_dig_dims, num_pri_dims),
                         stddev=init_sigma, dtype=tf.float32,name='W_init')
W = tf.Variable(W_init, name='W')
batch_size = tf.shape(X)[0]
#Following steps for one-shot matmul 
W_tiled = tf.tile(W,[batch_size,1,1,1,1],name='W_tiled')

'''
Reshape pri_cap output as column vector 
None,1152,8 to None,1152,8,1
'''

pri_caps_expanded = tf.expand_dims(pri_caps_output,-1,name='pri_caps_expanded')
'''
Add dimension for number of digit caps 
None,1152,8,1 to None,1152,1,8,1
'''
pri_caps_expanded = tf.expand_dims(pri_caps_expanded,2,name='pri_caps_expanded2')
'''None,1152,1,8,1 to None,1152,10,8,1'''
pri_caps_tiled = tf.tile(pri_caps_expanded, [1,1,num_dig_caps,1,1],name='pri_caps_tiled')

'''
Prediction u_j|i = dot(W_ij, u_i) --> Affine transform
where W_ij is a transformation matrix that will be learned
'''

dig_caps_predicted = tf.matmul(W_tiled,pri_caps_tiled,name='dig_caps_predicted')

##########-----ROUTING-----############

def condition(b_ij, v_j, counter, num_iterations):
    return counter < num_iterations
#Takes in routing weights and outputs updated weights
def routing(b_ij, agreement, counter, num_iterations):
    #c_i = softmax(b_i) --> along dig_caps dimension
    c_ij = tf.nn.softmax(b_ij,dim=2,name='routing_coeffs')

    '''
    s_j = sum(c_ij * u_j|i)
    where c_ij contains coupling coefficients 
    that will be determined by routing
    s_j is the raw output vector that represents the 
    instantiation parameters of each digit
    shape will be batch_size x 10 x 16 x 1
    '''
    #Keep_dims = true, otherwise s_j will be rank 4
    #==> Shape = None,1,10,16,1
    s_j = tf.reduce_sum(c_ij * dig_caps_predicted,axis=1,keepdims=True, name='weighted_sum')
    v_j = squash(s_j)
    #Make a copy of v_j for each primary capsule
    v_j_tiled = tf.tile(v_j,[1,num_pri_caps,1,1,1],name='v_j_tiled')
    agreement = tf.matmul(dig_caps_predicted,v_j_tiled,transpose_a=True,name='agreement')
    counter += 1
    return b_ij + agreement, v_j, counter, num_iterations

#extra ones at the end to keep it same rank as W
b_ij = tf.zeros([batch_size,num_pri_caps,num_dig_caps,1,1],dtype=tf.float32,name='b_ij')
num_iterations = 3
counter = tf.constant(0)
v_j=tf.zeros([batch_size,1,num_dig_caps,num_dig_dims,1],dtype=tf.float32,name='v_j_init')
#Dynamic loop keeps tf graph small
final_routing_weights, dig_caps_output, _,_ = tf.while_loop(condition,routing,[b_ij,v_j,counter,num_iterations])
y_prob = safe_norm(dig_caps_output,axis=-2)
y=tf.placeholder(tf.int32,shape=[None,1])
######----Margin Loss----#########
T = tf.one_hot(y,depth=num_dig_caps,name='T')
T = tf.reshape(T,tf.shape(y_prob))
m_pos = 0.9
m_neg = 0.1
lamb = 0.5
L_margin = tf.reduce_mean(tf.reduce_sum(T*tf.square(tf.maximum(0.,m_pos-y_prob)) + lamb*(tf.ones_like(T)-T) * tf.square(tf.maximum(0.,y_prob-m_neg)),axis=2))

#############-------------Decoder-----------------#######################
n_hidden = 512
n2_hidden = 1024
out_size = 28 * 28
alpha = 0.2
r_alpha = 0.0005
with tf.name_scope('decoder'):
    dc_1 = tf.layers.Dense(n_hidden,input_shape=(None,16))(tf.reshape(dig_caps_output[:,:,y[0][0],:,:],[-1,16]))
    dc_1 = tf.maximum(dc_1, alpha*dc_1)
    dc_2 = tf.layers.Dense(n2_hidden)(dc_1)
    dc_2 = tf.maximum(dc_2, alpha*dc_2)
    dc_out = tf.layers.Dense(out_size,activation=tf.nn.sigmoid)(dc_2)
    #Squared difference between reconstructed images and input images
    reconstruction_loss = tf.reduce_sum(tf.square(tf.reshape(X,[-1,out_size]) - dc_out))
L_total = L_margin + r_alpha*reconstruction_loss
#L_sum = tf.reduce_sum(L_total)
adam = tf.train.AdamOptimizer()
train = adam.minimize(L_total)
epochs = 50


Instructions for updating:
dim is deprecated, use axis instead


In [4]:
df = pd.read_csv('train.csv')

In [5]:
labels = np.asarray(df.label).reshape(42000,1)
train_x = np.asarray(df[df.columns[1:]])/255.0
print(labels[1], train_x.shape)

[0] (42000, 784)


In [None]:
saver = tf.train.Saver()
with tf.Session() as sess:
    try:
        sess.run(tf.global_variables_initializer())
        #saver.restore(sess, './checkpoints/capsnet_v1.ckpt')
        for e in range(epochs):
            print('Epoch ' + str(e))
            for i in range(40000):
                sess.run(train,{X:train_x[i].reshape((1,)+im_shape),y:labels[i].reshape(1,1)})
                if i%100==0:
                    saver.save(sess,'./checkpoints/capsnet_v1.ckpt')
                    print('Loss: ' + str(sess.run(L_total,{X:train_x[i].reshape((1,)+im_shape),y:labels[i].reshape(1,1)})))
    except Exception as e1:
        print(e1)
    finally:
        saver.save(sess,'./checkpoints/capsnet_v1.ckpt')
        print('Model saved')

Epoch 0
Loss: 0.9025147
Loss: 0.57550716
Loss: 0.15906316
Loss: 0.15309846
Loss: 0.13638262
Loss: 0.22661163
Loss: 0.043537788
Loss: 0.08224846
Loss: 0.19687043
Loss: 0.27934554
Loss: 0.085030854
Loss: 0.04475738
Loss: 0.099840626
Loss: 0.17799541
Loss: 0.043480173
Loss: 0.004708177
Loss: 0.26931676
Loss: 0.06343314
Loss: 0.054838274
Loss: 0.0642954
Loss: 0.21839091
Loss: 0.070228904
Loss: 0.22046064
Loss: 0.13529962
Loss: 0.22464588
Loss: 0.026945611
Loss: 0.10880646
Loss: 0.043994904
Loss: 0.022421796
Loss: 0.18667723
Loss: 0.084155984
Loss: 0.06519295
Loss: 0.010208081
Loss: 0.027051006
Loss: 0.012591977
Loss: 0.19425549
Loss: 0.108961314
Loss: 0.19390076
Loss: 0.06259301
Loss: 0.0043389946
Loss: 0.10591846
Loss: 0.18261445
Loss: 0.030512255
Loss: 0.07434363
Loss: 0.26745307
Loss: 0.26190683
Loss: 0.1872355
Loss: 0.06404972
Loss: 0.050143003
Loss: 0.060049742
Loss: 0.121065386
Loss: 0.17374215
Loss: 0.028255496
Loss: 0.0062751914
Loss: 0.23436908
Loss: 0.033650775
Loss: 0.031693712


Loss: 0.03815904
Loss: 0.018776705
Loss: 0.025646713
Loss: 0.037218723
Loss: 0.08525793
Loss: 0.02512806
Loss: 0.049307227
Loss: 0.0115834335
Loss: 0.017373534
Loss: 0.04393162
Loss: 0.044301733
Loss: 0.042408325
Loss: 0.038442094
Loss: 0.023612304
Loss: 0.1469315
Loss: 0.3085532
Loss: 0.018997375
Loss: 0.020256534
Loss: 0.02809542
Loss: 0.12353712
Loss: 0.013097966
Loss: 0.3681657
Loss: 0.047818653
Loss: 0.09217334
Loss: 0.024429122
Loss: 0.059250064
Loss: 0.013940566
Loss: 0.027842704
Loss: 0.059937984
Loss: 0.08166695
Loss: 0.043946683
Loss: 0.32206148
Loss: 0.052722238
Loss: 0.052091956
Loss: 0.03186442
Loss: 0.02597025
Loss: 0.11936837
Loss: 0.011074582
Loss: 0.016601693
Loss: 0.012200383
Loss: 0.39329243
Loss: 0.021328371
Loss: 0.038722966
Loss: 0.15718383
Loss: 0.17323303
Loss: 0.011504695
Loss: 0.018115
Loss: 0.0037447033
Loss: 0.057458125
Loss: 0.03477659
Loss: 0.018235773
Loss: 0.033484854
Loss: 0.10855407
Loss: 0.031072387
Loss: 0.003761169
Loss: 0.0117635885
Loss: 0.0348604

Loss: 0.020871416
Loss: 0.052594084
Loss: 0.043644916
Loss: 0.012341786
Loss: 0.06794155
Loss: 0.10273828
Loss: 0.02346218
Loss: 0.016276564
Loss: 0.015424283
Loss: 0.033980504
Loss: 0.058065053
Loss: 0.0126328925
Loss: 0.056885004
Loss: 0.025187515
Loss: 0.08631794
Loss: 0.068802506
Loss: 0.010134547
Loss: 0.025188293
Loss: 0.0073671415
Loss: 0.0035489427
Loss: 0.15897754
Loss: 0.0050766114
Loss: 0.0060933838
Loss: 0.04620026
Loss: 0.035185084
Loss: 0.083890885
Loss: 0.014603107
Loss: 0.018548332
Loss: 0.023167074
Loss: 0.0093827415
Loss: 0.054050542
Loss: 0.007955837
Loss: 0.026568945
Loss: 0.016743649
Loss: 0.01935308
Loss: 0.024688017
Loss: 0.09640251
Loss: 0.025720302
Loss: 0.08269817
Loss: 0.056894496
Loss: 0.059469312
Loss: 0.0413313
Loss: 0.0889547
Loss: 0.05270465
Loss: 0.03178314
Loss: 0.011817651
Loss: 0.037235513
Loss: 0.022430845
Loss: 0.008657421
Loss: 0.021425487
Loss: 0.020011287
Loss: 0.10215991
Loss: 0.071340814
Loss: 0.037638023
Loss: 0.0983891
Loss: 0.10569637
Loss:

Loss: 0.009377304
Loss: 0.028497
Loss: 0.04430437
Loss: 0.034397133
Loss: 0.067460306
Loss: 0.026926924
Loss: 0.022223918
Loss: 0.01536198
Loss: 0.030563043
Loss: 0.015147187
Loss: 0.06799955
Loss: 0.024554618
Loss: 0.02537703
Loss: 0.041204024
Loss: 0.016563248
Loss: 0.03886228
Loss: 0.013268883
Loss: 0.028014224
Loss: 0.008143699
Loss: 0.0032612856
Loss: 0.043682665
Loss: 0.043819867
Loss: 0.07283649
Loss: 0.05034629
Loss: 0.026777217
Loss: 0.06829397
Loss: 0.010713449
Loss: 0.022723451
Loss: 0.024757557
Loss: 0.029496673
Loss: 0.025385316
Loss: 0.12999293
Loss: 0.039460473
Loss: 0.013178356
Loss: 0.0694109
Loss: 0.038837917
Loss: 0.031485837
Loss: 0.013667991
Loss: 0.025469417
Loss: 0.008861426
Loss: 0.017346486
Loss: 0.007875792
Loss: 0.01950153
Loss: 0.026727604
Loss: 0.06562279
Loss: 0.017610442
Loss: 0.025759064
Loss: 0.011739721
Loss: 0.119028226
Loss: 0.014311457
Loss: 0.01978495
Loss: 0.020880561
Loss: 0.12661155
Loss: 0.022520794
Loss: 0.011973837
Loss: 0.026837168
Loss: 0.0

Loss: 0.014154528
Loss: 0.109984815
Loss: 0.014853676
Loss: 0.023016155
Loss: 0.018156145
Loss: 0.003616836
Loss: 0.065599814
Loss: 0.016282946
Loss: 0.007483941
Loss: 0.012149259
Loss: 0.120451145
Loss: 0.003398385
Loss: 0.049150966
Loss: 0.008506646
Loss: 0.073399365
Loss: 0.026408166
Loss: 0.00549078
Loss: 0.0154077355
Loss: 0.013307806
Loss: 0.019012442
Loss: 0.01719108
Loss: 0.03867983
Loss: 0.04886088
Loss: 0.058620762
Loss: 0.15972972
Loss: 0.039480276
Loss: 0.039946385
Loss: 0.060478404
Loss: 0.03202456
Loss: 0.06948336
Loss: 0.01141917
Loss: 0.009000655
Loss: 0.018724969
Loss: 0.061830945
Loss: 0.0068766745
Loss: 0.0054652966
Loss: 0.028017337
Loss: 0.019186044
Loss: 0.11300038
Loss: 0.01798553
Loss: 0.05193229
Loss: 0.0155667495
Loss: 0.025790308
Loss: 0.058358673
Loss: 0.019221418
Loss: 0.02721993
Loss: 0.011891831
Loss: 0.022259507
Loss: 0.016617537
Loss: 0.030606233
Loss: 0.017613653
Loss: 0.021765897
Loss: 0.030003956
Loss: 0.09835872
Loss: 0.015900228
Loss: 0.02384266
Lo

Loss: 0.032409802
Loss: 0.052931644
Loss: 0.018713895
Loss: 0.025043514
Loss: 0.028908737
Loss: 0.032578755
Loss: 0.01833889
Loss: 0.050832417
Loss: 0.026085552
Loss: 0.005133179
Loss: 0.0056496128
Loss: 0.018093247
Loss: 0.0860932
Loss: 0.056011103
Loss: 0.06643279
Loss: 0.01818143
Loss: 0.014344279
Loss: 0.062533595
Loss: 0.004614897
Loss: 0.014070431
Loss: 0.04322815
Loss: 0.0074243248
Loss: 0.010581163
Loss: 0.010376219
Loss: 0.08245708
Loss: 0.08526508
Loss: 0.016668105
Loss: 0.0062191444
Loss: 0.06250331
Loss: 0.054871283
Loss: 0.04038546
Loss: 0.008282892
Loss: 0.058144458
Loss: 0.02971484
Loss: 0.021055423
Loss: 0.011991838
Loss: 0.037192106
Loss: 0.06300098
Loss: 0.014161792
Loss: 0.016967185
Loss: 0.017914819
Loss: 0.01563453
Loss: 0.0034712334
Loss: 0.009049303
Loss: 0.02301471
Loss: 0.0065949457
Loss: 0.045591228
Loss: 0.07122144
Loss: 0.021497268
Loss: 0.030381382
Loss: 0.009624492
Loss: 0.01548061
Loss: 0.055596568
Loss: 0.060642526
Loss: 0.01474653
Loss: 0.066642575
Loss

In [17]:
y_prob.shape

TensorShape([Dimension(None), Dimension(1), Dimension(10), Dimension(1), Dimension(1)])

