# Intro to CapsNet

My implementation of capsnet from the original paper: https://arxiv.org/pdf/1710.09829.pdf by sara sabour, nick frosst and geoff hinton.

My goal is to understand the dynamic routing algorithm, implement it for the simplest capsnet. I want an implementation that is understandable to people in a mid-tier skill level. I think we can learn a lot by visualizing and examinign what capsnet does at each layer so that's what I'm going to do in this notebook. In the first section values will be hardcoded. In the after section we'll productionize our capsnet architecture a bit to make things a bit more flexible.

Architecture:

Input image -> ReLU Conv -> Primary "Caps" -> Digit Caps -> vector norm

The new things here is the primary caps which are just conv layers whose output is passed through the squash function (instead of relu) to prepare them for the actual caps, the digit caps which are a layer of actual capsules.

There are 3 dense layers at the end, one fully connected relu, another fully conected relu, and a fully connected sigmoid.

A Quick note:

CapsNet is brand new (1 month old) if you understand why something works, or you discover the reason it works. Let people know. Write a blog, or reach out to me and we can write one together. Everyone is still trying to understand why things work in neural nets and capsnet is no exception. Don't hoard your discoveries; they could improve the world. The barrier to discovery is very low, you could make the next great finding.

In [2]:
import os
import numpy as np
import torch
import torch.nn.functional as F

from torch import nn
from torch.utils import data
from torch.autograd import Variable
from torch.optim import Adam
from torchvision import transforms, datasets

In [3]:
torch.cuda.is_available(), torch.__version__

(True, '0.2.0_4')

In [4]:
def to_variable(x):
    if torch.cuda.is_available():
        x = x.cuda()
    return Variable(x)

Load data.

In [6]:
mnist_img = os.path.join('/scratch', 'yns207', 'imgs', 'mnist')
mnist_transform = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize((0.1307,), (0.3081,))
                                     ])
mnist_dataset = datasets.MNIST(mnist_img, 
                               download=True, 
                               transform=mnist_transform)
mnist_loader = data.DataLoader(dataset=mnist_dataset,
                               batch_size=64,
                               shuffle=True,
                               num_workers=2)

Let's first try to understand what goes into the capsules. Let's make a network up to the primary caps and see what the output of the previous layer is supposed to be.

In [7]:
class PrimaryCapsIn(nn.Module):
    def __init__(self):
        super(PrimaryCapsIn, self).__init__()
        
        self.conv = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9, stride=1)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, inputs):
        return self.relu(self.conv(inputs))

As per the paper thes is a convolutional layer with kernel size  and stride 1, with 1 black/white input channel for balck white mnist and 256 output channels (or feature maps) in the conv layer. let's now laod the data and observe the shape on the other end of this layer. This should be a gentle intro to the capsnet as most people interested in this will already be familiar with convolutional nets.

In [8]:
primary_caps_in = PrimaryCapsIn()

if torch.cuda.is_available():
    primary_caps_in.cuda()

In [9]:
n_epochs = 1
for epoch in range(n_epochs):
    for image_batch in mnist_loader:
        # get the actual batch, toss metadata
        image_batch = image_batch[0]
        # get the actual size of this batch
        # for uneven datasets it may not be 64
        # on the last batch
        batch_size = image_batch.shape[0]
        # convert to pytorch variable
        image_batch = to_variable(image_batch)
        # plug through our network
        outputs = primary_caps_in(image_batch)
        print(outputs.data.shape)
        break
    break

torch.Size([64, 256, 20, 20])


So the input coming out of our  network up to this point is a batch of 64 tensors (for 64 images). Where each tensor is 256x20x20. This matches up with what the paper teels us to expect on page 3-4. Let's create our primary caps now. These are conv layers BUT they will turn outputs into a vector. afterwards. A vector that can be fed into the actual capsules the digitcapsules.

In [84]:
class PrimaryCaps(nn.Module):
    def __init__(self):
        super(PrimaryCaps, self).__init__()
        # this creates a layer with 8 conv2d 
        caps = [nn.Conv2d(in_channels=256, out_channels=32, kernel_size=9, stride=2, padding=0) for _ in range(8)]
        self.caps = nn.ModuleList(caps)
        
    def squash(self, s):
        # s shape batch_size x all_data x 8, sum over all 8 caps, maintain shape
        squared_norm = torch.sum(torch.pow(s, 2), dim=2, keepdim=True)
        norm = torch.sqrt(squared_norm)
        return (squared_norm / (1 + squared_norm)) * (s / norm)
        
    def forward(self, inputs):
        # for every capsule in the layer
        # plug in the inputs
        u = [cap(inputs) for cap in self.caps]
        # list of 8 64x32x6x6 -> tensor of 64x8x32x6x6
        u = torch.stack(u, dim=1)
        #flatten the outputs
        u = u.view(u.size(0), 8, -1)
        # apply non linearity
        u = self.squash(u)
        return u

In [85]:
primary_caps_in = PrimaryCapsIn()
primary_caps = PrimaryCaps()
if torch.cuda.is_available():
    primary_caps.cuda()
    primary_caps_in.cuda()

In [86]:
for epoch in range(n_epochs):
    for image_batch in mnist_loader:
        image_batch = image_batch[0]
        batch_size = image_batch.shape[0]
        image_batch = to_variable(image_batch)
        outputs = primary_caps_in(image_batch)
        outputs = primary_caps(outputs)
        print(outputs.size())
        break
    break

torch.Size([64, 8, 1152])


So we have 8 'primary' caps and they each output what is essentially a fleattened vector of all the convolutional outputs for all images in a batch (just read the tensor shaep backwards as is the convention in pytorch). Let's look at the actual data in this tensor...

Ok so it looks like for one image we get the output of all 8 caps, which is equal to 32 x 6 x 6, which is what the paper tells us to expect at this point on page 4.

Now let's implement the digit caps.

In [299]:
class DigitCaps(nn.Module):
    def __init__(self):
        super(DigitCaps, self).__init__()
        self.in_units = 8 # number of primary caps
        self.in_channels = 32 * 6 * 6 # number of outputs per primary cap
        self.num_units = 10 # 10 digit caps
        self.unit_size = 16 # number of channels/digit cap
        self.W = nn.Parameter(torch.randn(1, 
                                                self.in_channels, 
                                                self.num_units, 
                                                self.unit_size,
                                                self.in_units))
        
    def squash(self, s):
        # s shape batch_size x all_data x 8, sum over all 8 caps, maintain shape
        squared_norm = torch.sum(torch.pow(s, 2), dim=2, keepdim=True)
        norm = torch.sqrt(squared_norm)
        return (squared_norm / (1 + squared_norm)) * (s / norm)
        
    def forward(self, x):
        ## --- multiply vectors by weights
        batch_size = x.size(0)
        x = x.transpose(1, 2)
        # batch, features, in_units
        
        x = torch.stack([x] * self.num_units, dim=2).unsqueeze(4)
        # batch, features, num_units, in_units, 1    
        
        W = torch.cat([self.W] * batch_size, dim=0)
        # batch, features, in_units, unit_size, num_units

        u_hat = torch.matmul(W, x)
        # batch, features, num_units, unit_size, 1

        ## --- route the vectors
        b = Variable(torch.zeros(1, self.in_channels, self.num_units, 1)).cuda()
        
        for iteration in range(3):
            c = F.softmax(b)
            c = torch.cat([c]* batch_size, dim=0).unsqueeze(4)
                        
            # apply routing to weighted inputs u_hat
            # then sum it together.
            s = (c * u_hat).sum(dim=1, keepdim=True)
                        
            # squash the output.
            # bathc_size, 1, n_caps, unit_size, 1
            v = self.squash(s)
                    
            # batch_size, feat, n_caps, unit_size, 1
            v_j = torch.cat([v] * self.in_channels, dim=1)
                        
            # 1, feat, n_caps, 1
            u_v = torch.matmul(u_hat.transpose(3,4), v_j).squeeze(4).mean(dim=0, keepdim=True)
                        
            # update route vectors
            b = b + u_v
        return v.squeeze(1)

In [300]:
primary_caps_in = PrimaryCapsIn()
primary_caps = PrimaryCaps()
digit_caps = DigitCaps()
if torch.cuda.is_available():
    primary_caps_in.cuda()
    primary_caps.cuda()
    digit_caps.cuda()

In [301]:
for epoch in range(n_epochs):
    for image_batch in mnist_loader:
        image_batch = image_batch[0]
        batch_size = image_batch.shape[0]
        image_batch = to_variable(image_batch)
        outputs = primary_caps_in(image_batch)
        outputs = primary_caps(outputs)
        outputs = digit_caps(outputs)
        print(outputs.size())
        break
    break

torch.Size([64, 10, 16, 1])


The output makes sense. It lists for all batches, the 10 digit caps and the 16 channels in each deget cap.

In [302]:
outputs[0].squeeze().size()

torch.Size([10, 16])

The output above contains the logits for all 10 caps for all 16 channels in that cap for one image.


Next we'll wan't to create a capsule loss to measure the loss in this case. In the paper something called a 'margin loss' is used. See formula 4 on page 3.

In [303]:
def margin_loss(predictions, labels):
    batch_size = predictions.size(0)
    zero = to_variable(torch.zeros(1))
    pred_norm = torch.sqrt((predictions**2).sum( dim=2, keepdim=True))
    max_left = torch.max(zero, 0.9 - pred_norm).view(batch_size, -1)**2
    max_right = torch.max(zero, pred_norm - 0.1).view(batch_size, -1)**2
    loss = labels*max_left + 0.5*(1.0 - labels)*max_right
    margin_loss = loss.sum(dim=1).mean()
    return margin_loss

Now let's combine everything into one network.

In [309]:
class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        
        self.conv = PrimaryCapsIn()
        self.primary = PrimaryCaps()
        self.digits = DigitCaps()
        
    def forward(self, inputs):
        return self.digits(self.primary(self.conv(inputs)))
    
    def loss(self, inputs, labels):
        return self.margin_loss(inputs, labels)

In [320]:
n_epochs = 5
lr = 0.01
one_hot_labels = to_variable(torch.eye(10))

capsnet = CapsNet()
optimizer = Adam(capsnet.parameters(), lr=lr)

if torch.cuda.is_available():
    capsnet.cuda()

In [321]:
for epoch in range(n_epochs):
    for images, labels in mnist_loader:
        mnist_labels = torch.LongTensor(labels).cuda()
        image_batch = images
        batch_size = images.shape[0]
        image_batch = to_variable(image_batch)
        
        optimizer.zero_grad()
        outputs = capsnet(image_batch)

        batch_labels = one_hot_labels.index_select(dim=0, index=mnist_labels)
        loss = margin_loss(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        
    print('epoch {} loss: {}'.format(epoch, loss.data[0]))

epoch 0 loss: 0.06106878072023392
epoch 1 loss: 0.048292066901922226
epoch 2 loss: 0.02802092768251896
epoch 3 loss: 0.013605736196041107
epoch 4 loss: 0.037176355719566345


With a train loss near 0.01 it's safe to say we have trained/fit the model!

Regularization is important to avoid overfitting and to help our netowrk learn a more general representation. They use a decoder to reconstruct inputs from the capsules and then compare those reconstructed images against the acutal images. 

In [None]:
class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()
        
        self.fc1 = nn.Linear(16*10, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 784)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, inputs):
        inputs = inputs.squeeze().transpose(0,1)
        outputs = self.fc1(inputs)
        outputs = self.relu(outputs)
        outputs = self.fc2(outputs)
        outputs = self.relu(outputs)
        outputs = self.fc3(outputs)
        outputs = self.sigmoid(outputs)
        return outputs

In [None]:
primary_caps_in = PrimaryCapsIn()
primary_caps = PrimaryCaps()
digit_caps = DigitCaps()
decoder = Decoder()
if torch.cuda.is_available():
    primary_caps_in.cuda()
    primary_caps.cuda()
    digit_caps.cuda()
    decoder.cuda()

In [None]:
for epoch in range(n_epochs):
    for image_batch in mnist_loader:
        image_batch = image_batch[0]
        batch_size = image_batch.shape[0]
        image_batch = to_variable(image_batch)
        outputs = primary_caps_in(image_batch)
        outputs = primary_caps(outputs)
        outputs = digit_caps(outputs)
        break
    break

Other sources snippets:

https://github.com/timomernick/pytorch-capsule/blob/master/capsule_network.py

https://github.com/gram-ai/capsule-networks
    
https://github.com/XifengGuo/CapsNet-Keras
    
https://github.com/naturomics/CapsNet-Tensorflow/blob/master/capsNet.py

https://medium.com/ai³-theory-practice-business/understanding-hintons-capsule-networks-part-i-intuition-b4b559d1159b

https://medium.com/ai³-theory-practice-business/understanding-hintons-capsule-networks-part-ii-how-capsules-work-153b6ade9f66

https://medium.com/ai³-theory-practice-business/understanding-hintons-capsule-networks-part-iii-dynamic-routing-between-capsules-349f6d30418

concepts to understand...

Squashing function:

http://mathworld.wolfram.com/VectorNorm.html