In [1]:
import torch
import torch.nn as nn
import numpy as np
import os

import matplotlib
if os.environ.get('DISPLAY','') == '':
    print('no display found. Using non-interactive Agg backend')
    matplotlib.use('Agg')
import matplotlib.pyplot as plt 
# The squash function specified in Dynamic Routing Between Capsules
# x: input tensor 



  
def squash(x, dim=-1):
  norm_squared = (x ** 2).sum(dim, keepdim=True)
  part1 = norm_squared / (1 +  norm_squared)
  part2 = x / torch.sqrt(norm_squared+ 1e-16)

  output = part1 * part2 
  return output


def quantize(s, word, fraction):
  shift = pow(2, fraction)
  s = s * float(shift)
  s = torch.round(s)
  sat_p = pow(2, word-1)-1
  sat_m = pow(2, word-1)*(-1)
  s = torch.clamp(s, sat_m, sat_p)
  #s = np.int8(s)
  s = s / float(shift)
  s= s.cuda().float()
  return s

def weights_init_xavier(m):
    classname = m.__class__.__name__
    ignore_modules = [
        "SmallNorbConvReconstructionModule",
        "ConvReconstructionModule",
        "ConvLayer"
    ]
    
    if classname.find('Conv') != -1 and classname not in ignore_modules:
        nn.init.xavier_normal_(m.weight.data, gain=0.02)
    elif classname.find('Linear') != -1:
        nn.init.xavier_normal_(m.weight.data, gain=0.02)
    elif classname.find('BatchNorm2d') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0.0)
    elif classname == 'ClassCapsules': 
        nn.init.xavier_normal_(m.W.data, gain=0.002)
        nn.init.xavier_normal_(m.bias.data, gain=0.002)
        
        
def initialize_weights(capsnet):
    capsnet.apply(weights_init_xavier)
    
def denormalize(image):
    image = image - image.min()
    image = image / image.max()
    return image
  
    
def get_path(SAVE_DIR, filename):
    if not os.path.isdir(SAVE_DIR):
        os.makedirs(SAVE_DIR)
    path = os.path.join(SAVE_DIR, filename)
    return path
    
def save_images(SAVE_DIR, filename, images, reconstructions, num_images = 100, imsize=28):
    if len(images) < num_images or len(reconstructions) < num_images:
        print("Not enough images to save.")
        return

    big_image = np.ones((imsize*10, imsize*20+1))
    images = denormalize(images).view(-1, imsize, imsize)
    reconstructions = denormalize(reconstructions).view(-1, imsize, imsize)
    images = images.data.cpu().numpy()
    reconstructions = reconstructions.data.cpu().numpy()
    for i in range(num_images):
        image = images[i]
        rec = reconstructions[i]
        j = i % 10
        i = i // 10
        big_image[i*imsize:(i+1)*imsize, j*imsize:(j+1)*imsize] = image
        j += 10
        big_image[i*imsize:(i+1)*imsize, j*imsize+1:(j+1)*imsize+1] = rec

    path = get_path(SAVE_DIR, filename)
    plt.imsave(path, big_image, cmap="gray")

def save_images_cifar10(SAVE_DIR, filename, images, reconstructions, num_images = 100):
    if len(images) < num_images or len(reconstructions) < num_images:
        print("Not enough images to save.")
        return

    big_image = np.ones((3,32*10, 32*20+1))
    #print('Images : ',big_image.T.shape,',',reconstructions.size())
    images = denormalize(images).view(-1, 3 ,32, 32)
    reconstructions = denormalize(reconstructions).view(-1, 3 ,32, 32)
    images = images.data.cpu().numpy()
    reconstructions = reconstructions.data.cpu().numpy()
    for i in range(num_images):
        image = images[i]
        rec = reconstructions[i]
        j = i % 10
        i = i // 10
        big_image[:,i*32:(i+1)*32, j*32:(j+1)*32] = image
        j += 10
        big_image[:,i*32:(i+1)*32, j*32+1:(j+1)*32+1] = rec

    path = get_path(SAVE_DIR, filename)
    plt.imsave(path, big_image.T)


no display found. Using non-interactive Agg backend


In [2]:
import torch.nn as nn
import torch.nn.functional as functional

import torch
from torch.autograd import Variable
USE_GPU=True

def routing_algorithm(x, weight, bias, routing_iterations):
    """
    x: [batch_size, num_capsules_in, capsule_dim]
    weight: [1,num_capsules_in,num_capsules_out,out_channels,in_channels]
    bias: [1,1, num_capsules_out, out_channels]
    """
    num_capsules_in = x.shape[1]
    num_capsules_out = weight.shape[2]
    batch_size = x.size(0)
    
    x = x.unsqueeze(2).unsqueeze(4)
    #weight = quantize(weight,8,7)
    #bais = quantize(bias,8,7)
    #[batch_size, 32*6*6, 10, 16]
    u_hat = torch.matmul(weight, x).squeeze()
    #u_hat = quantize(u_hat,8,7)
    b_ij = Variable(x.new(batch_size, num_capsules_in, num_capsules_out, 1).zero_())


    for it in range(routing_iterations):
      c_ij = functional.softmax(b_ij, dim=2)
      #c_ij = quantize(c_ij,8,7)
      # [batch_size, 1, num_classes, capsule_size]
      s_j = (c_ij * u_hat).sum(dim=1, keepdim=True) + bias
      s_j = quantize(s_j,8,7)
      # [batch_size, 1, num_capsules, out_channels]
      v_j = squash(s_j, dim=-1)
      if it < routing_iterations - 1: 
        # [batch-size, 32*6*6, 10, 1]
        delta = (u_hat * v_j).sum(dim=-1, keepdim=True)
        b_ij = b_ij + delta
    #v_j = quantize(v_j,8,7)
    return v_j.squeeze()

# First Convolutional Layer
class ConvLayer(nn.Module):
  def __init__(self, 
               in_channels=1, 
               out_channels=256, 
               kernel_size=9,
               batchnorm=False):
    super(ConvLayer, self).__init__()
    
    if batchnorm:
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    else:
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels=in_channels,
                              out_channels=out_channels,
                              kernel_size=kernel_size,
                              stride=1),
            nn.ReLU()
        )
  def forward(self, x):
    #x = quantize(x, 8, 7)
    output = self.conv(x)
    output = quantize(output, 8, 6)
    return output

class PrimaryCapules(nn.Module):
  
  def __init__(self, 
               num_capsules=32, 
               in_channels=256, 
               out_channels=8, 
               kernel_size=9,
               primary_caps_gridsize=6,
               batchnorm=False):

    super(PrimaryCapules, self).__init__()
    self.gridsize = primary_caps_gridsize
    self.num_capsules = num_capsules
    if batchnorm:
        self.capsules = nn.ModuleList([
          nn.Sequential(
          nn.Conv2d(in_channels=in_channels,
                    out_channels=num_capsules,
                    kernel_size=kernel_size,
                    stride=2,
                    padding=0),
          nn.BatchNorm2d(num_capsules)
          )
           for i in range(out_channels)
        ])
    else:
        self.capsules = nn.ModuleList([
          nn.Sequential(
          nn.Conv2d(in_channels=in_channels,
                    out_channels=num_capsules,
                    kernel_size=kernel_size,
                    stride=2,
                    padding=0),

          )
           for i in range(out_channels)
        ])
  
  def forward(self, x):
    output = [caps(x) for caps in self.capsules]
    output = torch.stack(output, dim=1)
    output = quantize(output, 8, 6)
    output = output.view(x.size(0), self.num_capsules*(self.gridsize)*(self.gridsize), -1)
    output = squash(output)
    output = quantize(output, 8, 6)
    return output


class ClassCapsules(nn.Module):
  
  def __init__(self, 
               num_capsules=10,
               num_routes = 32*6*6,
               in_channels=8,
               out_channels=16,
               routing_iterations=3,
               leaky=False):
    super(ClassCapsules, self).__init__()
    

    self.in_channels = in_channels
    self.num_routes = num_routes
    self.num_capsules = num_capsules
    self.routing_iterations = routing_iterations
    
    self.W = nn.Parameter(torch.rand(1,num_routes,num_capsules,out_channels,in_channels))
    self.bias = nn.Parameter(torch.rand(1,1, num_capsules, out_channels))


  # [batch_size, 10, 16, 1]
  def forward(self, x):
    v_j = routing_algorithm(x, self.W, self.bias, self.routing_iterations)
    return v_j.unsqueeze(-1)


class ReconstructionModule(nn.Module):
  def __init__(self, capsule_size=16, num_capsules=10, imsize=28,img_channel=1, batchnorm=False):
    super(ReconstructionModule, self).__init__()
    
    self.num_capsules = num_capsules
    self.capsule_size = capsule_size
    self.imsize = imsize
    self.img_channel = img_channel
    if batchnorm:
        self.decoder = nn.Sequential(
              nn.Linear(capsule_size*num_capsules, 512),
              nn.BatchNorm1d(512),
              nn.ReLU(),
              nn.Linear(512, 1024),        
              nn.BatchNorm1d(1024),
              nn.ReLU(),
              nn.Linear(1024, imsize*imsize*img_channel),
              nn.Sigmoid()
        )
    else:
        self.decoder = nn.Sequential(
              nn.Linear(capsule_size*num_capsules, 512),
              nn.ReLU(),
              nn.Linear(512, 1024),        
              nn.ReLU(),
              nn.Linear(1024, imsize*imsize*img_channel),
              nn.Sigmoid()
        )
        
  def forward(self, x, target=None):
    batch_size = x.size(0)
    if target is None:
      classes = torch.norm(x, dim=2)
      max_length_indices = classes.max(dim=1)[1].squeeze()
    else:
      max_length_indices = target.max(dim=1)[1]
    
    masked = Variable(x.new_tensor(torch.eye(self.num_capsules)))
    masked = masked.cuda()
    masked = masked.index_select(dim=0, index=max_length_indices.data)
    decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1)

    reconstructions = self.decoder(decoder_input)
    reconstructions = reconstructions.view(-1, self.img_channel, self.imsize, self.imsize)
    return reconstructions, masked

class ConvReconstructionModule(nn.Module):
  def __init__(self, num_capsules=10, capsule_size=16, imsize=28,img_channels=1, batchnorm=False):
    super(ConvReconstructionModule, self).__init__()
    self.num_capsules = num_capsules
    self.capsule_size = capsule_size
    self.imsize = imsize
    self.img_channels = img_channels
    self.grid_size = 6
    if batchnorm:
      self.FC = nn.Sequential(
        nn.Linear(capsule_size * num_capsules, num_capsules * (self.grid_size)**2 ),
        nn.BatchNorm1d(num_capsules * self.grid_size**2),
        nn.ReLU()
      )
      self.decoder = nn.Sequential(
          nn.ConvTranspose2d(in_channels=self.num_capsules, out_channels=32, kernel_size=9, stride=2),
          nn.BatchNorm2d(32),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1),
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=2, stride=1),
          nn.Sigmoid()
        )
    else:
        self.FC = nn.Sequential(
            nn.Linear(capsule_size * num_capsules, num_capsules *(self.grid_size**2) ),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
          nn.ConvTranspose2d(in_channels=self.num_capsules, out_channels=32, kernel_size=9, stride=2),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=2, stride=1),
          nn.Sigmoid()
        )
    
  def forward(self, x, target=None):
    batch_size = x.size(0)
    if target is None:
      classes = torch.norm(x, dim=2)
      max_length_indices = classes.max(dim=1)[1].squeeze()
    else:
      max_length_indices = target.max(dim=1)[1]
    
    masked = x.new_tensor(torch.eye(self.num_capsules))
    masked = masked.index_select(dim=0, index=max_length_indices.data)

    decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1)
    decoder_input = self.FC(decoder_input)
    decoder_input = decoder_input.view(batch_size,self.num_capsules, self.grid_size, self.grid_size)
    reconstructions = self.decoder(decoder_input)
    reconstructions = reconstructions.view(-1, self.img_channels, self.imsize, self.imsize)
    
    return reconstructions, masked




class SmallNorbConvReconstructionModule(nn.Module):
  def __init__(self, num_capsules=10, capsule_size=16, imsize=28,img_channels=1, batchnorm=False):
    super(SmallNorbConvReconstructionModule, self).__init__()
    self.num_capsules = num_capsules
    self.capsule_size = capsule_size
    self.imsize = imsize
    self.img_channels = img_channels
    
    self.grid_size = 4
    
    if batchnorm:
      self.FC = nn.Sequential(
            nn.Linear(capsule_size * num_capsules, num_capsules *self.grid_size*self.grid_size),
            nn.BatchNorm1d(num_capsules * self.grid_size**2),
            nn.ReLU()
        )
      self.decoder = nn.Sequential(
          nn.ConvTranspose2d(in_channels=num_capsules, out_channels=32, kernel_size=9, stride=2),
          nn.BatchNorm2d(32),            
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1),
          nn.BatchNorm2d(64),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=9, stride=1),
          nn.BatchNorm2d(128),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=2, stride=1),
          nn.Sigmoid()
        )
    else:
        self.FC = nn.Sequential(
            nn.Linear(capsule_size * num_capsules, num_capsules *(self.grid_size**2) ),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
          nn.ConvTranspose2d(in_channels=num_capsules, out_channels=32, kernel_size=9, stride=2),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=9, stride=1),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=64, out_channels=128, kernel_size=9, stride=1),
          nn.ReLU(),
          nn.ConvTranspose2d(in_channels=128, out_channels=img_channels, kernel_size=2, stride=1),
          nn.Sigmoid()
        )
    
  def forward(self, x, target=None):
    batch_size = x.size(0)
    if target is None:
      classes = torch.norm(x, dim=2)
      max_length_indices = classes.max(dim=1)[1].squeeze()
    else:
      max_length_indices = target.max(dim=1)[1]
    masked = Variable(x.new_tensor(torch.eye(self.num_capsules)))
    masked = masked.index_select(dim=0, index=max_length_indices.data)

    decoder_input = (x * masked[:, :, None, None]).view(batch_size, -1)
    decoder_input = self.FC(decoder_input)
    decoder_input = decoder_input.view(batch_size,self.num_capsules, self.grid_size, self.grid_size)
    reconstructions = self.decoder(decoder_input)
    reconstructions = reconstructions.view(-1, self.img_channels, self.imsize, self.imsize)
    
    return reconstructions, masked




class CapsNet(nn.Module):
  
  def __init__(self,
               reconstruction_type = "FC",
               imsize=28,
               num_classes=10,
               routing_iterations=3,
               primary_caps_gridsize=6,
               img_channels = 1,
               batchnorm = False,
               loss = "L2",
               num_primary_capsules=32,
               leaky_routing = False
              ):
    super(CapsNet, self).__init__()
    self.num_classes = num_classes
    if leaky_routing:
        num_classes += 1
        self.num_classes += 1
        
    self.imsize=imsize
    self.conv_layer = ConvLayer(in_channels=img_channels, batchnorm=batchnorm)
    self.leaky_routing = leaky_routing

    self.primary_capsules = PrimaryCapules(primary_caps_gridsize=primary_caps_gridsize,
                                           batchnorm=batchnorm,
                                           num_capsules = num_primary_capsules)
    
    self.digit_caps = ClassCapsules(num_capsules=num_classes,
                                    num_routes=num_primary_capsules*primary_caps_gridsize*primary_caps_gridsize,
                                    routing_iterations=routing_iterations,
                                    leaky=leaky_routing)

    if reconstruction_type == "FC":
        self.decoder = ReconstructionModule(imsize=imsize,
                                            num_capsules=num_classes,
                                            img_channel=img_channels, 
                                            batchnorm=batchnorm)
    elif reconstruction_type == "Conv32":
        self.decoder = SmallNorbConvReconstructionModule(num_capsules=num_classes,
                                                         imsize=imsize, 
                                                         img_channels=img_channels, 
                                                         batchnorm=batchnorm)            
    else:
        self.decoder = ConvReconstructionModule(num_capsules=num_classes,
                                                imsize=imsize, 
                                                img_channels=img_channels,
                                                batchnorm=batchnorm)
    
    if loss == "L2":
        self.reconstruction_criterion = nn.MSELoss(reduction="none")
    if loss == "L1":
        self.reconstruction_criterion = nn.L1Loss(reduction="none")
  
  def forward(self, x, target=None):
    output = self.conv_layer(x)
    output = self.primary_capsules(output)
    output = self.digit_caps(output)
    reconstruction, masked = self.decoder(output, target)

    return output, reconstruction, masked
  
  def loss(self, images, labels, capsule_output,  reconstruction, alpha):
    marg_loss = self.margin_loss(capsule_output, labels)
    rec_loss = self.reconstruction_loss(images, reconstruction)
    total_loss = (marg_loss + alpha * rec_loss).mean()
    return total_loss, rec_loss.mean(), marg_loss.mean()
  
  def margin_loss(self, x, labels):
    batch_size = x.size(0)
    v_c = torch.norm(x, dim=2, keepdim=True)
    
    left = functional.relu(0.9 - v_c).view(batch_size, -1) ** 2
    right = functional.relu(v_c - 0.1).view(batch_size, -1) ** 2

    loss = labels * left + 0.5 *(1-labels)*right
    loss = loss.sum(dim=1)
    return loss
  
  def reconstruction_loss(self, data, reconstructions):
    batch_size = reconstructions.size(0)
    reconstructions = reconstructions.view(batch_size, -1)
    data = data.view(batch_size, -1)
    loss = self.reconstruction_criterion(reconstructions, data)
    loss = loss.sum(dim=1)
    return loss


In [3]:
# Loader taken from https://github.com/mavanb/vision/blob/448fac0f38cab35a387666d553b9d5e4eec4c5e6/torchvision/datasets/utils.py

from __future__ import print_function
import os
import errno
import struct

import torch
import torch.utils.data as data
import numpy as np
from PIL import Image
from torchvision.datasets.utils import download_url, check_integrity


class SmallNORB(data.Dataset):
    """`MNIST <https://cs.nyu.edu/~ylclab/data/norb-v1.0-small//>`_ Dataset.
    Args:
        root (string): Root directory of dataset where processed folder and
            and  raw folder exist.
        train (bool, optional): If True, creates dataset from the training files,
            otherwise from the test files.
        download (bool, optional): If true, downloads the dataset from the internet and
            puts it in root directory. If the dataset is already processed, it is not processed
            and downloaded again. If dataset is only already downloaded, it is not
            downloaded again.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        info_transform (callable, optional): A function/transform that takes in the
            info and transforms it.
        mode (string, optional): Denotes how the images in the data files are returned. Possible values:
            - all (default): both left and right are included separately.
            - stereo: left and right images are included as corresponding pairs.
            - left: only the left images are included.
            - right: only the right images are included.
    """

    dataset_root = "https://cs.nyu.edu/~ylclab/data/norb-v1.0-small/"
    data_files = {
        'train': {
            'dat': {
                "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-dat.mat',
                "md5_gz": "66054832f9accfe74a0f4c36a75bc0a2",
                "md5": "8138a0902307b32dfa0025a36dfa45ec"
            },
            'info': {
                "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-info.mat',
                "md5_gz": "51dee1210a742582ff607dfd94e332e3",
                "md5": "19faee774120001fc7e17980d6960451"
            },
            'cat': {
                "name": 'smallnorb-5x46789x9x18x6x2x96x96-training-cat.mat',
                "md5_gz": "23c8b86101fbf0904a000b43d3ed2fd9",
                "md5": "fd5120d3f770ad57ebe620eb61a0b633"
            },
        },
        'test': {
            'dat': {
                "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-dat.mat',
                "md5_gz": "e4ad715691ed5a3a5f138751a4ceb071",
                "md5": "e9920b7f7b2869a8f1a12e945b2c166c"
            },
            'info': {
                "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-info.mat',
                "md5_gz": "a9454f3864d7fd4bb3ea7fc3eb84924e",
                "md5": "7c5b871cc69dcadec1bf6a18141f5edc"
            },
            'cat': {
                "name": 'smallnorb-5x01235x9x18x6x2x96x96-testing-cat.mat',
                "md5_gz": "5aa791cd7e6016cf957ce9bdb93b8603",
                "md5": "fd5120d3f770ad57ebe620eb61a0b633"
            },
        },
    }

    raw_folder = 'raw'
    processed_folder = 'processed'
    train_image_file = 'train_img'
    train_label_file = 'train_label'
    train_info_file = 'train_info'
    test_image_file = 'test_img'
    test_label_file = 'test_label'
    test_info_file = 'test_info'
    extension = '.pt'

    def __init__(self, root, train=True, transform=None, target_transform=None, info_transform=None, download=False,
                 mode="all"):

        self.root = os.path.expanduser(root)
        self.transform = transform
        self.target_transform = target_transform
        self.info_transform = info_transform
        self.train = train  # training set or test set
        self.mode = mode

        if download:
            self.download()

        if not self._check_exists():
            raise RuntimeError('Dataset not found or corrupted.' +
                               ' You can use download=True to download it')

        # load test or train set
        image_file = self.train_image_file if self.train else self.test_image_file
        label_file = self.train_label_file if self.train else self.test_label_file
        info_file = self.train_info_file if self.train else self.test_info_file
        # load labels
        self.labels = self._load(label_file)
        # load info files
        self.infos = self._load(info_file)

        # load right set
        if self.mode == "left":
            self.data = self._load("{}_left".format(image_file))

        # load left set
        elif self.mode == "right":
            self.data = self._load("{}_right".format(image_file))

        elif self.mode == "all" or self.mode == "stereo":
            left_data = self._load("{}_left".format(image_file))
            right_data = self._load("{}_right".format(image_file))
            # load stereo
            if self.mode == "stereo":
                self.data = torch.stack((left_data, right_data), dim=1)
    
            # load all
            else:
                self.data = torch.cat((left_data, right_data), dim=0)
    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            mode ``all'', ``left'', ``right'':
                tuple: (image, target, info)
            mode ``stereo'':
                tuple: (image left, image right, target, info)
        """
        target = self.labels[index % 24300] if self.mode is "all" else self.labels[index]
        if self.target_transform is not None:
            target = self.target_transform(target)

        info = self.infos[index % 24300] if self.mode is "all" else self.infos[index]
        if self.info_transform is not None:
            info = self.info_transform(info)

        if self.mode == "stereo":
            img_left = self._transform(self.data[index, 0])
            img_right = self._transform(self.data[index, 1])
            return img_left, img_right, target, info

        img = self._transform(self.data[index])
        return img, target

    def __len__(self):
        return len(self.data)

    def _transform(self, img):
        # doing this so that it is consistent with all other data sets
        # to return a PIL Image
        img = Image.fromarray(img.numpy(), mode='L')

        if self.transform is not None:
            img = self.transform(img)
        return img

    def _load(self, file_name):
        return torch.load(os.path.join(self.root, self.processed_folder, file_name + self.extension))

    def _save(self, file, file_name):
        with open(os.path.join(self.root, self.processed_folder, file_name + self.extension), 'wb') as f:
            torch.save(file, f)

    def _check_exists(self):
        """ Check if processed files exists."""
        files = (
            "{}_left".format(self.train_image_file),
            "{}_right".format(self.train_image_file),
            "{}_left".format(self.test_image_file),
            "{}_right".format(self.test_image_file),
            self.test_label_file,
            self.train_label_file
        )
        fpaths = [os.path.exists(os.path.join(self.root, self.processed_folder, f + self.extension)) for f in files]
        return False not in fpaths

    def _flat_data_files(self):
        return [j for i in self.data_files.values() for j in list(i.values())]

    def _check_integrity(self):
        """Check if unpacked files have correct md5 sum."""
        root = self.root
        for file_dict in self._flat_data_files():
            filename = file_dict["name"]
            md5 = file_dict["md5"]
            fpath = os.path.join(root, self.raw_folder, filename)
            if not check_integrity(fpath, md5):
                return False
        return True

    def download(self):
        """Download the SmallNORB data if it doesn't exist in processed_folder already."""
        import gzip

        if self._check_exists():
            return

        # check if already extracted and verified
        if self._check_integrity():
            print('Files already downloaded and verified')
        else:
            # download and extract
            for file_dict in self._flat_data_files():
                url = self.dataset_root + file_dict["name"] + '.gz'
                filename = file_dict["name"]
                gz_filename = filename + '.gz'
                md5 = file_dict["md5_gz"]
                fpath = os.path.join(self.root, self.raw_folder, filename)
                gz_fpath = fpath + '.gz'

                # download if compressed file not exists and verified
                download_url(url, os.path.join(self.root, self.raw_folder), gz_filename, md5)

                print('# Extracting data {}\n'.format(filename))

                with open(fpath, 'wb') as out_f, \
                        gzip.GzipFile(gz_fpath) as zip_f:
                    out_f.write(zip_f.read())

                os.unlink(gz_fpath)

        # process and save as torch files
        print('Processing...')

        # create processed folder
        try:
            os.makedirs(os.path.join(self.root, self.processed_folder))
        except OSError as e:
            if e.errno == errno.EEXIST:
                pass
            else:
                raise

        # read train files
        left_train_img, right_train_img = self._read_image_file(self.data_files["train"]["dat"]["name"])
        train_info = self._read_info_file(self.data_files["train"]["info"]["name"])
        train_label = self._read_label_file(self.data_files["train"]["cat"]["name"])

        # read test files
        left_test_img, right_test_img = self._read_image_file(self.data_files["test"]["dat"]["name"])
        test_info = self._read_info_file(self.data_files["test"]["info"]["name"])
        test_label = self._read_label_file(self.data_files["test"]["cat"]["name"])

        # save training files
        self._save(left_train_img, "{}_left".format(self.train_image_file))
        self._save(right_train_img, "{}_right".format(self.train_image_file))
        self._save(train_label, self.train_label_file)
        self._save(train_info, self.train_info_file)

        # save test files
        self._save(left_test_img, "{}_left".format(self.test_image_file))
        self._save(right_test_img, "{}_right".format(self.test_image_file))
        self._save(test_label, self.test_label_file)
        self._save(test_info, self.test_info_file)

        print('Done!')

    @staticmethod
    def _parse_header(file_pointer):
        # Read magic number and ignore
        struct.unpack('<BBBB', file_pointer.read(4))  # '<' is little endian)

        # Read dimensions
        dimensions = []
        num_dims, = struct.unpack('<i', file_pointer.read(4))  # '<' is little endian)
        for _ in range(num_dims):
            dimensions.extend(struct.unpack('<i', file_pointer.read(4)))

        return dimensions

    def _read_image_file(self, file_name):
        fpath = os.path.join(self.root, self.raw_folder, file_name)
        with open(fpath, mode='rb') as f:
            dimensions = self._parse_header(f)
            assert dimensions == [24300, 2, 96, 96]
            num_samples, _, height, width = dimensions

            left_samples = np.zeros(shape=(num_samples, height, width), dtype=np.uint8)
            right_samples = np.zeros(shape=(num_samples, height, width), dtype=np.uint8)

            for i in range(num_samples):

                # left and right images stored in pairs, left first
                left_samples[i, :, :] = self._read_image(f, height, width)
                right_samples[i, :, :] = self._read_image(f, height, width)

        return torch.ByteTensor(left_samples), torch.ByteTensor(right_samples)

    @staticmethod
    def _read_image(file_pointer, height, width):
        """Read raw image data and restore shape as appropriate. """
        image = struct.unpack('<' + height * width * 'B', file_pointer.read(height * width))
        image = np.uint8(np.reshape(image, newshape=(height, width)))
        return image

    def _read_label_file(self, file_name):
        fpath = os.path.join(self.root, self.raw_folder, file_name)
        with open(fpath, mode='rb') as f:
            dimensions = self._parse_header(f)
            assert dimensions == [24300]
            num_samples = dimensions[0]

            struct.unpack('<BBBB', f.read(4))  # ignore this integer
            struct.unpack('<BBBB', f.read(4))  # ignore this integer

            labels = np.zeros(shape=num_samples, dtype=np.int32)
            for i in range(num_samples):
                category, = struct.unpack('<i', f.read(4))
                labels[i] = category
            return torch.LongTensor(labels)

    def _read_info_file(self, file_name):
        fpath = os.path.join(self.root, self.raw_folder, file_name)
        with open(fpath, mode='rb') as f:

            dimensions = self._parse_header(f)
            assert dimensions == [24300, 4]
            num_samples, num_info = dimensions

            struct.unpack('<BBBB', f.read(4))  # ignore this integer

            infos = np.zeros(shape=(num_samples, num_info), dtype=np.int32)

            for r in range(num_samples):
                for c in range(num_info):
                    info, = struct.unpack('<i', f.read(4))
                    infos[r, c] = info

        return torch.LongTensor(infos)


  target = self.labels[index % 24300] if self.mode is "all" else self.labels[index]
  info = self.infos[index % 24300] if self.mode is "all" else self.infos[index]


In [4]:
import torch
from torch.autograd import Variable
from torchvision import datasets, transforms


from tqdm import tqdm

train_transform = transforms.Compose([
    transforms.Resize(48),
    transforms.RandomCrop(32),
    transforms.ColorJitter(brightness=32./255, contrast=0.5),
    transforms.ToTensor(),
    transforms.Normalize((0.704,), (0.3081,))
    ])

test_transform = transforms.Compose([
    transforms.Resize(48),
    transforms.CenterCrop(32),
    transforms.ToTensor(),
    transforms.Normalize((0.704,), (0.3081,))
    ])

train_dataset = SmallNORB('./datasets/smallNORB/', train=True, download=True, transform=train_transform)
test_dataset = SmallNORB('./datasets/smallNORB/', train=False, transform=test_transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=256, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=True)

capsnet = CapsNet(reconstruction_type='FC',
        imsize=32,
        num_classes=5,
        routing_iterations = 3,
        primary_caps_gridsize=8,
        num_primary_capsules=32,
        batchnorm=False,
        loss = 'L2',
        leaky_routing=False)

#initialize_weights(capsnet)
checkpoint = torch.load('./checkpoint/95.pt')
capsnet.load_state_dict(checkpoint['net'])

capsnet = capsnet.cuda()

best_acc = 0
optimizer = torch.optim.Adam(capsnet.parameters(), lr=0.01)

In [5]:
print(capsnet.conv_layer.conv[0].weight)
capsnet.conv_layer.conv[0].weight = nn.Parameter(quantize(capsnet.conv_layer.conv[0].weight,8,7))
capsnet.conv_layer.conv[0].bias = nn.Parameter(quantize(capsnet.conv_layer.conv[0].bias,8,7))
capsnet.primary_capsules.capsules[0][0].weight = nn.Parameter(quantize(capsnet.primary_capsules.capsules[0][0].weight,8,7))
capsnet.primary_capsules.capsules[0][0].bias   = nn.Parameter(quantize(capsnet.primary_capsules.capsules[0][0].bias,8,7))
capsnet.primary_capsules.capsules[1][0].weight = nn.Parameter(quantize(capsnet.primary_capsules.capsules[1][0].weight,8,7))
capsnet.primary_capsules.capsules[1][0].bias   = nn.Parameter(quantize(capsnet.primary_capsules.capsules[1][0].bias,8,7))
capsnet.primary_capsules.capsules[2][0].weight = nn.Parameter(quantize(capsnet.primary_capsules.capsules[2][0].weight,8,7))
capsnet.primary_capsules.capsules[2][0].bias   = nn.Parameter(quantize(capsnet.primary_capsules.capsules[2][0].bias,8,7))
capsnet.primary_capsules.capsules[3][0].weight = nn.Parameter(quantize(capsnet.primary_capsules.capsules[3][0].weight,8,7))
capsnet.primary_capsules.capsules[3][0].bias   = nn.Parameter(quantize(capsnet.primary_capsules.capsules[3][0].bias,8,7))
capsnet.primary_capsules.capsules[4][0].weight = nn.Parameter(quantize(capsnet.primary_capsules.capsules[4][0].weight,8,7))
capsnet.primary_capsules.capsules[4][0].bias   = nn.Parameter(quantize(capsnet.primary_capsules.capsules[4][0].bias,8,7))
capsnet.primary_capsules.capsules[5][0].weight = nn.Parameter(quantize(capsnet.primary_capsules.capsules[5][0].weight,8,7))
capsnet.primary_capsules.capsules[5][0].bias   = nn.Parameter(quantize(capsnet.primary_capsules.capsules[5][0].bias,8,7))
capsnet.primary_capsules.capsules[6][0].weight = nn.Parameter(quantize(capsnet.primary_capsules.capsules[6][0].weight,8,7))
capsnet.primary_capsules.capsules[6][0].bias   = nn.Parameter(quantize(capsnet.primary_capsules.capsules[6][0].bias,8,7))
capsnet.primary_capsules.capsules[7][0].weight = nn.Parameter(quantize(capsnet.primary_capsules.capsules[7][0].weight,8,7))
capsnet.primary_capsules.capsules[7][0].bias   = nn.Parameter(quantize(capsnet.primary_capsules.capsules[7][0].bias,8,7))

Parameter containing:
tensor([[[[-3.3871e-02, -3.4372e-02, -2.5176e-02,  ..., -3.3590e-02,
           -3.9744e-02, -4.6258e-02],
          [-3.7661e-02, -3.2365e-02, -2.4548e-02,  ..., -2.7167e-02,
           -3.7441e-02, -4.6642e-02],
          [-3.7373e-02, -3.6723e-02, -3.2977e-02,  ..., -2.6947e-02,
           -3.5803e-02, -4.5367e-02],
          ...,
          [-3.7107e-02, -3.7552e-02, -3.6008e-02,  ..., -5.1857e-02,
           -5.2370e-02, -3.9731e-02],
          [-4.4641e-02, -4.2960e-02, -4.2923e-02,  ..., -4.7699e-02,
           -3.9117e-02, -2.4351e-02],
          [-5.5089e-02, -5.2823e-02, -5.1550e-02,  ..., -2.7869e-02,
           -1.6198e-02, -4.6681e-03]]],


        [[[-4.3599e-02, -4.1154e-02, -3.8226e-02,  ..., -1.7950e-02,
           -1.8593e-02, -1.8396e-02],
          [-4.2964e-02, -3.5716e-02, -2.5391e-02,  ..., -4.0317e-02,
           -3.8780e-02, -3.8879e-02],
          [-3.9582e-02, -3.2994e-02, -2.9308e-02,  ..., -2.4089e-02,
           -2.5968e-02, -4.3178e-0

In [7]:
def train(epoch):
    capsnet.train()
    train_correct = 0
    total = 0
    for batch, (data, target) in tqdm(list(enumerate(train_loader)), ascii=True, desc="Epoch{:3d}".format(epoch)):
        data, target = Variable(data), Variable(target)
        data, target = data.cuda(), target.cuda()
        target = torch.eye(5).cuda().index_select(dim=0, index=target)
        optimizer.zero_grad()
        capsule_output, reconstructions, _ = capsnet(data, target)
        predictions = torch.norm(capsule_output.squeeze(), dim=2)
        loss, rec_loss, marg_loss = capsnet.loss(data, target, capsule_output, reconstructions, 0.0005)

        loss.backward()
        optimizer.step()

        train_correct += (target.max(dim=1)[1] == predictions.max(dim=1)[1]).sum().item()
        total += target.size(0)

    print("acc = {}%".format(train_correct/total))

def test(epoch):
    global best_acc
    capsnet.eval()
    test_correct = 0
    total = 0

    for batch_id, (data, target) in tqdm(list(enumerate(test_loader)), ascii=True, desc="Test {:3d}".format(epoch)):
        data, target = Variable(data), Variable(target)
        data, target = data.cuda(), target.cuda()
        target = torch.eye(5).cuda().index_select(dim=0, index=target)

        capsule_output, reconstructions, predictions = capsnet(data)
        data = denormalize(data)
        loss, rec_loss, marg_loss = capsnet.loss(data, target, capsule_output, reconstructions, 0.0005)

        test_correct += (target.max(dim=1)[1] == predictions.max(dim=1)[1]).sum().item()
        total += target.size(0)

    print("acc = {}%".format(test_correct/total))
    acc = 100.*test_correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': capsnet.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/{}.pt'.format(str(acc)))
        best_acc = acc

for epoch in range(200):
    test(epoch)

  masked = Variable(x.new_tensor(torch.eye(self.num_capsules)))
Test   0: 100%|######################################################################| 380/380 [00:20<00:00, 18.45it/s]


acc = 0.8473868312757201%


Test   1: 100%|######################################################################| 380/380 [00:20<00:00, 18.41it/s]


acc = 0.8473868312757201%


KeyboardInterrupt: 