## Train mnist

In [1]:
import tensorflow as tf

In [64]:
import matplotlib.pyplot as plt
import numpy as np
from IPython import display
import matplotlib.gridspec as gridspec
import seaborn
import matplotlib.animation as animation
from IPython.core.display import HTML #,display


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data.dataset import Dataset
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torchvision.utils import make_grid

from torchvision import datasets, transforms
# torch.set_default_tensor_type('torch.cuda.FloatTensor')
torch.set_default_tensor_type(torch.cuda.FloatTensor)

from time import time
import datetime
from tqdm import tqdm_notebook as tqdm

from glob import glob
from natsort import natsorted

from sklearn.decomposition import PCA
import scipy.ndimage
import imageio

# import onnx
# from onnx_tf.backend import prepare

t0=time()
def tick(msg=''):
    global t0
    t = time()-t0
    print(msg, '{:.2f} sec'.format(t))
    t0 = time()
    
    

class Squash(nn.Module):
    def __init__(self):
        super().__init__()
        
    def forward(self, s):
        # this is a wrong implementation but also works
#         norm = (s**2).sum(dim=-1, keepdim=True)
#         norm2 = norm ** 2

        norm2 = (s**2).sum(dim=-1, keepdim=True)
        norm = torch.sqrt(norm2)
        
        direction = s / norm
        
        length = norm2 / (1+norm2)
#         length = norm / (1+norm)
#         length = (torch.sigmoid(norm)-0.5)*2

        return length * direction
        
        
        
class Capsule(nn.Module):
    def __init__(self, nCapIn, dimCapIn, nCapOut, dimCapOut, dynamicRountingIter=3):
        super().__init__()
        self.nCapIn = nCapIn
        self.nCapOut = nCapOut
        self.dimCapIn = dimCapIn
        self.dimCapOut = dimCapOut
        self.dynamicRountingIter = dynamicRountingIter
        self.w = Parameter(
            torch.randn(nCapIn, dimCapIn, nCapOut*dimCapOut, requires_grad=True)
        )
        self.squash = Squash()
        
    def forward(self, u):
        '''u = tensor[exampleId, capInId, capInDim]
        return s = tensor[exampleId, capOutId, capOutDim]
        '''
        nExample = u.shape[0]
        # this version (without for loop) is inspired by:
        # https://github.com/adambielski/CapsNet-pytorch/blob/master/net.py
        
        # u = [#input, #capsIn, dimCapIn]
        # u1 = [#input, #capsIn, 1, dimCapIn]
        u = u.unsqueeze(2)
        
        # u1 = [#input, #capsIn, 1, dimCapIn]
        # w =          [#capsIn,    #dimCapIn, #CapOut*dimCapOut]
        uHat = u.matmul(self.w)
        # uHat = [#input, #capsIn, 1, #CapOut*dimCapOut]
        uHat = uHat.view(uHat.shape[0], uHat.shape[1], self.nCapOut, self.dimCapOut)
        # uHat = [#input, #capsIn, #CapOut, dimCapOut]
        self.uHat = uHat
        
        ## c=[#input, #capIn, #capOut]
        b = torch.zeros(nExample, self.nCapIn, self.nCapOut)
        c = F.softmax(b, dim=-1).detach()
        self.c = c
        ## c = [#input, #capIn, #capOut, 1]
        c = c.unsqueeze(-1)
        
        ## uHat = [#input, #capIn, #capOut, dimCapOut]
        ## c = [#input, #capsIn, #capOut, 1]
        ## (uHat*c) = [#input, #capsIn, #CapOut, dimCapOut]
        ## s = [#input, #CapOut, dimCapOut]
        s = (uHat*c).sum(dim=1)
        ## v = [#input, #CapOut, dimCapOut]
        v = self.squash(s)
        
        ## dynamic routing
        ## given uHat fixed, compute c and v
        self.cs = []
        for drIter in range(0, self.dynamicRountingIter):
            # this version (without nested for loop) is inspired by:
            # https://github.com/adambielski/CapsNet-pytorch/blob/master/net.py
            
            # update b
            b_batch = b.clone()
            v = v.unsqueeze(1)
            b_batch += (uHat.detach() * v.detach()).sum(-1) # sum over dimCapOut dimension
           
            b = b + b_batch            
            # update c
            c = F.softmax(b, dim=-1).detach()
            self.c = c[:,:,:]
            self.cs.append(self.c)
            c = c.unsqueeze(-1)            
            # update v to recompute b and c
            uHatC = uHat*c
            s = (uHatC).sum(dim=1)
            v = self.squash(s)
            
            self.uHatC = uHatC
            self.s = s
            self.v = v
        return v
    
        
    


# class Capsule2D(Capsule):
#     def __init__(self, nCapIn, dimCapIn, nCapOut, dimCapOut, dynamicRountingIter):
#         super().__init__(nCapIn, dimCapIn, nCapOut, dimCapOut, dynamicRountingIter)
        
#     def forward(self, x):
#         '''x[n_example, height, width, n_capsule_out, dim_capsule_in]
#         return y[n_example, n_capsule_out, dim_capsule_out]
#         '''
# #         x = torch.sum(x, dim=[1,2])
#         x = x.view(x.shape[0], self.nCapIn, self.dimCapIn)
#         v = super().forward(x)
#         return v
    

    
    
    
class ReconstructionNet(nn.Module):
    def __init__(self, dims):
        super().__init__()
        layers = sum([
            [nn.Linear(dims[i], dims[i+1]), nn.ReLU()]
            for i in range(len(dims)-2)
        ] + [
            [nn.Linear(dims[-2], dims[-1]), nn.Sigmoid()]
        ], [])
        
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x, target=None):
        if len(x.shape) > 2:
            x = x.view(-1, self.layers[0].in_features)
        if target is not None:
            e = torch.eye(10).unsqueeze(-1).expand([10,10,16]).contiguous().view(10, 160)
            target = e.index_select(0, target)
            x = x * target
        return self.layers(x)
    
    
    

    
class MnistNet(nn.Module):
    def __init__(self, dynamicRountingIter):
        super().__init__()
        self.conv1 = nn.Conv2d(1,256,9,stride=1)
        
        ## correction inspired by 
        ## https://github.com/higgsfield/Capsule-Network-Tutorial/blob/master/Capsule%20Network.ipynb
        self.conv2 = nn.Conv2d(256,8*32,9,stride=2)
        
        #primary-to-digit-cap (aka 'p2d')
        self.p2d = Capsule(32*6*6, 8, 10,16, dynamicRountingIter)
        self.squash = Squash()

        self.relu = nn.ReLU()
        self.reconstruct = ReconstructionNet([10*16, 512, 1024, 784])

        
    def forward(self, x):
        ## ---------------------------------------
        ## FIRST LAYER image -> ReLU_conv1 
        
        ## x[#input, 1, 28, 28]
        y = self.conv1(x) ##[#input, 256, 20, 20]
        y = self.relu(y) ##[#input, 256, 20, 20]
        
        ## ---------------------------------------
        ## SECOND LAYER: ReLU_conv1 -> PrimaryCaps
        
        ## [#input, channel=8*32=256, height=6, width=6]
        y = self.conv2(y)
        ## [#input, height=6, width=6, channel=8*32=256]
        s1 = y.permute([0,2,3,1]) 
        
        ## [#input, height=6, width=6, nCap=32, dimCap=8]
        s1 = s1.view(s1.shape[0], s1.shape[1], s1.shape[2], 32, 8)
        
        ## squash on last dimemsion: 
        ## (such that v1[i,x,y,c,:] has length at most 1)
        ## [#input, height=6, width=6, nCap=32, nCaps=32, dimCap=8]
        v1 = self.squash(s1) 
        
        ## ---------------------------------------
        ## THIRD LAYER: PrimaryCaps -> DigitCaps
        
        ## [#input, height=6, width=6, nCap=32, nCaps=32, dimCap=8]
        v1_flat = v1.view(v1.shape[0], -1, v1.shape[-1])
        ## [#input, nCap=6*6*32=1152, dimCap=8]
        v2 = self.p2d(v1_flat)
        return v2

    
class Cifar10Net(MnistNet):
    def __init__(self, dynamicRountingIter):
        super().__init__(dynamicRountingIter)
        self.conv1 = nn.Conv2d(3,512,9,stride=1)
        self.conv2 = nn.Conv2d(512,8*32,9,stride=2)

        #primary-to-digit-cap (aka 'p2d')
        self.p2d = Capsule(32*8*8, 8, 10,16, dynamicRountingIter)
        self.reconstruct = ReconstructionNet([10*16, 512, 1024, 32*32*3])    
    
    
def test(model, loader):
    model.eval()
    accuracy = 0
    count = 0
    pbar = tqdm(loader)
    for imgs, targets in pbar:
        imgs = imgs.cuda()
        v = model(imgs)
        vNorm = v.norm(dim=-1)
        pred = vNorm.argmax(dim=-1)

        accuracy += (pred == targets.cuda()).sum()
        count += targets.shape[0]
        pbar.set_postfix({'accuracy': (accuracy.float()/count).item() })
    return accuracy.item(), count, (accuracy.float()/count).item()



def mem(msg=None):
    print(msg, end=', ')
    print(torch.cuda.memory_allocated(device=None) / (1e9), 'GB')

    
    
def supSample(x, side, stride):
    
    x0,x1 = x.shape
    outSide = x0*stride+(side-1)
    out = torch.zeros(outSide,outSide)
    
    kern = np.zeros([side,side])
    kern[side//2, side//2] = 1
    kern = scipy.ndimage.filters.gaussian_filter(kern, 2)
    kern = torch.tensor(kern).float().cuda()
#     print(kern.shape)
    for i in range(x0):
        xStart, xStop = i*stride, i*stride+side
        for j in range(x1):
            yStart, yStop = j*stride, j*stride+side
            out[xStart:xStop, yStart:yStop] += kern*x[i,j]
    return out


mem('CUDA memory usage')
plt.style.use('ggplot')
plt.style.use('seaborn-colorblind')

CUDA memory usage, 2.315244032 GB


In [None]:
# x = torch.rand(4,4)
# out1 = supSample(x, 9, 2)
# out2 = supSample(out1, 9, 1)
# plt.subplot(131)
# plt.imshow(x)
# plt.subplot(132)
# plt.imshow(out1)
# plt.subplot(133)
# plt.imshow(out2)

## Dataset

In [65]:
# transform_mean = (0.1307,)
# transform_std = (0.3081,)
transform = transforms.Compose([
    transforms.RandomAffine(0, [2/28,2/28]), ##augment dataset
    transforms.ToTensor(),
#     transforms.Normalize(transform_mean, transform_std),
])


# dataset_name = 'mnist'
# train_dataset = datasets.MNIST('./dataset/'+dataset_name, train=True, download=True, transform=transform)
# test_dataset = datasets.MNIST('./dataset/'+dataset_name, train=False, download=True, transform=transform)


# dataset_name = 'fashion-mnist'
# train_dataset = datasets.FashionMNIST('./dataset/'+dataset_name, train=True, download=True, transform=transform)
# test_dataset = datasets.FashionMNIST('./dataset/'+dataset_name, train=False, download=True, transform=transform)

dataset_name = 'cifar10'
train_dataset = datasets.CIFAR10('./dataset/'+dataset_name, train=True, download=True, transform=transform)
test_dataset = datasets.CIFAR10('./dataset/'+dataset_name, train=False, download=True, transform=transform)

Files already downloaded and verified
Files already downloaded and verified


## Model

In [72]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# model = MnistNet(dynamicRountingIter=5).train()
model = Cifar10Net(dynamicRountingIter=5).train()
parameters = model.parameters()

In [73]:
!rm -r fig
!mkdir -p fig/loss/png
!mkdir -p fig/loss/svg
!mkdir -p fig/dynamic-routing
!mkdir -p fig/reconstruct

In [74]:
dataset_name

'cifar10'

In [75]:
!rm -r capture/cifar10
!mkdir -p capture/cifar10


In [76]:
parameters = model.parameters()
# optimizer = optim.SGD(parameters, lr=1e-4, momentum=0.99)
optimizer = optim.Adam(parameters, lr=1e-4)
# optimizer = optim.Adam(parameters, lr=5e-5)


In [77]:
# %%time

if model.__class__ == MnistNet:
    img_shape = (1,28,28)
elif model.__class__ == Cifar10Net:
    img_shape = (3,32,32)
            
mPlus = 0.9
mMinus = 0.1
def marginLoss(v, target):
    # v=[#input, 10,16]
    
    #[#input, 10]
    vNorm = v.norm(dim=-1)
    left = F.relu(mPlus-vNorm)**2
    right = F.relu(vNorm-mMinus)**2
    l1 = target * left
    l2 = 0.5 * (1.0 - target) * right    
    return (l1+l2).sum(dim=1).mean()



mse = nn.MSELoss()
l1 = nn.L1Loss()
def reconstructionLoss(reconstructed, original):
#     return mse(reconstructed, original)
    return 0.8*l1(reconstructed, original)+mse(reconstructed, original)


lossHistory = []
accuracyHistory = []
nepoch = 200

gs = gridspec.GridSpec(2, 5, width_ratios=[4,1,1,1,1])
gs.update(left=0.05, right=0.95, wspace=0.01, hspace=0.5)

gs2 = gridspec.GridSpec(3, 5, width_ratios=[1]*5, height_ratios=[1,1,1])
gs2.update(left=0.05, right=0.95, wspace=0.1, hspace=0.2)

gs3 = gridspec.GridSpec(2, 10, width_ratios=[1]*10, height_ratios=[3,19])
gs3.update(left=0.05, right=0.95, wspace=0.1, hspace=0.05)


k = -1
t0 = time()

shouldShow = False

for epoch in tqdm(range(nepoch)):
    pbar = tqdm(train_loader)
    for imgs, targets in pbar:
        k+=1
        imgs = imgs.cuda()
        targets = targets.cuda()
        targets_onehot = torch.eye(10).index_select(dim=0, index=targets)
        v = model(imgs)
        
        reconstructed = model.reconstruct(v, targets).view(-1,*img_shape)

        loss = marginLoss(v, targets_onehot)
#             loss += 0.0005 * reconstructionLoss(reconstructed, imgs)
#             loss += 0.005*reconstructionLoss(reconstructed, imgs)
        loss += 5*reconstructionLoss(reconstructed, imgs)

        model.zero_grad()
        loss.backward()
        optimizer.step()
#             scheduler.step(loss)        
        lossHistory.append(np.log(loss.item()))

        pbar.set_postfix({'loss': loss.item()})
        if k%100==99:
#         if k%10==9:
            with torch.no_grad():
                if shouldShow:
                    display.clear_output(wait=True)

                fig = plt.figure(figsize=[25,6])
                plt.subplot(gs[:,0])
                plt.plot(lossHistory)

                t = time()-t0
                plt.title('training time: {:.2f} sec\nepoch = {}, loss = {:.2g}, lr={}'\
                          .format(t, epoch, loss.item(), optimizer.param_groups[0]['lr']))

                #show margin loss
                vNorm = v.norm(dim=-1)

                plt.subplot(gs[:,1])
                plt.imshow(vNorm[:30].detach().cpu().numpy(), vmin=0)
                plt.colorbar()
                plt.title('class prediction')

                plt.subplot(gs[:,2])
                plt.imshow(targets_onehot[:30].detach().cpu().numpy(), vmin=0)
                plt.colorbar()
                plt.title('ground truth')

                plt.subplot(gs[0,3])
                c = model.p2d.c[0:2, :20, :].cpu().numpy()
                plt.imshow(c[0], vmin=0)
                plt.colorbar()
                plt.title('dynamic routing')
                plt.subplot(gs[1,3])
                plt.imshow(c[1], vmin=0)
                plt.colorbar()

                plt.subplot(gs[0,4])
                plt.imshow(reconstructed[0].permute(1,2,0).detach().cpu().numpy(), vmin=0, vmax=1)
                plt.axis('off')
                plt.colorbar()
                plt.title('reconstruction')
                plt.subplot(gs[1,4])
                plt.imshow(reconstructed[1].permute(1,2,0).detach().cpu().numpy(), vmin=0, vmax=1)
                plt.axis('off')
                plt.colorbar()

                if shouldShow:
                    display.display(plt.gcf())
                plt.savefig('fig/loss/png/loss-{}.png'.format(epoch), dpi=fig.dpi)
                plt.savefig('fig/loss/svg/loss-{}.svg'.format(epoch), dpi=fig.dpi)
                plt.close()

                ## show reconstruction
                orig = imgs
                act = model(orig)
                recon = model.reconstruct(act, targets).reshape([len(targets),*img_shape])
                im = torch.stack([orig, recon], dim=1).view([-1,*img_shape])
                im = make_grid(im.cpu(), nrow=len(targets)//4, padding=2, pad_value=1, normalize=True).numpy().transpose([1,2,0])
                # orig = make_grid(orig.cpu(), nrow=10, padding=5, pad_value=1, normalize=True).numpy().transpose([1,2,0])
                plt.figure(figsize=[32,8])
                # plt.subplot(121)
                plt.imshow(im, vmin=0, vmax=1)
                plt.axis('off')
                if shouldShow:
                    display.display(plt.gcf())
                imageio.imwrite('fig/reconstruct/recon-{}.png'.format(epoch), (im*255).astype('uint8'))
                plt.close()

                # image example
#                     fig = plt.figure(figsize=[18,10])
#                     plt.subplot(gs2[0,:])
#                     plt.imshow(imgs[0,0].cpu().numpy())
#                     plt.axis('off')

                # option1: dynamic rounting coefficients
#                     c = model.p2d.c[0, :, :].cpu().numpy()
#                     c = c.reshape([6,6,32,10]).sum(axis=2)
#                     vmax = c.max()
#                     for i in range(10):
#                         plt.subplot(gs2[1+i//5,i%5])
#                         plt.imshow(c[:,:,i], vmin=0, vmax=vmax)
#                         plt.title(str(i))
#                         plt.axis('off')

                ## option2: summands of instanciation vectors
#                     uHatC = model.p2d.uHatC[0]
#                     label = targets[0].item()
# #                     ax = plt.subplot(gs2[1,label])
#                     for i in range(10):
#                         plt.subplot(gs2[1+i//5,i%5])
#                         plt.scatter(uHatC[:,i,0], uHatC[:,i,1], s=4, 
#                                     c='C1' if i==label else 'C0')
#                         plt.title('{}: {:.4f}'.format(i, uHatC[:,i,:].sum(dim=0).norm()))


                ## option3: norm of instanciation vectors (attension)
#                     uHatC = model.p2d.uHatC[0].view(6,6,32,10,16)
#                     norm = uHatC.norm(dim=2).norm(dim=-1)
# #                     norm = uHatC[:,:,0].norm(dim=-1)
#                     attentions = [supSample(supSample(norm[:,:,i],9,2),9,1) for i in range(10)]
#                     vmax = torch.max(torch.stack(attentions)).item()
#                     label = targets[0].item()
#                     for i in range(10):

#                         plt.subplot(gs2[1+i//5,i%5])
#                         plt.imshow(imgs[0,0].cpu().numpy(),cmap='gray')

#                         plt.imshow(attentions[i].cpu(),
#                                    alpha=0.9,
#                                    vmax=vmax, vmin=0)
#                         plt.title('{}'.format(i))




#                     ## image example
#                     fig = plt.figure(figsize=[6,22])
#                     plt.subplot(gs3[0,:])
#                     plt.imshow(imgs[0,0].cpu().numpy())
#                     plt.axis('off')
# #                     ## dynamic rounting coefficients
#                     c = model.p2d.c[0, :, :].cpu().numpy()
#                     c = c.reshape([-1,6,10])
#                     vmax = c.max()
#                     for i in range(10):
#                         plt.subplot(gs3[1,i])
#                         plt.imshow(c[:,:,i], vmin=0, vmax=vmax)
#                         plt.title(str(i))
#                         plt.axis('off')

#                     if shouldShow:
#                         display.display(plt.gcf())
#                     plt.savefig('fig/dynamic-routing/dr{}.png'.format(epoch), dpi=fig.dpi)
#                     plt.close()


    torch.save(model.state_dict(), 'capture/{}/model.{}'.format(dataset_name, epoch))
    _,_,acc = test(model, test_loader)
    
    
    ## shrink learning rate
    if epoch % 2 == 1:
        for pg in optimizer.param_groups:
            pg['lr'] *= 0.98
        
        
        
        
torch.save(model.state_dict(), 'capture/{}/model.{}'.format(dataset_name, 999))
            

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

HBox(children=(IntProgress(value=0, max=157), HTML(value='')))

HBox(children=(IntProgress(value=0, max=391), HTML(value='')))

KeyboardInterrupt: 

In [40]:
dataset_name

'cifar10'

In [None]:
torch.save(model.state_dict(), 'capture/{}/model.{}'.format(dataset_name, 199))


## test

In [None]:
model = MnistNet(dynamicRountingIter=5)
model.load_state_dict(torch.load('capture/{}/model.{}'.format(dataset_name, 199)))

In [None]:
def test(model, loader):
    model.eval()
    accuracy = 0
    count = 0
    pbar = tqdm(loader)
    for imgs, targets in pbar:
        imgs = imgs.cuda()
        v = model(imgs)
        vNorm = v.norm(dim=-1)
        pred = vNorm.argmax(dim=-1)

        accuracy += (pred == targets.cuda()).sum()
        count += targets.shape[0]
        pbar.set_postfix({'accuracy': (accuracy.float()/count).item() })
    return accuracy.item(), count, (accuracy.float()/count).item()

test(model, test_loader)

## vis dynamic routing

In [None]:

for imgs, targets in train_loader:
    imgs = imgs.cuda()
    v = model(imgs)
    break

egIndex = 0
digitIndex = targets[egIndex].item()
uHat = model.p2d.uHat[egIndex].detach().cpu().numpy()
plt.imshow(imgs[egIndex,0].cpu())
for epoch in range(100,101):
    print(epoch)
    model.load_state_dict(torch.load('capture/MnistNet.model.{}'.format(epoch)))
    model(imgs[:1])
    uHat = model.p2d.uHat[egIndex].detach().cpu().numpy()
    cs = model.p2d.cs
    
    c = cs[-1]
    for c in cs:
        ceg = c[egIndex].detach().cpu().numpy()
        plt.figure(figsize=[46,16])
        for digitIndex in range(10):
            if digitIndex == 0:
                ax = plt.subplot(1,10,digitIndex+1)
            else:
                plt.subplot(1,10,digitIndex+1, sharex=ax, sharey=ax)
            c = ceg[:,digitIndex]
            xy = uHat[:,digitIndex,:]
            centroid = np.sum(xy*c.reshape([-1,1]), axis=0, keepdims=True)

            pca = PCA()
            pca.fit(xy)
            xy = pca.transform(xy)
            centroid = pca.transform(centroid)

            dim0,dim1 = 0,1
            plt.plot([0,centroid[0,dim0]], [0,centroid[0,dim1]], 'C1:.', zorder=1)
            plt.scatter(xy[:,dim0], xy[:,dim1], s=c**2*10, zorder=2)
            plt.title(digitIndex)

        #     vsum = np.zeros([16])
        #     for v,ci in sorted(zip(xy,c), key=lambda x:-np.linalg.norm(x[0])*x[1])[:10]:
        #         plt.arrow(vsum[dim0], vsum[dim1], v[dim0]*ci, v[dim1]*ci, color='C1')
        #         vsum += v*ci

            plt.axis('square')
        plt.show()


## Analyze digitCap statistics

In [None]:
# model = MnistNet(dynamicRountingIter=5)
# model.load_state_dict(torch.load('capture/MnistNet.model.100'))

In [None]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

model.eval()
digitCapActivation1 = []
targets1 = []
imgs1 = []
for imgs, targets in tqdm(test_loader):
    imgs, targets = imgs.cuda(), targets.cuda()   
    with torch.no_grad():
        v = model(imgs)
        imgs1.append(imgs)
        digitCapActivation1.append( v )
        targets1.append(targets)

targets1 = torch.cat(targets1)
digitCapActivation1 = torch.cat(digitCapActivation1, dim=0)
imgs1 = torch.cat(imgs1)

## show reconstruction

In [None]:
indices = [i for i in range(0,1000)]
nexample = len(indices)

orig = imgs1[indices]
with torch.no_grad():
    act = digitCapActivation1[indices]
    recon = model.reconstruct(act, targets1[indices]).reshape([nexample,1,28,28])

im = torch.stack([orig, recon], dim=1).view([-1,1,28,28])
im = make_grid(im.cpu(), nrow=20, padding=2, pad_value=1, normalize=True).numpy().transpose([1,2,0])
# orig = make_grid(orig.cpu(), nrow=10, padding=5, pad_value=1, normalize=True).numpy().transpose([1,2,0])

plt.figure(figsize=[20,50])
# plt.subplot(121)
plt.imshow(im, vmin=0, vmax=1)
plt.axis('off')
# plt.subplot(122)
# plt.imshow(recon, vmin=0, vmax=1)
# #     plt.axis('off')

## reconstruction by epcoh

In [None]:
indices = [i for i in range(0,20)]
epochs = [0,1,2,4,10,100,199]
nexample = len(indices)
nrow = nexample
padding = 2

gs = GridSpec(nrows=len(epochs)+1, ncols=1)
# gs.update(hspace=0.0, wspace=0)
plt.figure(figsize=[30,12])

orig = imgs1[indices]



model = MnistNet(dynamicRountingIter=5)
for i,epoch in enumerate(epochs):
    model.load_state_dict(torch.load('capture/MnistNet.model.{}'.format(epoch)))
    with torch.no_grad():
        act = model(orig)
        recon = model.reconstruct(act, targets1[indices]).reshape([nexample,1,28,28])

#     im = torch.stack([recon], dim=1).view([-1,1,28,28])
    im = make_grid(recon.cpu(), nrow=nrow, padding=padding, pad_value=1, normalize=True).numpy().transpose([1,2,0])
    plt.subplot(gs[i,0])
    plt.imshow(im, vmin=0, vmax=1)
    plt.xticks([])
    plt.yticks([])
    plt.ylabel('Ecoch {}'.format(epoch))

im = make_grid(orig.cpu(), nrow=nrow, padding=padding, pad_value=1, normalize=True).numpy().transpose([1,2,0])
plt.subplot(gs[-1,0])
plt.imshow(im, vmin=0, vmax=1)
plt.xticks([])
plt.yticks([])
plt.ylabel('Original')
plt.show()

## Histogram / KDE

In [None]:
digitCapActivation2 = digitCapActivation1.cpu().numpy()
targets2 = targets1.cpu().numpy()
imgs2 = imgs1.cpu().numpy()

In [None]:
model.load_state_dict(torch.load('capture/MnistNet.model.{}'.format(199)))

In [None]:
# plt.style.use('seaborn-colorblind')
gs = gridspec.GridSpec(5,2)
gs.update(left=0.05, right=0.95, wspace=0.1, hspace=0.1)
fig = plt.figure(figsize=[20, 40])

drange = np.zeros([10,16,2])
dmeanstd = np.zeros([10,16,2])

ax = plt.subplot(gs[0,0])



for i in range(10):
    d = digitCapActivation2[targets2==i]
    
    dmean = np.mean(d[:,i,:], axis=0, keepdims=True)
    dstd = np.std(d[:,i,:], axis=0, keepdims=True)
    dmeanstd[i,:,0] = dmean
    dmeanstd[i,:,1] = dstd
    dnorm = d#(d-dmean) / dstd

#     dmin = d[:,i,:].min(axis=0, keepdims=True)
#     dmax = d[:,i,:].max(axis=0, keepdims=True)
#     drange[i,:,0] = dmin
#     drange[i,:,1] = dmax
#     dnorm = (d-dmin) / (dmax-dmin)

    plt.subplot(gs[i//2, i%2], sharex=ax)
    for j in range(16):
        seaborn.kdeplot(dnorm[:,i,j], bw=0.2, gridsize=1000, cumulative=False)
#     plt.xlim([-2,2])
plt.show()


## generative model

In [None]:
meanstd = torch.from_numpy(dmeanstd).float()
meanstd = meanstd.cuda()

In [None]:
targetClass = 3
nexample = 20
plt.figure(figsize=[20,10])
gs = gridspec.GridSpec(nrows=10, ncols=1)
gs.update(hspace=0)
for targetClass in range(10):
    sample = torch.zeros(nexample,10,16)
    sample[:,targetClass,:] = torch.randn(nexample,16) * 0.4
    sample[:,targetClass,:] = sample[:,targetClass,:] * meanstd[targetClass,:,1] + meanstd[targetClass,:,0]
    with torch.no_grad():
        recon = model.reconstruct(sample).view(-1,1,28,28)
    recon = make_grid(recon, nrow=nexample,normalize=True)
    recon = recon.cpu().numpy().transpose([1,2,0])
    plt.subplot(gs[targetClass,0])
    plt.imshow(recon)
    plt.axis('off')
plt.show()

## scatter plot, two classes

In [None]:
digit0,dim0 = 0,3
digit1,dim1 = 1,0
c = plt.cm.tab10(targets2)
plt.figure(figsize=[10,10])
plt.scatter(digitCapActivation2[:,digit0,dim0],digitCapActivation2[:,digit1,dim1], 
        alpha=0.99, s=20,
        facecolors=c, edgecolors=c,
        cmap='tab10')
plt.xlabel('dim{} of digit{} capsule'.format(dim0, digit0))
plt.ylabel('dim{} of digit{} capsule'.format(dim1, digit1))

## scatter plot, single class

In [None]:
targetClass = 5

d = digitCapActivation2[targets2==targetClass]
c = plt.cm.tab10(targets2[targets2==targetClass])
im = imgs2[targets2==targetClass]

proj = np.random.randn(16, 2)
coord = np.dot(d[:,targetClass,:], proj)
imgSize = np.abs(coord).max() / 30
plt.figure(figsize=[10,10])

## show image
# for i,[x,y] in enumerate(coord[:150]):
#     plt.imshow(im[i,0], extent=[x, x+imgSize, y, y+imgSize], cmap='Greys')
    
plt.scatter(coord[:,0],coord[:,1], 
        alpha=0.6, s=10,
        facecolors=c, edgecolors='none',
        cmap='tab10')

# plt.xlim([-0.025,0.025])
# plt.ylim([-0.025,0.025])

In [None]:
%%time

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE, MDS
# from umap import UMAP

# model = UMAP(n_neighbors=10, min_dist=0.4)
npoint = 10000

digitCapActivation3 = digitCapActivation2[:npoint,:,:]
digitCapActivation3 = digitCapActivation3.reshape(digitCapActivation3.shape[0], -1)
# digitCapActivation3 = model.fit_transform(digitCapActivation3)
imgs3 = imgs1.cpu().numpy()


## random projection
# digitCapActivation3 = digitCapActivation2[:,6,:].reshape([digitCapActivation2.shape[0], -1])
# proj = np.random.randn(digitCapActivation3.shape[1], 2)
# digitCapActivation3 = np.dot(digitCapActivation3, proj)

In [None]:
# seaborn.kdeplot(digitCapActivation1[:,0,6], gridsize=1000)
plt.figure(figsize=[20,20])

c = plt.cm.tab10(targets2)
# m = plt.cm.ScalarMappable(cmap=plt.cm.tab10)

# for i,[x,y] in enumerate(digitCapActivation3[:300]):
#     plt.imshow(imgs3[i,0], extent=[x, x+0.3, y, y+0.3])
    
# plt.scatter(digitCapActivation3[:npoint,0],
#             digitCapActivation3[:npoint,1], 
#             alpha=0.7,
#             facecolors=c[:npoint], edgecolors='none', cmap='tab10',
#             s=10)
plt.scatter(digitCapActivation2[:npoint,4,0],
            digitCapActivation2[:npoint,8,1], 
            alpha=0.5,
            facecolors=c[:npoint], cmap='tab10',
            s=20)
plt.axis('equal')
# plt.colorbar(m)
# plt.savefig('instantiation-parameter.svg')
# plt.xlim([-0.02, 0.02])
plt.show()

## grand tour

In [None]:
def mix(a,b, p):
    return (1-p)*a + p*b


def embed(x, canvas):
    canvas[:x.shape[0],:x.shape[1]] = x
    return canvas


def rot(ndim, dim1, dim2, theta):
    res = np.eye(ndim)
    res[dim1, dim1] = np.cos(theta)
    res[dim2, dim1] = -np.sin(theta)
    res[dim1, dim2] = np.sin(theta)
    res[dim2, dim2] = np.cos(theta)
    return res


def getMatrix(baseMatrix, d_thetas, stop_dim=10):
    n = 0
    ndim = baseMatrix.shape[0]
    matrix = baseMatrix
    for i in range(ndim):
        for j in range(ndim):
            if i!=j:
                rot_ij = rot(ndim, i, j, d_thetas[i,j])
                matrix = matrix.dot(rot_ij)
    return matrix


class GrandTour():
    def __init__(self, ndim=10, stepsize=0.01, like=None):
        self.ndim = ndim
        self.stepsize = stepsize
        if like is not None:
            if ndim >= like.ndim:
                self.lambdas = np.random.rand(ndim,ndim) * 2*np.pi
                self.lambdas[:like.ndim, :like.ndim] = like.lambdas
                self.matrix = embed(like.matrix, np.eye(ndim))
            else:
                self.lambdas = like.lambdas[:ndim, :ndim]
                self.matrix = like.matrix[:ndim, :ndim]
        else:
            self.lambdas = np.random.rand(ndim,ndim)
            self.matrix = np.eye(ndim)
            
        self.thetas = np.zeros_like(self.lambdas)

    def proj(self, data, t=1):
        thetas = self.lambdas * self.stepsize * t
#         self.thetas = self.thetas + d_thetas
        self.matrix = getMatrix(np.eye(ndim), thetas, stop_dim=2)
        data_t = data.dot(self.matrix[:,:2])
        return data_t
    


class Animation():

    def __init__(self, data_seq,
                 figsize=[5,5], 
                 xlim=None, ylim=None):
        self.data_seq = data_seq
        self.fig = plt.figure(figsize=figsize)
        self.ax = self.fig.add_subplot(111)
        self.ax.axis('square')
        
        self.vmax = np.max(np.abs(data_seq))
        if xlim is None:
            xlim = [-self.vmax*1.05, self.vmax*1.05]
        if ylim is None:
            ylim = [-self.vmax*1.05, self.vmax*1.05]
            
        self.ax.set_xlim(xlim)
        self.ax.set_ylim(ylim)
        self.scatter = self.ax.scatter([], [])
    
    # initialization function: plot the background of each frame
    def anim_init(self):
        def init():
            self.scatter.set_offsets(np.zeros([10,2]))
    #       self.scatter.set_array(labels)
            self.scatter.set_cmap('Set1')
            return (self.scatter,)
        return init
    
    def anim_draw(self):
        def draw(i):
            xy = self.data_seq[i]
            self.scatter.set_offsets(xy)
            return (self.scatter,)
        return draw
    
    def set_color(self, c):
        self.scatter.set_color(c)
        
    def set_size(self, s):
        self.scatter.set_sizes(s)
        
        
    def make(self, duration=2000):
        anim = animation.FuncAnimation(self.fig, 
                                       self.anim_draw(), 
                                       init_func=self.anim_init(),
                                       frames=len(self.data_seq), 
                                       interval=duration / len(self.data_seq), 
                                       blit=True)
        return anim

In [None]:
npoint = 2000
data = digitCapActivation2[:npoint].reshape([-1,160])
data = data / np.linalg.norm(data, axis=0, keepdims=True)
# data = data[:,16*1:16*3]
c = plt.cm.tab10(targets2)

animations = []
gt = None
data_seq = []
ndim = data.shape[1]
gt = GrandTour(ndim, stepsize=0.003)

# gt
nframe = 1500
for i in tqdm(range(nframe)):
    points = gt.proj(data, t=i)
    data_seq.append(points)
    



In [None]:
plt.rc('animation', embed_limit=100)
anim = Animation(data_seq)
anim.set_color(c[:npoint])
anim.set_size([2 for _ in range(npoint)])
a = anim.make(nframe/60 * 1000)
HTML(a.to_jshtml())  

## convert to Keras model, then dump as tensorflowjs

In [None]:
# drange.reshape(160,2).tolist() 
# dmeanstd.reshape(160,2).tolist() 
import json
with open('dmeanstd.js' ,'w') as f:
    f.write('let dmeanstd = \n')
    json.dump(dmeanstd.reshape(160,2).tolist(), f, indent=2)

In [None]:
import tensorflow.keras as keras
from keras.layers import Lambda
from keras import backend as K
import tensorflowjs as tfjs

# drange = np.fromfile('drange.bin').reshape([10,16,2])
# input1 = keras.layers.Input(shape=(160,))
# dmin = K.constant(drange[:,:,0].reshape(-1), name='dmin_const')
# dmax = K.constant(drange[:,:,1].reshape(-1), name='dmax_const')
# y0 = Lambda(lambda input1: input1 *(dmax-dmin) + dmin, 
#             name='unstandardize')(input1)
# undunstandardize = keras.models.Model(inputs=input1, outputs=y0 )

model_keras = keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(160,)),
    keras.layers.Dense(1024, activation='relu'),
    keras.layers.Dense(784, activation='sigmoid'),
])



[w.shape for w in model_keras.get_weights()]

model.load_state_dict(torch.load('./capture/MnistNet.model.199'))
weights = [[model.reconstruct.layers[i].weight,model.reconstruct.layers[i].bias]  for i in [0,2,4]]
weights = sum(weights, [])
weights = [w.data.cpu().numpy().T for w in weights]

model_keras.set_weights(weights)
model_keras.save('model_keras.h5')
tfjs.converters.save_keras_model(model_keras, 'reconstruct_tfjs/')

In [None]:
!zip -r reconstruct.zip reconstruct_tfjs/

In [None]:
# !tensorflowjs_converter --input_format keras \
                       model_keras.h5 \
                       model_tfjs/

In [None]:
# model_torch = model.reconstruct

# dummy_input = torch.randn(1, 160)
# torch.onnx.export(model_torch, dummy_input, 'MnistNet.onnx')

# model_onnx = onnx.load('MnistNet.onnx')
# model_tf = prepare(model_onnx)
# model_tf.export_graph('MnistNet.pb')

In [None]:
# dir(model_tf)
# model_tf.tensor_dict
# dir(model_tf.graph)

# input1 = model_tf.graph.get_operation_by_name('input.1').outputs[0]
# output = model_tf.graph.get_operation_by_name('Sigmoid')
# model_tf.graph.get_operations()
# model_tf.graph.get_name_scope()
# # help(tf.random_normal)

# model_tf.graph.is_feedable('input.1')

In [None]:
# import tensorflow as tf


# input1 = model_tf.graph.get_operation_by_name('input.1').outputs[0]
# output = model_tf.graph.get_operation_by_name('Sigmoid').outputs[0]

# with model_tf.graph.as_default(), tf.Session() as sess:
#     out = sess.run('Sigmoid:0', feed_dict={'input.1:0': np.random.randn(1,160)})
#     print(out.shape)

In [None]:
# !tensorflowjs_converter \
#     --input_format=tf_saved_model \
#     --output_format=tfjs_graph_model \
#     MnistNet.pb \
#     tfjsModel/

# !tensorflowjs_converter \
#     --input_format=tf_saved_model \
#     --output_format=tfjs_graph_model\
#     checkpoints/saved_model.pb \
#     ./web_model

In [None]:
nrow, ncol = 10,10

gs = gridspec.GridSpec(nrow, ncol)
gs.update(left=0.05, right=0.95, wspace=0.01, hspace=0.0)
fig = plt.figure(figsize=[ncol, nrow])
for imgs, targets in train_loader:
    imgs = imgs.cuda()
    targets = targets.cuda()
    v = model(imgs)
    reconstructed = model.reconstruct(v, targets).view(-1,1,28,28)
    
    for i, (img0,img1) in enumerate(zip(imgs, reconstructed)):
        img0 = img0[0]
        img1 = img1[0]
        plt.subplot(gs[(i//ncol)*2, i%ncol])
        plt.imshow(img0.detach().cpu().numpy(), vmin=0, vmax=1)
        plt.axis('off')
        plt.subplot(gs[(i//ncol)*2+1, i%ncol])
        plt.imshow(img1.detach().cpu().numpy(), vmin=0, vmax=1)
        plt.axis('off')
        if i >= ncol*nrow//2-1:
            break
    break

## accuracyHistory

In [None]:
accuracyHistory = []
model.eval()
for fn in tqdm(natsorted(glob('./capture/MnistNet.model.*'))):
    model.load_state_dict(torch.load(fn))
#     print(fn)
    _,_,acc = test(model, test_loader)
    accuracyHistory.append(acc)
   

In [None]:
plt.plot(accuracyHistory, '.-')
plt.savefig('accuracy.svg')


In [None]:
c = model.cap2d.c.cpu().numpy()

# plt.figure(figsize=[c.shape[1],c.shape[0]])
plt.imshow(c[:100,:])
# for i,row in enumerate(c):
#     for j, entry in enumerate(row):
#         plt.text(j,i,
#                  '{:.1g}'.format(entry),
#                  ha='center', va='center',
#                  color=[1,1,1])
plt.colorbar()
plt.show()

In [None]:
for imgs, targets in test_loader:
    imgs = imgs.cuda()
    
    break
v = model(imgs)
vNorm = v.norm(dim=-1)

imgs1 = imgs.detach().cpu().numpy()
targets1 = targets.detach().cpu().numpy()
v1 = v.detach().cpu().numpy()
vNorm1 = vNorm.detach().cpu().numpy()
# plt.imshow(vNorm1)
# plt.colorbar()

for i in range(10):
    plt.figure(figsize=[6,2])
    plt.subplot(121)
    plt.imshow(imgs1[i,0])
    plt.subplot(122)
    plt.stem(vNorm1[i,:])
    plt.xticks(range(10))
    plt.show()

In [None]:
plt.scatter(vNorm1[:,1], vNorm1[:,7], c=targets1, cmap='tab10')
plt.colorbar()

## Train a klein bottle

In [None]:
##model def:
class KleinNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.cap1 = Capsule(2,1, 4,20)
        self.cap2 = Capsule(4,20, 4,20)
        self.squash = Squash()
        self.fc1 = nn.Linear(self.cap2.nCapOut*self.cap2.dimCapOut, 20)
        self.fc2 = nn.Linear(20, 4)
        
        self.relu = nn.ReLU()
        self.tanhshrink = nn.Tanhshrink()
        self.tanh = nn.Tanh()
        self.celu = nn.CELU()
        self.act = self.tanhshrink
        
    def forward(self, x):
        v1 = self.cap1(x)
#         z1 = self.squash(self.act(v1))
        z1 = self.act(v1)
        v2 = self.cap2(z1)
#         z2 = self.squash(self.act(v2))
        z2 = self.act(v2)
        
        y1 = self.act(self.fc1(z2.view(x.shape[0], -1)))
        y2 = self.fc2(y1)
        
        self.v1 = v1
        self.v2 = v2
        return y2
    
    def dynamicRouting(self):
        self.cap1.dynamicRouting(self.v1)
        self.cap2.dynamicRouting(self.v2)

In [None]:
cap = Capsule(3,2,4,2)
squash = Squash()
for i in range(10):
    s = cap(torch.rand(10,3,2))
    v = squash(s)
    cap.dynamicRouting(v)
cap.register_buffer('b', cap.b)

In [None]:
model = KleinNet()

x = (torch.rand(10001,2,1)-0.5) * 3.0 * np.pi
y = klein(x[:,0,0], x[:,1,0])

In [None]:
##loss
mse = nn.MSELoss()
l1 = nn.L1Loss()


##optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

lossHistory = []
nepoch = 1000
for i in range(nepoch):
    yPred = model(x)
    loss = mse(yPred, y) + l1(yPred, y)
    model.zero_grad()
    loss.backward()
    optimizer.step()
    model.dynamicRouting()
    lossHistory.append(np.log(loss.item()))
    if i%(nepoch//10)==nepoch//10-1:
        plt.figure(figsize=[10,5])
        plt.subplot(121)
        plt.plot(lossHistory)
        plt.title(loss.item())
        
        dim1,dim2 = 0,1
        yPred1 = yPred.cpu().detach().numpy()
        y1 = y.cpu().detach().numpy()
        plt.subplot(122)
        plt.plot(y1[:,dim1], y1[:,dim2], '.', color='grey', markersize=1, alpha=0.1)
        plt.scatter(yPred1[:,dim1], yPred1[:,dim2], s=2)
        
        plt.show()

In [None]:

yPred1 = yPred.cpu().detach().numpy()
y1 = y.cpu().detach().numpy()

for dim1 in range(4):
    for dim2 in range(dim1+1, 4):
        plt.figure(figsize=[5,5])
        plt.plot(y1[:,dim1], y1[:,dim2], '.', color='C1', markersize=1, alpha=1)
        plt.scatter(yPred1[:,dim1], yPred1[:,dim2], s=2)
        plt.show()