In [1]:
import numpy as np
import glob
import sys
import imageio
import scipy.misc
import tensorflow as tf
from tensorflow.python.framework import ops
from sklearn.model_selection import train_test_split
import math
import matplotlib.pyplot as plt
%pylab inline
import pandas as pd
import seaborn as sns
sns.set_style("darkgrid")
import os
import gc
import random
from PIL import Image

  from ._conv import register_converters as _register_converters


Populating the interactive namespace from numpy and matplotlib


In [2]:
def crop(img):
    half_the_width = img.size[0]
    half_the_height = img.size[1]
    img1 = img.crop(
        (
            32 ,
            32 ,
            half_the_width-32 ,
            half_the_height-32 
        )
    )
    return img1

def get_img_array(path):
    """
    Given path of image, returns it's numpy array
    """
    img = crop(Image.open(path))
    return np.asarray(img)

def resize_img(img_arr, target_dim=(50, 50)):
    """
    Resizes img represented as numpy array
    """
    return scipy.misc.imresize(img_arr, target_dim)

In [3]:
def get_files(folder,suffix):
    """
    Given path to folder, returns list of files in it
    """
    filenames = [file for file in glob.glob(folder+'/*/'+suffix+'/*.png')]
    return filenames

In [4]:
def augmentation(x, max_offset=2):
    bz, h, w, c = x.shape
    bg = np.zeros([bz, w + 2 * max_offset, h + 2 * max_offset, c])
    offsets = np.random.randint(0, 2 * max_offset + 1, 2)
    #shift
    bg[:, offsets[0]:offsets[0] + h, offsets[1]:offsets[1] + w, :] = x
    return bg[:, max_offset:max_offset + h, max_offset:max_offset + w, :]

def get_images_labels(folder,suffix,batch_size):
    """
    returns numpy array of all samples in folder
    each column is a sample flattened
    """
    files = random.sample(get_files(folder,suffix),batch_size)
    images = []
    labels = []
    
    for f in files:

        label = float(f.split('/')[-1][:2])
            
        img_arr = get_img_array(f)/255.0
        images.append(img_arr)
        labels.append(label)
        gc.collect()
#     X = tf.map_fn(lambda image: tf.image.per_image_standardization(image), images, tf.int32)
    X = np.array(images)
    y = np.array(labels)
#     X = images
    return X,y

In [5]:
def train_valid_iter(folder,iters=1000,batch_size=32,is_shift_ag=True):
    max_offset = int(is_shift_ag) * 2
    for i in range(iters):
        X_train,y_train = get_images_labels(folder,'train',batch_size)
        X_valid,y_valid = get_images_labels(folder,'test',batch_size)
        yield augmentation(X_train,max_offset),y_train.reshape(-1,1),augmentation(X_valid,max_offset),y_valid.reshape(-1,1)

In [6]:
class CapsNet(object):
    def __init__(self,
                 routing_iterations=3,
                 batch_size=32,
                 steps=5000,
                 norm=True,
                 lr_find=False):
        """
        routing_iterations: iterations for routing in CapsNet
        batch_size: data size for every batch
        is_multi_mnist: if it's for single or overlapping image
        steps: epochs 
        norm: if batch_normalization or not
        """
        self.iterations = routing_iterations
        self.batch_size = batch_size
        
        self.x = tf.placeholder(tf.float32, [None, 192, 192, 3])
        self.y = tf.placeholder(tf.float32, [None, 1])
        
        if lr_find:
            self.lr = tf.placeholder(tf.float32)
        else:
            #use a exponentially decayed learninig rate
            global_step = tf.Variable(0)
            lr = tf.train.exponential_decay(
                0.001, global_step, steps / 10, 0.96, staircase=True)
        self.norm = norm
        self.on_train = tf.placeholder(tf.bool)

        x = self.x
        y = tf.squeeze(self.y)
        
        #reg_term is the penalty term
        length_v,reg_term = self.get_CapsNet(x, self.norm, self.on_train)
        
        self.y_pred = length_v
        self.y_test = self.y
        
        self.loss = tf.reduce_sum(tf.square(tf.subtract(y, length_v))) + reg_term
        
        #adam optimizer
        if lr_find:
            self.train = tf.train.AdamOptimizer(learning_rate=self.lr).minimize(self.loss)
        else:
            self.train = tf.train.AdamOptimizer(
                learning_rate=lr).minimize(
                    self.loss, global_step=global_step)


        total_error = tf.reduce_sum(tf.square(tf.subtract(y, tf.reduce_mean(y))))
        unexplained_error = tf.reduce_sum(tf.square(tf.subtract(y, length_v)))
        self.accuracy = tf.subtract(1.0, tf.divide(unexplained_error,total_error))
        

    def get_CapsNet(self, x, norm, on_train, reg=False, reuse=False):
        """
        norm: to use batch-normalization or not
        on_train: if it's for train or test
        """
        #Two conv layers and a capsule layer
        with tf.variable_scope('CapsNet', reuse=reuse):
            wconv1 = tf.get_variable(
                'wconv1', [9, 9, 3, 256],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
            bconv1 = tf.get_variable(
                'bconv1', [256],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
            wconv2 = tf.get_variable(
                'wconv2', [5, 5, 256, 8 * 32],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
            bconv2 = tf.get_variable(
                'bconv2', [8 * 32],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
            wconv3 = tf.get_variable(
                'wconv3', [3, 3, 256, 256],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
            bconv3 = tf.get_variable(
                'bconv3', [256],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
            wconv4 = tf.get_variable(
                'wconv4', [3, 3, 256, 1],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
            bconv4 = tf.get_variable(
                'bconv4', [1],
                initializer=tf.random_normal_initializer(mean=0,stddev=0.1))
            
        #L2-regularization
        reg_term = 0
        if reg:
            tf.add_to_collection(tf.GraphKeys.WEIGHTS, wconv1)
            tf.add_to_collection(tf.GraphKeys.WEIGHTS, wconv2)
            tf.add_to_collection(tf.GraphKeys.WEIGHTS, wconv3)
            tf.add_to_collection(tf.GraphKeys.WEIGHTS, wconv4)
            regularizer = tf.contrib.layers.l2_regularizer(scale=5.0/50000)
            reg_term = tf.contrib.layers.apply_regularization(regularizer)
        
        if norm:
            # BN for the capsule layer
            fc_mean, fc_var = tf.nn.moments(
                x,
                axes=[0, 1, 2],
            )
            scale = tf.Variable(tf.ones([1]))
            shift = tf.Variable(tf.zeros([1]))
            epsilon = 0.001
            ema = tf.train.ExponentialMovingAverage(decay=0.5)

            def mean_var_with_update():
                ema_apply_op = ema.apply([fc_mean, fc_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(fc_mean), tf.identity(fc_var)

            mean, var = tf.cond(
                on_train, mean_var_with_update,
                lambda: (ema.average(fc_mean), ema.average(fc_var)))
            x = tf.nn.batch_normalization(x, mean, var, shift, scale, epsilon)
            
        conv1 = tf.nn.conv2d(x, wconv1, [1, 2, 2, 1], padding='VALID') + bconv1
        conv1 = tf.nn.max_pool(conv1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
        if norm:
            # BN for the capsule layer
            fc_mean, fc_var = tf.nn.moments(
                conv1,
                axes=[0, 1, 2],
            )
            scale = tf.Variable(tf.ones([1]))
            shift = tf.Variable(tf.zeros([1]))
            epsilon = 0.001
            ema = tf.train.ExponentialMovingAverage(decay=0.5)

            def mean_var_with_update():
                ema_apply_op = ema.apply([fc_mean, fc_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(fc_mean), tf.identity(fc_var)

            mean, var = tf.cond(
                on_train, mean_var_with_update,
                lambda: (ema.average(fc_mean), ema.average(fc_var)))
            conv1 = tf.nn.batch_normalization(conv1, mean, var, shift, scale, epsilon)
        conv1 = tf.nn.relu(conv1)

        conv2 = tf.nn.conv2d(
            conv1, wconv2, [1, 2, 2, 1], padding='VALID') + bconv2
        conv2 = tf.nn.max_pool(conv2, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='VALID')
        if norm:
            # BN for the capsule layer
            fc_mean, fc_var = tf.nn.moments(
                conv2,
                axes=[0, 1, 2],
            )
            scale = tf.Variable(tf.ones([1]))
            shift = tf.Variable(tf.zeros([1]))
            epsilon = 0.001
            ema = tf.train.ExponentialMovingAverage(decay=0.5)

            def mean_var_with_update():
                ema_apply_op = ema.apply([fc_mean, fc_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(fc_mean), tf.identity(fc_var)

            mean, var = tf.cond(
                on_train, mean_var_with_update,
                lambda: (ema.average(fc_mean), ema.average(fc_var)))
            conv2 = tf.nn.batch_normalization(conv2, mean, var, shift, scale, epsilon)
        conv2 = tf.nn.relu(conv2)

        conv3 = tf.nn.conv2d(
            conv2, wconv3, [1, 2, 2, 1], padding='VALID') + bconv3
        if norm:
            # BN for the capsule layer
            fc_mean, fc_var = tf.nn.moments(
                conv3,
                axes=[0, 1, 2],
            )
            scale = tf.Variable(tf.ones([1]))
            shift = tf.Variable(tf.zeros([1]))
            epsilon = 0.001
            ema = tf.train.ExponentialMovingAverage(decay=0.5)

            def mean_var_with_update():
                ema_apply_op = ema.apply([fc_mean, fc_var])
                with tf.control_dependencies([ema_apply_op]):
                    return tf.identity(fc_mean), tf.identity(fc_var)

            mean, var = tf.cond(
                on_train, mean_var_with_update,
                lambda: (ema.average(fc_mean), ema.average(fc_var)))
            conv3 = tf.nn.batch_normalization(conv3, mean, var, shift, scale, epsilon)
        conv3 = tf.nn.relu(conv3)

        conv4 = tf.nn.conv2d(
            conv3, wconv4, [1, 2, 2, 1], padding='VALID') + bconv4
        conv4 = tf.nn.relu(conv4)
        
        v_digit = tf.squeeze(conv4)

        return v_digit,reg_term

In [8]:
folder = 'data/'
batch_size = 10
is_shift_ag = True
irun = 0
steps = 100000
lr_find = False
lr_list = []
loss_list = []

###  We first try to find the best lr by increasing it by every iter and see at what iter the corresponding lr is doing the best

In [9]:
lr_find = True
lr = 1e-4

In [10]:
train_iter = train_valid_iter(
    folder,iters=steps, batch_size=batch_size, is_shift_ag=True)

net = CapsNet(steps=steps, lr_find=lr_find)

init = tf.global_variables_initializer()

sess = tf.Session()

sess.run(init)

for X, Y, X_TEST, Y_TEST in train_iter:
    a = time.time()
    if lr_find:
        LS, ACC, _ = sess.run([net.loss, net.accuracy, net.train], feed_dict={net.x: X, net.y: Y, net.on_train:True ,net.lr: lr})
        lr_list.append(lr)
        loss_list.append(LS)
        b = time.time()
        if irun % 1 == 0:
            print(irun, LS, ACC, '{}s per epoch'.format(b - a))
        lr += np.linspace(1e-4,1e-1,steps)[2]-np.linspace(1e-4,1e-1,steps)[1]   
        
    else:    
        LS, ACC, _ = sess.run(
            [net.loss, net.accuracy, net.train],
            feed_dict={
                net.x: X,
                net.y: Y,
                net.on_train: True
            })
        ACC_TEST = sess.run(
            net.accuracy,
            feed_dict={
                net.x: X_TEST,
                net.y: Y_TEST,
                net.on_train: False
            })

        b = time.time()
        if irun % 100 == 0:
            print(irun, LS, ACC, ACC_TEST, '{}s per epoch'.format(b - a))

        if (irun + 1) % 1000 == 0:
            saver.save(sess, "./checkpoint_CapsNet/MyModel")
            print('Model Saved!')
            
    irun += 1

0 277.24847 -0.5059668 1.0292859077453613s per epoch
1 91.9349 -0.7152034 0.03593635559082031s per epoch
2 44.133404 -0.6406472 0.035400390625s per epoch
3 284.12225 -0.71054935 0.03535866737365723s per epoch
4 135.64552 -0.6052724 0.03509688377380371s per epoch
5 202.09058 -0.3832345 0.03605771064758301s per epoch
6 187.36809 -0.7527417 0.03700828552246094s per epoch
7 115.00783 -0.6524112 0.037119388580322266s per epoch
8 39.69991 -1.8155963 0.03674435615539551s per epoch
9 84.618324 -1.6777952 0.03500723838806152s per epoch
10 42.265438 -0.31668043 0.03591513633728027s per epoch
11 430.526 -0.31699598 0.035755157470703125s per epoch
12 345.06775 -0.95063746 0.03673553466796875s per epoch
13 113.667595 -1.4656744 0.03698086738586426s per epoch
14 70.15604 -0.33885562 0.03653216361999512s per epoch
15 147.04784 -0.74227285 0.03587675094604492s per epoch
16 153.78345 -0.35372758 0.03656411170959473s per epoch
17 10.149239 0.6035453 0.036231279373168945s per epoch
18 102.61274 -0.665791

149 119.47732 -0.22415292 0.03604626655578613s per epoch
150 132.62794 -1.1530511 0.036853790283203125s per epoch
151 135.27368 -0.30070853 0.03621935844421387s per epoch
152 188.16609 -0.39485598 0.036765336990356445s per epoch
153 89.39699 -2.9731994 0.03604459762573242s per epoch
154 65.790245 -0.92933273 0.036302804946899414s per epoch
155 41.09408 0.074457586 0.03603959083557129s per epoch
156 148.8532 -0.09773755 0.03539872169494629s per epoch
157 667.1627 -0.05214119 0.03714895248413086s per epoch
158 34.997704 0.7105236 0.03661537170410156s per epoch
159 122.66239 0.30305457 0.03934144973754883s per epoch
160 29.178526 -1.0262866 0.03828763961791992s per epoch
161 382.911 -0.87242544 0.03644371032714844s per epoch
162 227.12404 -0.5599178 0.03616642951965332s per epoch
163 635.3375 -0.49985266 0.03678154945373535s per epoch
164 227.40323 -0.5272212 0.036585092544555664s per epoch
165 316.26337 0.04278642 0.0367281436920166s per epoch
166 358.08017 0.1556704 0.035759925842285156

295 110.372154 0.36604166 0.03628396987915039s per epoch
296 125.18375 -0.1634177 0.03608083724975586s per epoch
297 178.75925 -0.61773074 0.03666806221008301s per epoch
298 197.16139 -1.2899115 0.03614473342895508s per epoch
299 155.65237 -0.06611216 0.03563833236694336s per epoch
300 117.95283 0.10370195 0.037125587463378906s per epoch
301 203.4211 0.16630697 0.03597092628479004s per epoch
302 67.70094 -0.064480186 0.03618192672729492s per epoch
303 53.73301 -0.26728797 0.03781461715698242s per epoch
304 238.51796 -2.507617 0.035910844802856445s per epoch
305 143.06332 -16.031351 0.036449432373046875s per epoch
306 120.49118 0.3010952 0.03644752502441406s per epoch
307 140.60167 -0.52166307 0.03632926940917969s per epoch
308 69.805374 -0.30233908 0.037177324295043945s per epoch
309 291.97067 -0.91456175 0.036202192306518555s per epoch
310 484.1712 -0.057142377 0.040384531021118164s per epoch
311 61.30978 -0.049825072 0.03616929054260254s per epoch
312 36.701614 -0.016665101 0.0363492

441 207.76208 -0.16983151 0.03600931167602539s per epoch
442 180.50734 -0.21227229 0.036214590072631836s per epoch
443 233.07153 -0.118385434 0.03595399856567383s per epoch
444 209.03235 0.4980011 0.036696434020996094s per epoch
445 113.25218 -3.718841 0.03585314750671387s per epoch
446 82.97388 -0.5252552 0.036043643951416016s per epoch
447 87.98004 -0.55992985 0.03580117225646973s per epoch
448 338.93506 -0.36888158 0.035486698150634766s per epoch
449 55.38969 -0.45762348 0.037700653076171875s per epoch
450 222.17615 0.37256104 0.03691244125366211s per epoch
451 75.16176 -1.4724264 0.0359039306640625s per epoch
452 134.0955 -0.33961558 0.036054134368896484s per epoch
453 92.72101 -0.28067696 0.03705263137817383s per epoch
454 170.11577 -1.9897327 0.037322044372558594s per epoch
455 96.540886 -0.66163313 0.03600811958312988s per epoch
456 97.34692 0.33324027 0.03587841987609863s per epoch
457 92.654205 -1.6548481 0.0359041690826416s per epoch
458 202.45476 -0.17433155 0.03588914871215

587 69.66092 0.42142087 0.037560462951660156s per epoch
588 203.33841 -0.081587315 0.03764629364013672s per epoch
589 207.4774 0.19112122 0.038101911544799805s per epoch
590 89.16419 0.28324598 0.03848719596862793s per epoch
591 22.90083 0.3284214 0.0366368293762207s per epoch
592 190.42339 -5.7766333 0.036696434020996094s per epoch
593 358.1455 0.3006336 0.037119388580322266s per epoch
594 222.99948 -1.5900054 0.03731846809387207s per epoch
595 176.40042 -0.21739411 0.03601980209350586s per epoch
596 88.66066 -0.6149483 0.03638505935668945s per epoch
597 76.39396 0.70572436 0.036505937576293945s per epoch
598 250.84122 0.036708176 0.03618788719177246s per epoch
599 72.977905 0.15534824 0.035733699798583984s per epoch
600 172.42078 -0.13137007 0.035970449447631836s per epoch
601 14.311449 0.039499998 0.037888288497924805s per epoch
602 169.26035 -0.71837914 0.03660106658935547s per epoch
603 296.76035 -0.2703781 0.03572344779968262s per epoch
604 172.96686 0.014995158 0.036781072616577

734 42.490368 0.4824559 0.03612065315246582s per epoch
735 42.069748 -0.049120784 0.037931203842163086s per epoch
736 120.53966 0.18773818 0.03578519821166992s per epoch
737 75.213104 0.20325094 0.03603219985961914s per epoch
738 574.0319 -0.059881687 0.03613781929016113s per epoch
739 55.497337 -1.351582 0.03828597068786621s per epoch
740 114.31892 0.4943878 0.037947654724121094s per epoch
741 196.31198 -0.2393434 0.03707146644592285s per epoch
742 85.09473 0.13168645 0.04394364356994629s per epoch
743 39.94043 -0.6369028 0.03674888610839844s per epoch
744 66.00836 -0.69687295 0.03711223602294922s per epoch
745 95.04161 -0.74708855 0.03724217414855957s per epoch
746 68.2788 0.35586035 0.04016399383544922s per epoch
747 829.938 -0.27899206 0.03597402572631836s per epoch
748 141.09871 0.4444933 0.03615236282348633s per epoch
749 41.482742 0.72270894 0.035808563232421875s per epoch
750 102.81784 -0.63462377 0.03598761558532715s per epoch
751 16.32135 0.3817671 0.03658461570739746s per ep

KeyboardInterrupt: 

In [None]:
plt.figure(figsize=(30,10))
plt.plot(lr_list,loss_list);

### 2

In [None]:
lr_find = False
steps = 5000

In [None]:
train_iter = train_valid_iter(
    folder,iters=steps, batch_size=batch_size, is_shift_ag=True)

net = CapsNet(steps=steps, lr_find=lr_find)

init = tf.global_variables_initializer()

sess = tf.Session()

sess.run(init)

for X, Y, X_TEST, Y_TEST in train_iter:
    a = time.time()
    if lr_find:
        LS, ACC, _ = sess.run([net.loss, net.accuracy, net.train], feed_dict={net.x: X, net.y: Y, net.on_train:True ,net.lr: lr})
        lr_list.append(lr)
        loss_list.append(LS)
        b = time.time()
        if irun % 1 == 0:
            print(irun, LS, ACC, '{}s per epoch'.format(b - a))
        lr += np.linspace(1e-3,1,steps)[2]-np.linspace(1e-3,1,steps)[1]   
        
    else:    
        LS, ACC, _ = sess.run(
            [net.loss, net.accuracy, net.train],
            feed_dict={
                net.x: X,
                net.y: Y,
                net.on_train: True
            })
        ACC_TEST = sess.run(
            net.accuracy,
            feed_dict={
                net.x: X_TEST,
                net.y: Y_TEST,
                net.on_train: False
            })

        b = time.time()
        if irun % 1 == 0:
            print(irun, LS, ACC, ACC_TEST, '{}s per epoch'.format(b - a))

        if (irun + 1) % 1000 == 0:
            saver.save(sess, "./checkpoint_CapsNet/MyModel")
            print('Model Saved!')
            
    irun += 1

In [12]:
X ,Y = get_images_labels(folder,'train',32)

In [15]:
Y

array([ 0., 10., 13.,  0.,  2.,  0.,  1., 10.,  1.,  0.,  0.,  1.,  1.,
        0.,  1.,  0.,  1.,  0.,  7.,  7.,  2.,  8.,  6.,  0.,  3.,  1.,
        0.,  0.,  0.,  4.,  1.,  1.])

In [13]:
X[0]

array([[[0.51372549, 0.41176471, 0.41176471],
        [0.63921569, 0.55294118, 0.55686275],
        [0.67058824, 0.60784314, 0.59607843],
        ...,
        [0.82352941, 0.78431373, 0.76078431],
        [0.84705882, 0.80784314, 0.78823529],
        [0.83921569, 0.79215686, 0.76862745]],

       [[0.65490196, 0.55686275, 0.56078431],
        [0.69411765, 0.61176471, 0.60392157],
        [0.63529412, 0.56470588, 0.54117647],
        ...,
        [0.82352941, 0.78039216, 0.75686275],
        [0.82352941, 0.78039216, 0.75686275],
        [0.83921569, 0.79215686, 0.76862745]],

       [[0.6745098 , 0.57647059, 0.56862745],
        [0.65098039, 0.56078431, 0.54117647],
        [0.59607843, 0.5254902 , 0.49019608],
        ...,
        [0.81960784, 0.77647059, 0.75294118],
        [0.82745098, 0.78431373, 0.76078431],
        [0.85098039, 0.80784314, 0.78431373]],

       ...,

       [[0.85882353, 0.84705882, 0.82745098],
        [0.85490196, 0.84705882, 0.82745098],
        [0.85098039, 0