In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [None]:
# Define the squashing function
def squash(s, dim=-1):
    norm = torch.norm(s, dim=dim, keepdim=True)
    scale = norm ** 2 / (1 + norm ** 2)
    return scale * s / norm

In [None]:
# Define the Capsule Layer
class CapsuleLayer(nn.Module):
    def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, num_iterations=3):
        super(CapsuleLayer, self).__init__()
        self.num_capsules = num_capsules
        self.num_route_nodes = num_route_nodes
        self.num_iterations = num_iterations
        self.out_channels = out_channels
        self.in_channels = in_channels  # Store in_channels as an attribute

        if num_route_nodes != -1:
            self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
        else:
            self.capsules = nn.ModuleList([
                nn.Conv2d(in_channels, out_channels, kernel_size=9, stride=2, padding=0)
                for _ in range(num_capsules)
            ])

    def forward(self, x, routing_logits=None):
        if self.num_route_nodes != -1:
            # Fully connected capsules
            batch_size = x.size(0)
            # Assuming x is of shape (batch_size, num_route_nodes, in_channels)
            if len(x.shape) != 3:
                # Reshape x to (batch_size, num_route_nodes, in_channels)
                x = x.view(batch_size, self.num_route_nodes, self.in_channels)

            # Reshape route_weights for broadcasting
            route_weights = self.route_weights.view(1, self.num_capsules, self.num_route_nodes, self.in_channels, self.out_channels)
            route_weights = route_weights.expand(batch_size, -1, -1, -1, -1)

            # Compute priors using broadcasting and summation
            priors = torch.sum(x.unsqueeze(1).unsqueeze(-1) * route_weights, dim=3)  # Shape: (batch_size, num_capsules, num_route_nodes, out_channels)

            if routing_logits is None:
                routing_logits = torch.zeros(batch_size, self.num_capsules, self.num_route_nodes).to(x.device)

            for i in range(self.num_iterations):
                # Softmax along the num_route_nodes dimension
                c = F.softmax(routing_logits, dim=2)
                # Compute assignments
                s = torch.sum(c.unsqueeze(3) * priors, dim=2, keepdim=True)  # Shape: (batch_size, num_capsules, 1, out_channels)
                # Apply squash
                v = squash(s)
                if i != self.num_iterations - 1:
                    # Update routing logits
                    agreement = torch.sum(priors * v, dim=-1, keepdim=False)  # Shape: (batch_size, num_capsules, num_route_nodes)
                    routing_logits = routing_logits + agreement

            return v.squeeze(2)
        else:
            # Convolutional capsules
            u = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
            u = torch.cat(u, dim=-1)
            return squash(u)

In [None]:
# Define the Margin Loss
class MarginLoss(nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lambda_val=0.5):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_val = lambda_val

    def forward(self, v_c, t_c):
        left = F.relu(self.m_pos - v_c).pow(2)
        right = F.relu(v_c - self.m_neg).pow(2)
        loss = t_c * left + self.lambda_val * (1.0 - t_c) * right
        return loss.sum(dim=1).mean()

In [None]:
# Define the Capsule Network
class CapsuleNet(nn.Module):
    def __init__(self):
        super(CapsuleNet, self).__init__()  # Initialize the superclass
        self.conv1 = nn.Conv2d(1, 256, kernel_size=9, stride=1)
        self.primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32)
        self.digit_capsules = CapsuleLayer(num_capsules=10, num_route_nodes=8 * 6 * 6, in_channels=32, out_channels=16)
        self.decoder = nn.Sequential(
            nn.Linear(16 * 10, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.ReLU(inplace=True),
            nn.Linear(1024, 784),
            nn.Sigmoid()
        )

    def forward(self, x, y=None):
        x = F.relu(self.conv1(x), inplace=True)
        x = self.primary_capsules(x)
        x = x.view(x.size(0), -1, self.primary_capsules.out_channels)  # Flatten to (batch_size, num_route_nodes, in_channels)
        x = self.digit_capsules(x)
        v_c = torch.norm(x, dim=-1)

        if y is not None:
            reconstructions = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
            return v_c, reconstructions
        else:
            return v_c

In [None]:
# Define the training loop
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        v_c, reconstructions = model(data, F.one_hot(target, 10).float())
        loss = criterion(v_c, F.one_hot(target, 10).float()) + 0.0005 * F.mse_loss(reconstructions, data.view(-1, 784))
        loss.backward()
        optimizer.step()
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')

# Define the testing loop
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            v_c, reconstructions = model(data, F.one_hot(target, 10).float())
            test_loss += criterion(v_c, F.one_hot(target, 10).float()).item()
            pred = v_c.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.0f}%)\n')

# Main function
def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    batch_size = 64
    epochs = 10
    learning_rate = 0.001

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    train_dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST('./data', train=False, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    model = CapsuleNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    criterion = MarginLoss()

    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, criterion, epoch)
        test(model, device, test_loader, criterion)

if __name__ == '__main__':
    main()


Test set: Average loss: 0.0003, Accuracy: 9838/10000 (98%)



KeyboardInterrupt: 