In [43]:
import torchvision.datasets as datasets
import torchvision.transforms as transforms

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data
from torch.autograd import Variable

In [19]:
root = '/home/wilsonyan/data/mnist'

In [23]:
batch_size = 32

In [28]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

dset_train = datasets.MNIST(root, train=True, download=True, transform=transform)
dset_test = datasets.MNIST(root, train=False, download=True, transform=transform)

loader_train = data.DataLoader(dset_train, batch_size=batch_size, shuffle=True)
loader_test = data.DataLoader(dset_test, batch_size=batch_size, shuffle=True)

In [108]:
def squash(x, dim=-1):
    norm_squared = torch.sum(x ** 2, dim, keepdim=True)
    x = norm_squared / (1 + norm_squared) * x / torch.sqrt(norm_squared)
    return x

In [109]:
class PrimaryCaps(nn.Module):
    def __init__(self, n_dim):
        super(PrimaryCaps, self).__init__()
        self.n_dim = n_dim
        self.conv = nn.Conv2d(256, 32 * 8, 9, stride=2)
    
    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size()[0], -1, self.n_dim)
        x = squash(x)
        return x

In [135]:
class DigitCaps(nn.Module):
    def __init__(self, n_dim, prev_dim, n_iter):
        super(DigitCaps, self).__init__()
        self.n_dim = n_dim
        self.prev_dim = prev_dim
        self.n_iter = n_iter
        self.weights = nn.Parameter(torch.randn(10, 6 * 6 * 32, prev_dim, n_dim))
    
    def forward(self, x):
        x = x.unsqueeze(1).unsqueeze(3)
        u_hat = torch.matmul(x, self.weights).squeeze(3)
        
        b = Variable(torch.zeros(10, 6 * 6 * 32))
        for i in range(self.n_iter):
            c = F.softmax(b)
            s = u_hat.mul(c.unsqueeze(0).unsqueeze(3)).sum(2)
            v = squash(s)
            if i < self.n_iter - 1:
                b += u_hat.mul(v.unsqueeze(2)).sum(-1).sum(0)
        return v

In [136]:
class CapsuleNet(nn.Module):
    def __init__(self):
        super(CapsuleNet, self).__init__()
        self.conv = nn.Conv2d(1, 256, 9)
        self.primary_caps = PrimaryCaps(8)
        self.digit_caps = DigitCaps(16, 8, 5)
        
    def forward(self, x):
        x = F.relu(self.conv(x))
        x = self.primary_caps(x)
        x = self.digit_caps(x)

In [137]:
model = CapsuleNet()
optimizer = optim.RMSprop(model.parameters())

In [138]:
def train(model, optimizer, loader, num_epochs=10):
    for epoch in range(num_epochs):
        for x, y in iter(loader):
            x, y = Variable(x), Variable(y)
            out = model(x)

In [139]:
model = train(model, optimizer, loader_train, num_epochs=10)

torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 1, 16]) torch.Size([32, 10, 1152, 16])
torch.Size([32, 10, 16])


KeyboardInterrupt: 