In [150]:
!pip install torch



In [151]:
import torch
from torchvision import datasets, transforms 

In [152]:
class Dataset:
  def __init__(self, dataset, batch_size):
    super(Dataset, self).__init__()

    if dataset == 'mnist':
      dataset_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307, ), (0.3081, ))
      ])


      train_dataset = datasets.MNIST('/data/mnist', train = True, download = True,
                                     transform = dataset_transform)
      
      test_dataset = datasets.MNIST('/data/mnist', train = False, download = True,
                                    transform = dataset_transform)
      
      self.load_train = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle = True)
      self.load_test = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle = False)


    elif dataset == 'cifar10':
      dataset_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
      ])


      train_dataset = datasets.CIFAR10('/data/cifar', train = True, download = True,
                                     transform = dataset_transform)
      
      test_dataset = datasets.CIFAR10('/data/cifar', train = False, download = True,
                                    transform = dataset_transform)
      
      self.load_train = torch.utils.data.DataLoader(train_dataset, batch_size, shuffle = True)
      self.load_test = torch.utils.data.DataLoader(test_dataset, batch_size, shuffle = False)


In [153]:
#Capsule Net architecture

import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable

USE_CUDA = True if torch.cuda.is_available() else False

In [154]:
class ConvLayer(nn.Module):
  def __init__(self, in_channels = 1, out_channels = 256, kernel_size = 9):
    super(ConvLayer, self).__init__()

    self.conv = nn.Conv2d(in_channels= in_channels,
                          out_channels = out_channels,
                          kernel_size = kernel_size,
                          stride = 1)
    
  def forward(self, x):
    return F.relu(self.conv(x))  

In [155]:
class PrimaryCaps(nn.Module):
  def __init__(self, num_capsules = 8, in_channels = 256, out_channels = 32, kernel_size = 9, num_routes = 32 * 6 * 6):
    super(PrimaryCaps, self).__init__()
    self.num_routes = num_routes
    self.capsules = nn.ModuleList([
                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels, 
                              kernel_size=kernel_size, stride = 2, padding = 0)
                    for _ in range(num_capsules)])
  
  def forward(self, x):
    u = [capsule(x) for capsule in self.capsules]
    u = torch.stack(u, dim = 1)
    u = u.view(x.size(0), self.num_routes, -1)
    u = self.squash(u)
    return u


  def squash(self, in_tensor):
    sq_norm = (in_tensor ** 2).sum(-1, keepdim = True)
    output_tensor = sq_norm * in_tensor / ((1 + sq_norm) * torch.sqrt(sq_norm))
    return output_tensor

In [156]:
class DigitCaps(nn.Module):
  def __init__(self, num_capsules = 10, num_routes = 6 * 6 * 32, in_channels = 8, out_channels = 16):
    super(DigitCaps, self).__init__()

    self.in_channels = in_channels
    self.out_channels = out_channels
    self.num_capsules = num_capsules
    self.num_routes = num_routes

    self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))


  def forward(self, x):
    batch_size = x.size(0)
    x = torch.stack([x] * self.num_capsules, dim = 2).unsqueeze(4)
    W = torch.cat([self.W] * batch_size, dim = 0)
    u_hat = torch.matmul(W, x)

    #following the routing algorithm
    b_ij = Variable(torch.zeros(1, self.num_routes, self.num_capsules, 1))
    if USE_CUDA: 
      b_ij = b_ij.cuda()

    n_iterations = 3
    for iteration in range(n_iterations):
      c_ij = F.softmax(b_ij, dim = 1)
      c_ij = torch.cat([c_ij] * batch_size, dim = 0).unsqueeze(4)
      s_j = (c_ij * u_hat).sum(dim = 1, keepdim = True)
      v_j = self.squash(s_j)

      #this is to find similarity of v and u_hat, larger sim gets larger routing coefficient
      if iteration < n_iterations - 1:
        a_ij = torch.matmul(u_hat.transpose(3, 4), torch.cat([v_j] * self.num_routes, dim = 1)) 
        b_ij = b_ij + a_ij.squeeze(4).mean(dim = 0, keepdim = True)

    return v_j    

  def squash(self, in_tensor):
    sq_norm = (in_tensor ** 2).sum(-1, keepdim = True)
    output_tensor = sq_norm * in_tensor / ((1 + sq_norm) * torch.sqrt(sq_norm))
    return output_tensor

In [165]:
class Decoder(nn.Module):
  def __init__(self, w = 28, h = 28, channels = 1):
    super(Decoder, self).__init__()
    self.w = w
    self.h = h
    self.channels = channels

    self.reconstruction_layers = nn.Sequential(
        nn.Linear(16 * 10, 512),
        nn.ReLU(inplace = True),
        nn.Linear(512, 1024),
        nn.ReLU(inplace = True),
        nn.Linear(1024, self.w * self.h * self.channels),
        nn.Sigmoid()
    )

  def forward(self, x, data):
      classes = torch.sqrt((x**2).sum(2)) #sum along the samples for each digit prediction (assume parallel Digit caps layers with >1 data)
      classes = F.softmax(classes, dim = 0)

      _, max_len_idx = classes.max(dim = 1)
      masked = Variable(torch.sparse.torch.eye(10))
      if USE_CUDA : 
        masked = masked.to('cuda:0')

      print(max_len_idx.squeeze(1).data.shape)
      masked = masked.index_select(dim = 0, index = Variable(max_len_idx.squeeze(1).data))
      x_in = (x * masked[:, :, None, None]).view(x.size(0), -1)
      recons = self.reconstruction_layers(x_in)
      recons = recons.view(-1, self.channels, self.w, self.h)
      return recons, masked  

In [166]:
class CapsNet(nn.Module):
    def __init__(self, config=None):
        super(CapsNet, self).__init__()
        if config:
          self.conv_layer = ConvLayer(config.cnn_in_channels, config.cnn_out_channels,
                                      config.cnn_kernel_size)
          
          self.primary_caps = PrimaryCaps(config.pc_num_capsules, config.pc_in_channels, 
                                          config.pc_out_channels, config.pc_kernel_size, config.pc_num_routes)
          
          self.digit_caps = DigitCaps(config.dc_num_capsules, config.dc_num_routes, config.dc_in_channels,
                                      config.dc_out_channels)
          
          self.decoder = Decoder(config.input_width, config.input_height, config.cnn_in_channels)

        else:

          self.conv_layer = ConvLayer()  
          self.primary_caps = PrimaryCaps()  
          self.digit_caps = DigitCaps()  
          self.decoder = Decoder()


        self.mse_loss = nn.MSELoss()



    def forward(self, data):
      out = self.digit_caps(self.primary_caps(self.conv_layer(data)))
      recons, masked = self.decoder(out, data)
      return out, recons, masked

    def loss(self, data, x, target, recons):
      loss_total = self.margin_loss(x, target) + self.recons_loss(data, recons)
      return loss_total

    def margin_loss(self, x, labels, size_avg = True):
      batch_size = x.size(0)
      v_c = torch.sqrt((x ** 2).sum(dim = 2, keepdim = True))

      left_term = F.relu(0.9 - v_c).view(batch_size, -1)  
      right_term = F.relu(v_c - 0.1).view(batch_size, -1)

      loss = labels * left_term + (1 - labels) * right_term * 0.5
      loss = loss.sum(dim = 1, keepdim = True)

      return loss  

    def recons_loss(self, data, recons):
      loss = self.mse_loss(recons.view(recons.size(0), -1), data.view(recons.size(0), -1))  
      lambd = 0.0005
      return lambd * loss


In [167]:
#Test phase

import numpy as np
import torch
from tqdm import tqdm  

In [168]:
USE_CUDA = True if torch.cuda.is_available() else False
batch_size = 100
n_epochs = 25
lr = 0.01
momentum = 0.9

In [169]:
#config class is an easy way to access all the params of the dataset

class Config:

  def __init__(self, dataset = 'mnist'):
        if dataset == 'mnist':
            self.cnn_in_channels = 1
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 9

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 9
            self.pc_num_routes = 32 * 6 * 6

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 6 * 6
            self.dc_in_channels = 8
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 28
            self.input_height = 28

        elif dataset == 'cifar10':
            # CNN (cnn)
            self.cnn_in_channels = 3
            self.cnn_out_channels = 256
            self.cnn_kernel_size = 9

            # Primary Capsule (pc)
            self.pc_num_capsules = 8
            self.pc_in_channels = 256
            self.pc_out_channels = 32
            self.pc_kernel_size = 9
            self.pc_num_routes = 32 * 8 * 8

            # Digit Capsule (dc)
            self.dc_num_capsules = 10
            self.dc_num_routes = 32 * 8 * 8
            self.dc_in_channels = 8
            self.dc_out_channels = 16

            # Decoder
            self.input_width = 32
            self.input_height = 32      

In [170]:
def train(model, optimizer, load_train, epoch):
    capsule_nw = model
    capsule_nw.train()
    num_batch = len(list(enumerate(load_train)))
    total_loss = 0
    for batch_idx, (data, target) in enumerate(tqdm(load_train)):

      target = torch.sparse.torch.eye(10).index_select(dim = 0, index = target)
      data, target = Variable(data), Variable(target)

      if USE_CUDA:
        data, target = data.cuda(), target.cuda()

      optimizer.zero_grad()
      output, recons, masked = capsule_nw(data)
      loss = capsule_nw.loss(data, output, target, recons)  
      loss.backward()
      optimizer.step()
      correct = sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))
      train_loss = loss.item()
      total_loss += train_loss

      if batch_idx % 100 == 0:
        tqdm.write("Epoch: [{}/{}], BatchNo: [{}/{}], train_accuracy: {:.6f}, loss: {:.6f}".format(
            epoch,
            n_epochs, 
            batch_idx + 1,
            num_batch,
            correct / float(batch_size),
            train_loss / float(batch_size)
        ))

    tqdm.write('Epoch: [{}/{}], train loss: {:.6f}'.format(epoch, n_epochs,
                                                          total_loss / len(load_train.dataset))) 

In [171]:
def test(model, load_test, epoch):
  capsule_nw = model
  capsule_nw.eval()
  test_loss = 0
  correct = 0

  for batch_idx, (data, target) in enumerate(load_test):

    target = torch.sparse.torch.eye(10).index_select(dim = 0, index = target)
    data, target = Variable(data), Variable(target)

    if USE_CUDA:
      data, target = data.cuda(), target.cuda()

    output, recons, masked = capsule_nw(data)
    loss = capsule_nw.loss(data, output, target, recons)  
    
    test_loss += loss.item()
    correct += sum(np.argmax(masked.data.cpu().numpy(), 1) == np.argmax(target.data.cpu().numpy(), 1))
  
  tqdm.write('Epoch: [{}/{}], test loss: {:.6f}'.format(epoch, n_epochs,
                                                         test_loss / len(load_test))) 

In [172]:
if __name__ == '__main__':
  torch.manual_seed(1)
  dataset = 'mnist'
  config = Config(dataset)
  mnist = Dataset(dataset, batch_size)

  capsule_network = CapsNet(config)
  capsule_network = torch.nn.DataParallel(capsule_network)

  if USE_CUDA:
    capsule_network = capsule_network.cuda()
  capsule_network = capsule_network.module  

  optimizer = torch.optim.Adam(capsule_network.parameters())

  for i in range(n_epochs):
    train(capsule_network, optimizer, mnist.load_train, i)
    test(capsule_network, mnist.load_test, i)


  0%|          | 0/600 [00:00<?, ?it/s]

torch.Size([100, 16, 1])





RuntimeError: ignored