In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# IMPORTANT TO DEFINE FIRST THE SQUASHING FUNCTION

In [None]:
# Define the squashing function
def squash(s, dim=-1):
    squared_norm = (s ** 2).sum(dim=dim, keepdim=True)
    scale = squared_norm / (1 + squared_norm)
    return scale * s / torch.sqrt(squared_norm + 1e-8)

In [None]:
# Define the PrimaryCapsules layer
class PrimaryCapsules(nn.Module):
    def __init__(self, num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, stride=2):
        super(PrimaryCapsules, self).__init__()
        self.capsules = nn.ModuleList([
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, bias=False)
            for _ in range(num_capsules)
        ])

    def forward(self, x):
        u = [capsule(x) for capsule in self.capsules]
        u = torch.stack(u, dim=1)
        u = u.view(x.size(0), -1, 8)
        return squash(u)

# Define the DigitCapsules layer
class DigitCapsules(nn.Module):
    def __init__(self, num_capsules=10, num_routes=32 * 6 * 6, in_channels=8, out_channels=16, num_iterations=3):
        super(DigitCapsules, self).__init__()
        self.num_routes = num_routes
        self.num_iterations = num_iterations
        self.W = nn.Parameter(torch.randn(1, num_routes, num_capsules, out_channels, in_channels))

    def forward(self, u):
        batch_size = u.size(0)
        u = u.unsqueeze(2).unsqueeze(4).repeat(1, 1, self.W.size(2), 1, 1)  # Shape: (batch_size, num_routes, num_capsules, in_channels, 1)
        u_hat = torch.matmul(self.W, u).squeeze(-1)  # Shape: (batch_size, num_routes, num_capsules, out_channels)
        b_ij = torch.zeros(batch_size, self.num_routes, self.W.size(2), 1).to(u.device)

        for iteration in range(self.num_iterations):
            c_ij = F.softmax(b_ij, dim=2)  # Shape: (batch_size, num_routes, num_capsules, 1)
            s_j = (c_ij * u_hat).sum(dim=1, keepdim=True)  # Shape: (batch_size, 1, num_capsules, out_channels)
            v_j = squash(s_j)  # Shape: (batch_size, 1, num_capsules, out_channels)
            if iteration < self.num_iterations - 1:
                b_ij = b_ij + (u_hat * v_j).sum(dim=-1, keepdim=True)  # Shape: (batch_size, num_routes, num_capsules, 1)

        return v_j.squeeze(1)

In [None]:
# Define the CapsNet architecture
class CapsNet(nn.Module):
    def __init__(self, num_classes=10):
        super(CapsNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
        self.primary_capsules = PrimaryCapsules(num_capsules=8, in_channels=256, out_channels=32, kernel_size=9, stride=2)
        self.digit_capsules = DigitCapsules(num_capsules=num_classes)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.primary_capsules(x)
        x = self.digit_capsules(x)
        return x

In [None]:
# Define the margin loss
def margin_loss(x, labels, m_plus=0.9, m_minus=0.1, lambda_=0.5):
    v_c = torch.sqrt((x ** 2).sum(dim=2, keepdim=True))
    T_c = F.one_hot(labels, num_classes=10).float().unsqueeze(2)
    L_c = T_c * torch.relu(m_plus - v_c).pow(2) + lambda_ * (1 - T_c) * torch.relu(v_c - m_minus).pow(2)
    return L_c.mean()

In [None]:
# Data loaders
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

In [None]:
train_loader = DataLoader(datasets.MNIST('../data', train=True, download=True, transform=transform), batch_size=128, shuffle=True)
test_loader = DataLoader(datasets.MNIST('../data', train=False, download=True, transform=transform), batch_size=128, shuffle=False)

# Initialize the model, loss function, and optimizer
model = CapsNet()
criterion = margin_loss
optimizer = optim.Adam(model.parameters(), lr=0.001)

# NOW TRAIN IT WITH SECOND VERSION

In [None]:










# Training loop
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)}] Loss: {loss.item():.6f}')

# Testing loop
def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            output = model(data)
            loss = criterion(output, target)
            test_loss += loss.item()
            _, predicted = output.max(dim=2)
            correct += predicted.squeeze().eq(target).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({accuracy:.0f}%)\n')

# Run the training and testing loops
for epoch in range(1, 11):
    train(epoch)
    test()

Train Epoch: 1 [0/60000] Loss: 0.363890
Train Epoch: 1 [12800/60000] Loss: 0.364424


KeyboardInterrupt: 