In [1]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable

import numpy as np

In [2]:
root = '/home/wilsonyan/data/mnist'

In [3]:
batch_size = 32

In [4]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dset_train = datasets.MNIST(root, train=True, download=True, transform=transform)
dset_test = datasets.MNIST(root, train=False, download=True, transform=transform)

loader_train = data.DataLoader(dset_train, batch_size=batch_size, shuffle=True)
loader_test = data.DataLoader(dset_test, batch_size=batch_size, shuffle=True)

In [5]:
def squash(x, dim=-1):
    norm_squared = torch.sum(x ** 2, dim, keepdim=True)
    x = norm_squared / (1 + norm_squared) * x / torch.sqrt(norm_squared)
    return x

In [79]:
def margin_loss(out, y):
    pred = torch.sqrt((out ** 2).sum(-1))
    loss = y * torch.max(0.9 - pred, 0)[0] ** 2 + 0.5 * (1 - y) * torch.max(pred - 0.1, 0)[0] ** 2
    return loss.sum()

In [80]:
def mse_loss(out, y):
    return torch.sum((out - y) ** 2) / out.data.nelement()

In [61]:
class PrimaryCaps(nn.Module):
    def __init__(self, n_dim):
        super(PrimaryCaps, self).__init__()
        self.n_dim = n_dim
        self.conv = nn.Conv2d(256, 32 * 8, 9, stride=2)
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size()[0], -1, self.n_dim)
        x = squash(x)
        return x

In [85]:
class DigitCaps(nn.Module):
    def __init__(self, n_dim, prev_dim, n_iter):
        super(DigitCaps, self).__init__()
        self.n_dim = n_dim
        self.prev_dim = prev_dim
        self.n_iter = n_iter
        self.weights = nn.Parameter(torch.randn(10, 6 * 6 * 32, prev_dim, n_dim))
    
    def forward(self, x):
        x = x.unsqueeze(1).unsqueeze(3)
        u_hat = torch.matmul(x, self.weights).squeeze(3)
        
        b = Variable(torch.zeros(10, 6 * 6 * 32))
        for i in range(self.n_iter):
            c = F.softmax(b)
            s = u_hat.mul(c.unsqueeze(0).unsqueeze(3)).sum(2)
            v = squash(s)
            if i < self.n_iter - 1:
                b = b + u_hat.mul(v.unsqueeze(2)).sum(-1).sum(0)
        return v

In [86]:
class CapsuleNet(nn.Module):
    def __init__(self):
        super(CapsuleNet, self).__init__()
        self.conv = nn.Conv2d(1, 256, 9)
        self.primary_caps = PrimaryCaps(8)
        self.digit_caps = DigitCaps(16, 8, 5)
        self.fc1 = nn.Linear(160, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, 784)
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.primary_caps(x)
        out = self.digit_caps(x)
        x = out.view(out.size()[0], -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.sigmoid(self.fc3(x))
        reconstruction = x.view(x.size()[0], 28, 28)
        return out, reconstruction

In [87]:
model = CapsuleNet()
optimizer = optim.RMSprop(model.parameters())

In [90]:
def train(model, optimizer, loader, num_epochs=10, show_every=20):
    for epoch in range(num_epochs):
        print('Epoch %s' % epoch)
        print('=' * 10)
        losses = []
        for i, (x, y) in enumerate(iter(loader)):
            y_one_hot = torch.FloatTensor(y.size()[0], 10)
            y_one_hot.scatter_(1, y.unsqueeze(1), 1)
            x, y_one_hot = Variable(x), Variable(y_one_hot)
            out, reconstruction = model(x)
            loss_margin = margin_loss(out, y_one_hot)
            loss_rec = mse_loss(x, reconstruction)
            loss = loss_margin + 0.0005 * loss_rec
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            losses.append(loss.data[0])
            if i * show_every == 0:
                print('Loss: %s' % np.mean(losses))
            
        print('Mean Loss: %s' % np.mean(losses))
            

In [None]:
model = train(model, optimizer, loader_train, num_epochs=10)

Epoch 0
Loss: -2.11439733168e+32
