<a href="https://colab.research.google.com/github/taliafabs/CSC413-Project/blob/main/Project.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##Introduction##
This file contains the code that creates architecture and experiments its capabilities.

The debugging print statements will show the matrix dimensions during computation.

Running with GPU will significantly reduce the time consumption, it can also handle more data.

Example. CPU (subset_ratio=0.1, batch_size <= 16), GPU (subset_ratio=10, batch_size=64)

To run CapsNet alone, set hebb=False in the model declaration statement.

##CapsNet Architecture##
The following section defines the CapsNet with Hebbian Softmax architecture, and tests whether it is functioning properly.

###Setting###

In [None]:
# Libraries
# !pip install wget
import os
# import wget
from zipfile import ZipFile
from PIL import Image, UnidentifiedImageError
import torch
import torch.nn.functional as F
from torch import nn
from torch.optim import Adam, lr_scheduler
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms
from torchvision.datasets import ImageFolder
from torch.autograd import Variable
from collections import namedtuple, Counter, OrderedDict
from sklearn.metrics import precision_score, recall_score, f1_score
import numpy as np
import matplotlib.pyplot as plt
import wandb

In [None]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
# Debugger
debug = False
def debug_message(msg):
    if debug:
        print(msg)

# Save Activations for Hebbian Learning
def save_activation(activations, name):
    def hook(model, input, output):
        activations[name] = output
        debug_message(f"Hook triggered for {name}, activation shape: {output.shape}")
    return hook

def register_hooks(model, activations):
    for name, layer in model.named_modules():
        if name:
            layer.register_forward_hook(save_activation(activations, name))

### Network ###

In [None]:
class CapsuleNet(nn.Module):
    def __init__(self, input_size, classes, routings, softmax=False):
        super(CapsuleNet, self).__init__()
        self.softmax = softmax
        self.input_size = input_size
        self.classes = classes
        self.routings = routings

        self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0)
        self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0)
        self.digitcaps = DigitCapsule(in_num_caps=32 * 6 * 6, in_dim_caps=8,
                                      out_num_caps=classes, out_dim_caps=16, routings=routings)

        self.decoder = nn.Sequential(
            OrderedDict([
                ("fc1", nn.Linear(16 * classes, 512)),
                ("relu1", nn.ReLU(inplace=True)),
                ("fc2", nn.Linear(512, 1024)),
                ("relu2", nn.ReLU(inplace=True)),
                ("fc3", nn.Linear(1024, input_size[0] * input_size[1] * input_size[2])),
                ("sigmoid", nn.Sigmoid())
            ])
        )

        if softmax:
            self.hebbsoftmax = HebbianSoftmax(input_size[0] * input_size[1] * input_size[2], classes)

        self.to(DEVICE)

    def forward(self, x, y=None):
        x = x.to(DEVICE)
        debug_message(f"Input shape: {x.shape}")

        x = self.conv1(x)
        debug_message(f"Conv1 Output shape: {x.shape}")

        x = nn.functional.relu(x)
        debug_message(f"ReLU Ouput shape: {x.shape}")

        x = self.primarycaps(x)

        x = self.digitcaps(x)

        length = x.norm(dim=-1)
        debug_message(f"Length shape (class probabilities): {length.shape}")

        if y is None:  # during testing, no label is given, so we need to create one-hot coding using `length`
            index = length.max(dim=1)[1]
            y = torch.zeros(length.size(), device=DEVICE).scatter_(1, index.view(-1, 1), 1.)

        y = y.to(DEVICE)
        reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
        debug_message(f"Decoder Output shape: {reconstruction.shape}")

        if self.softmax:
            length = self.hebbsoftmax(reconstruction)

        return length, reconstruction.view(-1, *self.input_size)

### Layer ###

In [None]:
# Helper Functions
def squash(inputs, axis=-1):
    norm = torch.norm(inputs, p=2, dim=axis, keepdim=True)
    scale = (norm**2 / (1 + norm**2)) / (norm + 1e-8)
    return scale * inputs


class PrimaryCapsule(nn.Module):
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0):
        super(PrimaryCapsule, self).__init__()
        self.to(DEVICE)
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        x.to(DEVICE)
        debug_message(f"PrimaryCapsule Input shape: {x.shape}")

        outputs = self.conv2d(x)
        debug_message(f"After Conv2D shape: {outputs.shape}")

        outputs = outputs.view(x.size(0), -1, self.dim_caps)
        debug_message(f"PrimaryCapsule Output shape: {outputs.shape}")

        return squash(outputs)


class DigitCapsule(nn.Module):
    def __init__(self, in_num_caps, in_dim_caps, out_num_caps, out_dim_caps, routings=3):
        super(DigitCapsule, self).__init__()
        self.to(DEVICE)
        self.in_num_caps = in_num_caps
        self.in_dim_caps = in_dim_caps
        self.out_num_caps = out_num_caps
        self.out_dim_caps = out_dim_caps
        self.routings = routings
        self.weight = nn.Parameter(0.01 * torch.randn(out_num_caps, in_num_caps, out_dim_caps, in_dim_caps))

    def forward(self, x):
        x.to(DEVICE)
        debug_message(f"DigitCapsule Input shape: {x.shape}")

        x = x.view(x.size(0), 1, self.in_num_caps, self.in_dim_caps, 1)
        debug_message(f"After reshape for routing shape: {x.shape}")

        self.weight = self.weight.to(DEVICE)
        x_hat = torch.squeeze(torch.matmul(self.weight, x), dim=-1)
        debug_message(f"x_hat shape: {x_hat.shape}")

        b = torch.zeros(x.size(0), self.out_num_caps, self.in_num_caps, device=DEVICE)
        for i in range(self.routings):
            c = F.softmax(b, dim=1)
            debug_message(f"Routing {i+1} coupling coefficients shape: {c.shape}")
            if i == self.routings - 1:
                outputs = squash(torch.sum(c[:, :, :, None] * x_hat, dim=-2, keepdim=True))
            else:
                outputs = squash(torch.sum(c[:, :, :, None] * x_hat.detach(), dim=-2, keepdim=True))
                b = b + torch.sum(outputs * x_hat.detach(), dim=-1)

        debug_message(f"DigitCapsule Output shape: {outputs.shape}")
        return torch.squeeze(outputs, dim=-2)


class HebbianSoftmax(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(HebbianSoftmax, self).__init__()
        self.to(DEVICE)
        self.linear = nn.Linear(input_dim, output_dim, bias=False)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, x):
        x.to(DEVICE)
        debug_message(f"HebbianSoftmax Input shape: {x.shape}")
        x = self.linear(x)
        debug_message(f"After Linear shape: {x.shape}")
        x = self.softmax(x)
        debug_message(f"HebbianSoftmax Output shape: {x.shape}")
        return x

### Train & Test ###

In [None]:
# Helper Functions
def init_hebb_param(class_labels, epochs):
    Nmin = min(Counter(class_labels).values())
    gamma = 1 / Nmin
    T = Nmin * epochs  # should be epochs until convergence
    return gamma, T


def compute_metrics(y_true, y_pred):
    y_true = y_true.cpu().numpy()
    y_pred = y_pred.cpu().numpy()

    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)

    return precision, recall, f1

def show_reconstruction(x, x_recon):
    plt.subplot(1, 2, 1)
    plt.title("Original")
    plt.imshow(x[0].cpu().squeeze(), cmap='gray')

    plt.subplot(1, 2, 2)
    plt.title("Reconstruction")
    plt.imshow(x_recon[0].cpu().detach().squeeze(), cmap='gray')
    plt.show()

In [None]:
def caps_loss(y_true, y_pred, x, x_recon, lam_recon):
    """
    Capsule loss = Margin loss + lam_recon * reconstruction loss.
    """
    assert y_true.size() == y_pred.size(), f"Shape mismatch: {y_true.size()} vs {y_pred.size()}"
    y_true = y_true.to(DEVICE)
    y_pred = y_pred.to(DEVICE)
    x = x.to(DEVICE)
    x_recon = x_recon.to(DEVICE)

    L = y_true * torch.clamp(0.9 - y_pred, min=0.) ** 2 + \
        0.5 * (1 - y_true) * torch.clamp(y_pred - 0.1, min=0.) ** 2
    L_margin = L.sum(dim=1).mean()
    L_recon = nn.MSELoss()(x_recon, x)
    return L_margin + lam_recon * L_recon


def test(model, test_loader, cfg):
    model.eval()
    test_loss = 0.0
    correct = 0
    recon_mse = []
    y_true_list = []
    y_pred_list = []

    with torch.no_grad():
        for x, y in test_loader:
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            y = torch.zeros(y.size(0), cfg["classes"], device=DEVICE).scatter_(1, y.view(-1, 1), 1.)  # change to one-hot coding
            y_pred, x_recon = model(x)

            # Loss
            test_loss += caps_loss(y, y_pred, x, x_recon, cfg["lam_recon"]).item() * x.size(0)

            # Classification Accuracy
            y_pred = y_pred.data.max(1)[1]
            y_true = y.data.max(1)[1]
            correct += y_pred.eq(y_true).cpu().sum().item()

            # Reconstruction Accuracy
            x_recon = 2 * x_recon - 1  # Transform to [-1, 1]
            recon_error = torch.mean((x - x_recon) ** 2, dim=(1, 2, 3))
            recon_mse.append(recon_error.mean().item())

            # Metrics
            y_true_list.append(y_true.cpu())
            y_pred_list.append(y_pred.cpu())

    data_size = len(test_loader.dataset)
    test_loss /= data_size
    test_acc = correct / data_size
    test_recon_err = sum(recon_mse) / len(recon_mse)
    precision, recall, f1 = compute_metrics(torch.cat(y_true_list), torch.cat(y_pred_list))
    return test_loss, test_acc, test_recon_err, precision, recall, f1

In [None]:
def train(model, train_loader, test_loader, cfg):
    print('Training starts...')
    writer = SummaryWriter(log_dir=cfg["save_dir"])
    wandb.init(
        project=cfg["wb"]["project"],
        name=cfg["wb"]["run"],
        config={
            "learning_rate": cfg["lr"],
            "epochs": cfg["epochs"],
            "batch_size": cfg["batch_size"]
            })

    optimizer = Adam(model.parameters(), lr=cfg["lr"])
    lr_decay = lr_scheduler.ExponentialLR(optimizer, gamma=cfg["lr_decay"])

    all_labels = [label for _, labels in train_loader for label in labels]
    gamma, T = init_hebb_param(all_labels, cfg["epochs"])
    class_occurrences = {i: 0 for i in range(cfg["classes"])}

    model.to(DEVICE)
    activations = {}
    register_hooks(model, activations)

    for epoch in range(cfg["epochs"]):
        debug_message(f"Epoch: {epoch}")
        model.train()
        lr = optimizer.param_groups[0]['lr']
        train_loss = 0.0
        correct = 0
        recon_mse = []
        y_true_list = []
        y_pred_list = []

        for _, (x, y) in enumerate(train_loader):
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            y = torch.zeros(y.size(0), cfg["classes"], device=DEVICE).scatter_(1, y.view(-1, 1), 1.)  # change to one-hot coding
            debug_message(f"One-hot Encoding y: {y.shape}")

            optimizer.zero_grad()
            y_pred, x_recon = model(x, y)
            loss = caps_loss(y, y_pred, x, x_recon, cfg["lam_recon"])
            loss.backward()

            for name, param in model.named_parameters():
                debug_message(f"Layer: {name}, Shape of theta: {param.shape}")
                if "hebbsoftmax" in name:
                    with torch.no_grad():
                        # intermediate SGD update
                        intermediate_update = torch.zeros_like(param.data)
                        h = activations["decoder"]
                        debug_message(f"Layer: {name}, Shape of activation: {h.shape}")
                        error = y_pred - y
                        debug_message(f"Shape of error: {error.shape}")
                        grad = torch.matmul(error.T, h)
                        debug_message(f"Shape of gradient: {grad.shape}")
                        intermediate_update = param - lr * grad
                        debug_message(f"Shape of theta_0.5: {intermediate_update.shape}")

                        # Apply the blending mechanism
                        n_t = y.sum(dim=0)  # occurrences of each class
                        debug_message(f"Shape of n_t: {n_t.shape}")
                        for i in range(cfg["classes"]):
                            n_t_i = n_t[i].item()
                            if n_t_i > 0:
                                lambda_t_i = max(1 / (class_occurrences[i] + 1), gamma) if class_occurrences[i] < T else 0
                                h_bar_t_i = (h * y[:, i][:, None]).mean(dim=0)
                                debug_message(f"Shape of h_bar_t_i: {h_bar_t_i.shape}")
                                blended_update = lambda_t_i * h_bar_t_i + (1 - lambda_t_i) * intermediate_update[i]
                                param.data[i].copy_(blended_update)
                            else:
                                param.data[i].copy_(intermediate_update[i])
                            debug_message(f"Shape of updated theta: {param.shape}")
                            class_occurrences[i] += n_t_i  # increment class occurrence counter

            # Loss
            train_loss += loss.item() * x.size(0)

            # Classification Accuracy
            y_pred = y_pred.data.max(1)[1]
            y_true = y.data.max(1)[1]
            correct += y_pred.eq(y_true).cpu().sum().item()

            # Reconstruction Accuracy (MSE)
            x_recon = 2 * x_recon - 1  # Transform to [-1, 1]
            recon_error = torch.mean((x - x_recon) ** 2, dim=(1, 2, 3))
            recon_mse.append(recon_error.mean().item())

            # Metrics
            y_true_list.append(y_true.cpu())
            y_pred_list.append(y_pred.cpu())

            optimizer.step()
        lr_decay.step()

        # Print All Metrics
        data_size = len(train_loader.dataset)
        train_loss /= data_size
        train_acc = correct / data_size
        train_recon_err = sum(recon_mse) / len(recon_mse)
        precision, recall, f1 = compute_metrics(torch.cat(y_true_list), torch.cat(y_pred_list))

        test_loss, test_acc, test_recon_acc, test_precision, test_recall, test_f1 = test(model, test_loader, cfg)

        print('train: epoch = %d, loss = %.4f, classfication acc = %.4f, reconstruction err = %.4f' % (epoch, train_loss, train_acc, train_recon_err))
        print('                   precision = %.4f, recall = %.4f, f1-score = %.4f' % (precision, recall, f1))
        print('test: epoch = %d, loss = %.4f, classfication acc = %.4f, reconstruction err = %.4f' % (epoch, test_loss, test_acc, test_recon_acc))
        print('                  precision = %.4f, recall = %.4f, f1-score = %.4f' % (test_precision, test_recall, test_f1))
        wandb.log({
              "train_recon_err": train_recon_err,
              "train_loss": train_loss,
              "train_acc": train_acc,
              "train_f1": f1,
              "train_precision": precision,
              "train_recall": recall,
              "test_recon_err": test_recon_acc,
              "test_loss": test_loss,
              "test_acc": test_acc,
              "test_f1": test_f1,
              "test_precision": test_precision,
              "test_recall": test_recall,
          })
    print('Training Finished')
    wandb.finish()
    return model

In [None]:
def train_hebb(model, train_loader, test_loader, cfg):
    print('Training starts...')
    writer = SummaryWriter(log_dir=cfg["save_dir"])
    wandb.init(
        project=cfg["wb"]["project"],  # Replace with your project name
        name=cfg["wb"]["run"],             # Optional: Customize the run name
        config={                     # Optional: Hyperparameters or config
            "learning_rate": cfg["lr"],
            "epochs": cfg["epochs"],
            "batch_size": cfg["batch_size"]
                                     })
    optimizer = Adam(model.parameters(), lr=cfg["lr"])
    lr_decay = lr_scheduler.ExponentialLR(optimizer, gamma=cfg["lr_decay"])

    model.to(DEVICE)
    activations = {}
    register_hooks(model, activations)

    for epoch in range(cfg["epochs"]):
        debug_message(f"Epoch: {epoch}")
        model.train()
        lr = optimizer.param_groups[0]['lr']
        train_loss = 0.0
        correct = 0
        recon_mse = []
        y_true_list = []
        y_pred_list = []

        for _, (x, y) in enumerate(train_loader):
            x = x.to(DEVICE)
            y = y.to(DEVICE)
            y = torch.zeros(y.size(0), cfg["classes"], device=DEVICE).scatter_(1, y.view(-1, 1), 1.)  # change to one-hot coding
            debug_message(f"One-hot Encoding y: {y.shape}")

            optimizer.zero_grad()
            y_pred, x_recon = model(x, y)
            loss = caps_loss(y, y_pred, x, x_recon, cfg["lam_recon"])
            loss.backward()

            for name, param in model.named_parameters():
                debug_message(f"Layer: {name}, Shape of theta: {param.shape}")
                if "decoder.fc2.weight" in name or "decoder.fc3.weight" in name:
                    h = activations["decoder.relu1"] if "decoder.fc2.weight" in name else activations["decoder.relu2"]
                    debug_message(f"Shape of h: {h.shape}")
                    error = activations["decoder.relu2"] if "decoder.fc2.weight" in name else activations["decoder"]
                    debug_message(f"Shape of error: {error.shape}")
                    with torch.no_grad():
                        # intermediate SGD update
                        intermediate_update = torch.zeros_like(param.data)
                        grad = torch.matmul(error.T, h)
                        debug_message(f"Shape of gradient: {grad.shape}")
                        intermediate_update = param - lr * grad
                        debug_message(f"Shape of theta_0.5: {intermediate_update.shape}")

                        # Apply the blending mechanism
                        lambda_t = 0.8
                        h_bar_t = (h.unsqueeze(1) * error[:, :, None]).mean(dim=0)
                        debug_message(f"Shape of h_bar_t: {h_bar_t.shape}")
                        blended_update = lambda_t * h_bar_t + (1 - lambda_t) * intermediate_update
                        param.data.copy_(blended_update)
                        debug_message(f"Shape of updated theta: {param.shape}")

            # Loss
            train_loss += loss.item() * x.size(0)

            # Classification Accuracy
            y_pred = y_pred.data.max(1)[1]
            y_true = y.data.max(1)[1]
            correct += y_pred.eq(y_true).cpu().sum().item()

            # Reconstruction Accuracy (MSE)
            x_recon = 2 * x_recon - 1  # Transform to [-1, 1]
            recon_error = torch.mean((x - x_recon) ** 2, dim=(1, 2, 3))
            recon_mse.append(recon_error.mean().item())

            # Metrics
            y_true_list.append(y_true.cpu())
            y_pred_list.append(y_pred.cpu())

            optimizer.step()
        lr_decay.step()

        # Print All Metrics
        data_size = len(train_loader.dataset)
        train_loss /= data_size
        train_acc = correct / data_size
        train_recon_err = sum(recon_mse) / len(recon_mse)
        precision, recall, f1 = compute_metrics(torch.cat(y_true_list), torch.cat(y_pred_list))

        test_loss, test_acc, test_recon_acc, test_precision, test_recall, test_f1 = test(model, test_loader, cfg)

        print('train: epoch = %d, loss = %.4f, classfication acc = %.4f, reconstruction err = %.4f' % (epoch, train_loss, train_acc, train_recon_err))
        print('                   precision = %.4f, recall = %.4f, f1-score = %.4f' % (precision, recall, f1))
        print('test: epoch = %d, loss = %.4f, classfication acc = %.4f, reconstruction err = %.4f' % (epoch, test_loss, test_acc, test_recon_acc))
        print('                  precision = %.4f, recall = %.4f, f1-score = %.4f' % (test_precision, test_recall, test_f1))
        wandb.log({
              "train_recon_err": train_recon_err,
              "train_loss": train_loss,
              "train_acc": train_acc,
              "train_f1": f1,
              "train_precision": precision,
              "train_recall": recall,

              "test_recon_err": test_recon_acc,
              "test_loss": test_loss,
              "test_acc": test_acc,
              "test_f1": test_f1,
              "test_precision": test_precision,
              "test_recall": test_recall,
          })
    print('Training Finished')
    wandb.finish()
    return model

## Baseline CNN Architecture ##
The following section includes creates a baseline CNN to compare with CapsNet.

In [None]:
class CNNAutoencoder(nn.Module):
    def __init__(self, input_size, classes, softmax=True):
        super(CNNAutoencoder, self).__init__()
        self.hebb = softmax
        self.input_size = input_size
        self.classes = classes
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 256, kernel_size=9, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=9, stride=2, padding=0),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 784, kernel_size=6, stride=1, padding=0),
            nn.ReLU(inplace=True),
            nn.Flatten()
        )
        self.decoder = nn.Sequential(
            OrderedDict([
                ("fc1", nn.Linear(784, 512)),
                ("relu1", nn.ReLU(inplace=True)),
                ("fc2", nn.Linear(512, 1024)),
                ("relu2", nn.ReLU(inplace=True)),
                ("fc3", nn.Linear(1024, input_size[0] * input_size[1] * input_size[2])),
                ("sigmoid", nn.Sigmoid())
            ])
        )
        self.classifier = nn.Sequential(
            nn.Linear(16 * classes, 128),
            nn.ReLU(),
        )
        if softmax:
            debug_message(f"Input size: {self.input_size}")
            self.hebbsoftmax = HebbianSoftmax(input_size[0]*input_size[1]*input_size[2], classes)
        else: self.softmax = nn.Sequential(nn.Linear(input_size[0] * input_size[1] * input_size[2], 10),
                                           nn.Softmax(dim=-1))

        self.to(DEVICE)

    def forward(self, x, y=None):
        x = x.to(DEVICE)
        debug_message(f"Input shape: {x.shape}")

        encoded = self.encoder(x)
        debug_message(f"Encoder Output shape: {encoded.shape}")
        if self.hebb:
            classification = self.hebbsoftmax(encoded)
        else:
            classification = self.softmax(encoded)

        reconstruction = self.decoder(encoded)
        reconstruction = reconstruction.view(-1, *self.input_size)

        debug_message(f"Classifier Output shape: {classification.shape}")
        debug_message(f"Decoder Output shape: {reconstruction.shape}")
        length = x.norm(dim=-1)

        return classification, reconstruction

## Running Experiments ##
The following section includes all the experiments to investigate the model performance, complexity, etc.

### Data ###

In [None]:
def create_subsets(train_dataset, test_dataset, subset_ratio):
    train_subset_size = int(len(train_dataset) * (subset_ratio / 100))
    train_subset_indices = np.random.choice(len(train_dataset), train_subset_size, replace=False)
    train_subset = Subset(train_dataset, train_subset_indices)
    test_subset_size = int(len(test_dataset) * (subset_ratio / 100))
    test_subset_indices = np.random.choice(len(test_dataset), test_subset_size, replace=False)
    test_subset = Subset(test_dataset, test_subset_indices)
    return train_subset, test_subset


def load_MNIST(batch_size, subset_ratio=None):
    transform = transforms.Compose([
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # Load full training and testing dataset
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    # Print dataset sizes
    print(f"Full training dataset size: {len(train_dataset)}")
    print(f"Full testing dataset size: {len(test_dataset)}")

    # Load subset if required
    if subset_ratio is not None:
        train_subset, test_subset = create_subsets(train_dataset, test_dataset, subset_ratio)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=True)
    else:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader

In [None]:
def download_ImageNet():
    url='http://cs231n.stanford.edu/tiny-imagenet-200.zip'
    download_dir = os.getcwd()
    dataset_dir = os.path.join(download_dir, 'tiny-imagenet-200')
    if os.path.exists(dataset_dir):
        return dataset_dir

    print("Downloading Tiny ImageNet dataset...")
    tiny_imgdataset_path = wget.download(url, out=download_dir)
    print(f"\nDownloaded to: {tiny_imgdataset_path}")

    print("Extracting dataset...")
    with ZipFile(tiny_imgdataset_path, 'r') as zip_ref:
        zip_ref.extractall(download_dir)
    print("Extraction complete.")

    os.remove(tiny_imgdataset_path)
    print(f"Deleted the zip file: {tiny_imgdataset_path}")
    return dataset_dir


class TinyImageNetLoader:
    def __init__(self, data_dir, batch_size=32, img_size=28, subset_ratio=None):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.subset_ratio = subset_ratio
        self.img_size = img_size
        self.train_dir = os.path.join(data_dir, 'train')
        self.val_dir = os.path.join(data_dir, 'val')
        self.val_images_dir = os.path.join(self.val_dir, 'images')
        self.val_annotations_path = os.path.join(self.val_dir, 'val_annotations.txt')

    def _load_val_annotations(self):
        val_annotations = {}
        with open(self.val_annotations_path, 'r') as f:
            for line in f:
                parts = line.strip().split('\t')
                img_name = parts[0]
                class_id = parts[1]
                val_annotations[img_name] = class_id
        return val_annotations

    def _get_transforms(self):
        return transforms.Compose([
            transforms.Resize(self.img_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _create_val_dataset(self, annotations):
        val_images = []
        val_labels = []
        for img_name, true_label in annotations.items():
            img_path = os.path.join(self.val_images_dir, img_name)
            if isinstance(img_path, str):
                try:
                    img = Image.open(img_path).convert('RGB')
                    val_images.append(img)
                    val_labels.append(true_label)
                except (UnidentifiedImageError, IOError) as e:
                    print(f"Error opening image {img_path}: {e}")
                    continue
            else:
                print(f"Skipping invalid image path: {img_path}")

            val_images.append(img)
            val_labels.append(true_label)
        return val_images, val_labels

    def _encode_labels(self, val_labels):
        class_names = sorted(os.listdir(self.train_dir))
        label_mapping = {name: idx for idx, name in enumerate(class_names)}
        return [label_mapping[label] for label in val_labels]

    def load_data(self):
        # Load train dataset using ImageFolder
        train_transform = self._get_transforms()
        train_dataset = ImageFolder(root=self.train_dir, transform=train_transform)

        # Load validation dataset manually
        val_annotations = self._load_val_annotations()
        val_images, val_labels = self._create_val_dataset(val_annotations)
        val_labels_encoded = self._encode_labels(val_labels)

        # Create validation dataset
        val_dataset = torch.utils.data.TensorDataset(
            torch.stack([train_transform(img) for img in val_images]),
            torch.tensor(val_labels_encoded)
        )

        # DataLoaders
        print(f"Full training dataset size: {len(train_dataset)}")
        print(f"Full testing dataset size: {len(val_dataset)}")
        if self.subset_ratio is not None:
            train_subset, val_subset = create_subsets(train_dataset, val_dataset, self.subset_ratio)
            train_loader = DataLoader(train_subset, batch_size=self.batch_size, shuffle=True)
            val_loader = DataLoader(val_subset, batch_size=self.batch_size, shuffle=True)
        else:
            train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)

        return train_loader, val_loader

In [None]:
def load_CIFAR(batch_size=32, img_size=28, subset_ratio=None):
    transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.Grayscale(num_output_channels=1),  # convert to grayscale
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])

    train_dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.CIFAR100(root='./data', train=False, download=True, transform=transform)

    print(f"Full training dataset size: {len(train_dataset)}")
    print(f"Full testing dataset size: {len(test_dataset)}")

    if subset_ratio is not None:
        train_subset, test_subset = create_subsets(train_dataset, test_dataset, subset_ratio)
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=True)
    else:
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

    return train_loader, test_loader

### Model Configurations ###

In [None]:
### Configuration
wb = {'project': 'CSC413', 'run': 'caps-hebb-softmax-lr0.001'}
wb_ImageNet = {'project': 'CSC413', 'run': 'ImageNet-caps-hebb-softmax-lr0.001'}
wb_CIFAR = {'project': 'CSC413', 'run': 'CIFAR-caps-hebb-softmax-lr0.001'}

cfg = {
    "classes": 10,
    "batch_size": 8,
    "epochs": 50,
    "lr": 0.001,
    "lr_decay": 0.9,
    "lam_recon": 0.0005 * 784,
    "routings": 3,
    "save_dir": "./log",
    "wb": wb
}

cfg_ImageNet = {
    "classes": 200,
    "batch_size": 8,
    "epochs": 50,
    "lr": 0.001,
    "lr_decay": 0.9,
    "lam_recon": 0.0005 * 784,
    "routings": 3,
    "save_dir": "./log",
    "wb": wb_ImageNet
}

cfg_CIFAR = {
    "classes": 100,
    "batch_size": 8,
    "epochs": 50,
    "lr": 0.001,
    "lr_decay": 0.9,
    "lam_recon": 0.0005 * 784,
    "routings": 3,
    "save_dir": "./log",
    "wb": wb_CIFAR
}

# Logging directory
if not os.path.exists(cfg["save_dir"]):
    os.makedirs(cfg["save_dir"])

### Datasets
# MNIST
train_MNIST, test_MNIST = load_MNIST(cfg["batch_size"], subset_ratio=0.1)

# # TinyImageNet
# dataset_dir = download_ImageNet()
# loader = TinyImageNetLoader(data_dir=dataset_dir, batch_size=cfg_ImageNet["batch_size"], subset_ratio=0.1)
# train_ImageNet, test_ImageNet = loader.load_data()

# # CIFAR100
# train_CIFAR, test_CIFAR = load_CIFAR(cfg["batch_size"], subset_ratio=0.1)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9.91M/9.91M [00:00<00:00, 54.8MB/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28.9k/28.9k [00:00<00:00, 1.77MB/s]

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1.65M/1.65M [00:00<00:00, 14.1MB/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4.54k/4.54k [00:00<00:00, 9.37MB/s]

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw

Full training dataset size: 60000
Full testing dataset size: 10000





### Results ###

In [None]:
cfg = {
    "classes": 10,
    "batch_size": 8,
    "epochs": 10,
    "lr": 0.001,
    "lr_decay": 0.9,
    "lam_recon": 0.0005 * 784,
    "routings": 3,
    "save_dir": "./log",
    "wb": wb
}

cfg_ImageNet = {
    "classes": 200,
    "batch_size": 8,
    "epochs": 50,
    "lr": 0.001,
    "lr_decay": 0.9,
    "lam_recon": 0.0005 * 784,
    "routings": 3,
    "save_dir": "./log",
    "wb": wb_ImageNet
}

cfg_CIFAR = {
    "classes": 100,
    "batch_size": 8,
    "epochs": 50,
    "lr": 0.001,
    "lr_decay": 0.9,
    "lam_recon": 0.0005 * 784,
    "routings": 3,
    "save_dir": "./log",
    "wb": wb_CIFAR
}

In [None]:
## MNIST
print("MNIST")
# print("CapsNet")
debug=True
capsnet = CapsuleNet(input_size=[1, 28, 28], classes=cfg["classes"], routings=cfg["routings"])
# train(capsnet, train_MNIST, test_MNIST, cfg)

# print("CapsNet + Hebbian")
train_hebb(capsnet, train_MNIST, test_MNIST, cfg)

# ## TinyImageNet
# print("TinyImageNet-200")
# print("CapsNet")
# capsnet = CapsuleNet(input_size=[1, 28, 28], classes=cfg_ImageNet["classes"], routings=cfg_ImageNet["routings"])
# train(capsnet, train_MNIST, test_MNIST, cfg_ImageNet)

# print("CapsNet + Hebbian")
# train_hebb(capsnet, train_MNIST, test_MNIST, cfg_ImageNet)

# ## CIFAR100
# print("CIFAR100")
# print("CapsNet")
# capsnet = CapsuleNet(input_size=[1, 28, 28], classes=cfg_CIFAR["classes"], routings=cfg_CIFAR["routings"])
# train(capsnet, train_CIFAR, test_CIFAR, cfg_CIFAR)

# print("CapsNet + Hebbian")
# train_hebb(capsnet, train_CIFAR, test_CIFAR, cfg_CIFAR)

# ## Extra: CapsNet + HebbianSoftmax
# capsnet_softmax = CapsuleNet(input_size=[1, 28, 28], classes=cfg["classes"], routings=cfg["routings"], softmax=True)
# train(capsnet_softmax, train_MNIST, test_MNIST, cfg)

MNIST
Training starts...


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Epoch: 0
One-hot Encoding y: torch.Size([8, 10])
Input shape: torch.Size([8, 1, 28, 28])
Hook triggered for conv1, activation shape: torch.Size([8, 256, 20, 20])
Conv1 Output shape: torch.Size([8, 256, 20, 20])
ReLU Ouput shape: torch.Size([8, 256, 20, 20])
PrimaryCapsule Input shape: torch.Size([8, 256, 20, 20])
Hook triggered for primarycaps.conv2d, activation shape: torch.Size([8, 256, 6, 6])
After Conv2D shape: torch.Size([8, 256, 6, 6])
PrimaryCapsule Output shape: torch.Size([8, 1152, 8])
Hook triggered for primarycaps, activation shape: torch.Size([8, 1152, 8])
DigitCapsule Input shape: torch.Size([8, 1152, 8])
After reshape for routing shape: torch.Size([8, 1, 1152, 8, 1])
x_hat shape: torch.Size([8, 10, 1152, 16])
Routing 1 coupling coefficients shape: torch.Size([8, 10, 1152])
Routing 2 coupling coefficients shape: torch.Size([8, 10, 1152])
Routing 3 coupling coefficients shape: torch.Size([8, 10, 1152])
DigitCapsule Output shape: torch.Size([8, 10, 1, 16])
Hook triggered for

VBox(children=(Label(value='0.264 MB of 0.264 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test_acc,▁▅▇▇███▇▇▇
test_f1,▁▅▇▇███▇▇▇
test_loss,█▄▂▂▁▁▁▁▁▁
test_precision,▁▄█▇███▇▇▇
test_recall,▁▅▇▇███▇▇▇
test_recon_err,█▄▃▃▂▂▁▁▁▁
train_acc,▁▅▆███████
train_f1,▁▅▆███████
train_loss,█▅▄▂▂▁▁▁▁▁
train_precision,▁▅▆███████

0,1
test_acc,0.8
test_f1,0.76
test_loss,0.84035
test_precision,0.78333
test_recall,0.8
test_recon_err,0.88992
train_acc,1.0
train_f1,1.0
train_loss,0.75131
train_precision,1.0


CapsuleNet(
  (conv1): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
  (primarycaps): PrimaryCapsule(
    (conv2d): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
  )
  (digitcaps): DigitCapsule()
  (decoder): Sequential(
    (fc1): Linear(in_features=160, out_features=512, bias=True)
    (relu1): ReLU(inplace=True)
    (fc2): Linear(in_features=512, out_features=1024, bias=True)
    (relu2): ReLU(inplace=True)
    (fc3): Linear(in_features=1024, out_features=784, bias=True)
    (sigmoid): Sigmoid()
  )
)

In [None]:
## CNN autoencoder
debug=False
cnn_wb = {'project': 'CSC413', 'run': 'cnn-lr0.001'}
cnn_hebb_wb = {'project': 'CSC413', 'run': 'cnn-hebb-softmax-lr0.001'}
cnn_cfg = {
    "classes": 10,
    "batch_size": 8,
    "epochs": 50,
    "lr": 0.001,
    "lr_decay": 0.9,
    "lam_recon": 0.0005 * 784,
    "routings": 3,
    "save_dir": "./log",
    "wb": cnn_wb
}
cnn_hebb_cfg = {
    "classes": 10,
    "batch_size": 8,
    "epochs": 50,
    "lr": 0.001,
    "lr_decay": 0.9,
    "lam_recon": 0.0005 * 784,
    "routings": 3,
    "save_dir": "./log",
    "wb": cnn_hebb_wb
}
cnn = CNNAutoencoder(input_size=[1, 28, 28], classes=cfg["classes"], softmax=False)
train(cnn, train_MNIST, test_MNIST, cnn_cfg)

cnn_hebb = CNNAutoencoder(input_size=[1, 28, 28], classes=cfg["classes"], softmax=True)
train(cnn_hebb, train_MNIST, test_MNIST, cnn_hebb_cfg)

Training starts...


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


train: epoch = 0, loss = 1.5359, classfication acc = 0.0667, reconstruction err = 0.5224
                   precision = 0.0201, recall = 0.0667, f1-score = 0.0257
test: epoch = 0, loss = 1.2130, classfication acc = 0.3000, reconstruction err = 0.4695
                  precision = 0.0900, recall = 0.3000, f1-score = 0.1385
train: epoch = 1, loss = 1.4573, classfication acc = 0.1000, reconstruction err = 0.4575
                   precision = 0.0100, recall = 0.1000, f1-score = 0.0182
test: epoch = 1, loss = 1.2133, classfication acc = 0.3000, reconstruction err = 0.4395
                  precision = 0.0900, recall = 0.3000, f1-score = 0.1385
train: epoch = 2, loss = 1.4577, classfication acc = 0.1000, reconstruction err = 0.4425
                   precision = 0.0100, recall = 0.1000, f1-score = 0.0182
test: epoch = 2, loss = 1.2133, classfication acc = 0.3000, reconstruction err = 0.4678
                  precision = 0.0900, recall = 0.3000, f1-score = 0.1385
train: epoch = 3, loss = 1.4

0,1
test_acc,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_f1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_loss,▁███████████████████████████████████████
test_precision,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_recall,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_recon_err,▆▃▅▅▄▁▃▆▇▆▃█▃▄▂▃▅▅▇▁▃▆▆▄▅▄▂▃▃▆▆▁▁▄▂▃▃▅▆▃
train_acc,▁███████████████████████████████████████
train_f1,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_precision,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test_acc,0.3
test_f1,0.13846
test_loss,1.21332
test_precision,0.09
test_recall,0.3
test_recon_err,0.43637
train_acc,0.1
train_f1,0.01818
train_loss,1.45769
train_precision,0.01


[34m[1mwandb[0m: Currently logged in as: [33methanoate[0m ([33mtrimunculo-oxygenate[0m). Use [1m`wandb login --relogin`[0m to force relogin


Training starts...


train: epoch = 0, loss = 1.4525, classfication acc = 0.1333, reconstruction err = 0.5308
                   precision = 0.1364, recall = 0.1333, f1-score = 0.1106
test: epoch = 0, loss = 1.5772, classfication acc = 0.0000, reconstruction err = 0.4129
                  precision = 0.0000, recall = 0.0000, f1-score = 0.0000
train: epoch = 1, loss = 1.4972, classfication acc = 0.0667, reconstruction err = 0.4427
                   precision = 0.0044, recall = 0.0667, f1-score = 0.0083
test: epoch = 1, loss = 1.5772, classfication acc = 0.0000, reconstruction err = 0.4793
                  precision = 0.0000, recall = 0.0000, f1-score = 0.0000
train: epoch = 2, loss = 1.4156, classfication acc = 0.1167, reconstruction err = 0.4538
                   precision = 0.0231, recall = 0.1167, f1-score = 0.0383
test: epoch = 2, loss = 1.5772, classfication acc = 0.0000, reconstruction err = 0.4344
                  precision = 0.0000, recall = 0.0000, f1-score = 0.0000
train: epoch = 3, loss = 1.4

0,1
test_acc,▁▁▁▅▅█▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
test_f1,▁▁▁▃▃█▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
test_loss,███▄▅▁▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
test_precision,▁▁██████████████████████████████████████
test_recall,▁▁██████████████████████████████████████
test_recon_err,▂▆▃▁▅▇▃▄▄▄▆▇▆▂▅▆█▄▃▅▃▄▆▆▃▆▇▄▆▆▆▄▁▄▁▅▄▃▆▃
train_acc,█▁▆▄▄▃▁▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
train_f1,█▁▃▂▅▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
train_loss,▄█▁▃▆▆▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
train_precision,█▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test_acc,0.1
test_f1,0.01818
test_loss,1.45571
test_precision,0.01
test_recall,0.1
test_recon_err,0.43173
train_acc,0.11667
train_f1,0.02438
train_loss,1.43649
train_precision,0.01361


CNNAutoencoder(
  (encoder): Sequential(
    (0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(256, 256, kernel_size=(9, 9), stride=(2, 2))
    (3): ReLU(inplace=True)
    (4): Conv2d(256, 784, kernel_size=(6, 6), stride=(1, 1))
    (5): ReLU(inplace=True)
    (6): Flatten(start_dim=1, end_dim=-1)
  )
  (decoder): Sequential(
    (fc1): Linear(in_features=784, out_features=512, bias=True)
    (relu1): ReLU(inplace=True)
    (fc2): Linear(in_features=512, out_features=1024, bias=True)
    (relu2): ReLU(inplace=True)
    (fc3): Linear(in_features=1024, out_features=784, bias=True)
    (sigmoid): Sigmoid()
  )
  (classifier): Sequential(
    (0): Linear(in_features=160, out_features=128, bias=True)
    (1): ReLU()
  )
  (hebbsoftmax): HebbianSoftmax(
    (linear): Linear(in_features=784, out_features=10, bias=False)
    (softmax): Softmax(dim=-1)
  )
)

## Testing Other Ideas ##

### Weight Update Through Genetic Algorithm ###

In [None]:
%pip install pygad
import pygad

Collecting pygad
  Downloading pygad-3.3.1-py3-none-any.whl.metadata (19 kB)
Downloading pygad-3.3.1-py3-none-any.whl (84 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m6.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pygad
Successfully installed pygad-3.3.1


In [None]:
def fitness_function(ga_instance, solution, solution_idx):
    """Evaluate the fitness of a chromosome."""
    start = 0
    for name, param in model.named_parameters():
        numel = param.numel()
        param.data = torch.tensor(solution[start:start + numel], dtype=torch.float32).view(param.shape).to(DEVICE)
        start += numel

    model = model.to(torch.float32).to(DEVICE)
    train_loss = 0.0
    with torch.no_grad():
        for x, y in train_loader:
            x, y = x.float().to(DEVICE), y.float().to(DEVICE)
            y = torch.zeros(y.size(0), 10, device=DEVICE).scatter_(1, y.view(-1, 1), 1.)  # one-hot encoding
            y_pred, x_recon = model(x)
            loss = caps_loss(y, y_pred, x, x_recon, cfg["lam_recon"])
            train_loss += loss.item() * x.size(0)

    train_loss /= len(train_loader.dataset)
    return -train_loss

In [None]:
# PyGAD setup
num_generations = 50
num_parents_mating = 5
population_size = 10
initial_weights = torch.cat([param.detach().view(-1) for param in model.parameters()]).cpu().numpy()
num_weights = initial_weights.size

# Initialize PyGAD
ga_instance = pygad.GA(
    num_generations=num_generations,
    num_parents_mating=num_parents_mating,
    fitness_func=fitness_function,
    sol_per_pop=population_size,
    num_genes=num_weights,
    init_range_low=-1.0,  # Weight initialization range
    init_range_high=1.0,
    mutation_probability=0.1,
)

NameError: name 'model' is not defined

In [None]:
class CapsuleNet(nn.Module):
    def __init__(self, input_size, classes, routings, hebb=True):
        super(CapsuleNet, self).__init__()
        self.hebb = hebb
        self.input_size = input_size
        self.classes = classes
        self.routings = routings

        self.conv1 = nn.Conv2d(input_size[0], 256, kernel_size=9, stride=1, padding=0, bias=False)
        self.primarycaps = PrimaryCapsule(256, 256, 8, kernel_size=9, stride=2, padding=0, bias=False)
        self.digitcaps = DigitCapsule(in_num_caps=32 * 6 * 6, in_dim_caps=8,
                                      out_num_caps=classes, out_dim_caps=16, routings=routings)

        self.decoder = nn.Sequential(
            OrderedDict([
                ("fc1", nn.Linear(16 * classes, 512, bias=False)),
                ("relu1", nn.ReLU(inplace=True)),
                ("fc2", nn.Linear(512, 1024, bias=False)),
                ("relu2", nn.ReLU(inplace=True)),
                ("fc3", nn.Linear(1024, input_size[0] * input_size[1] * input_size[2], bias=False)),
                ("sigmoid", nn.Sigmoid())
            ])
        )
        self.to(DEVICE)

    def forward(self, x, y=None):
        x = x.to(DEVICE)
        x = self.conv1(x)
        x = nn.functional.relu(x)
        x = self.primarycaps(x)
        x = self.digitcaps(x)
        length = x.norm(dim=-1)

        if y is None:
            index = length.max(dim=1)[1]
            y = torch.zeros(length.size(), device=DEVICE).scatter_(1, index.view(-1, 1), 1.)

        y = y.to(DEVICE)
        reconstruction = self.decoder((x * y[:, :, None]).view(x.size(0), -1))
        return length, reconstruction.view(-1, *self.input_size)


class PrimaryCapsule(nn.Module):
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride=1, padding=0, bias=True):
        super(PrimaryCapsule, self).__init__()
        self.to(DEVICE)
        self.dim_caps = dim_caps
        self.conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias)

    def forward(self, x):
        x.to(DEVICE)
        outputs = self.conv2d(x)
        outputs = outputs.view(x.size(0), -1, self.dim_caps)
        return squash(outputs)

In [None]:
train_loader, test_loader = load_data(cfg["batch_size"], subset_ratio=0.1)
model = CapsuleNet(input_size=[1, 28, 28], classes=cfg["classes"], routings=cfg["routings"], hebb=False)

# Run the genetic algorithm
ga_instance.run()

# Extract the best solution
best_solution, best_solution_fitness, _ = ga_instance.best_solution()
print("Best Fitness Achieved:", best_solution_fitness)

In [None]:
# Load the best solution into the model for testing
start = 0
for name, param in model.named_parameters():
    numel = param.numel()
    param.data = torch.tensor(best_solution[start:start + numel]).view(param.shape).to(DEVICE)
    start += numel

# Test the model with the optimized weights
test_loss, test_acc = test(model, test_loader, cfg)
print(f"Test Loss: {test_loss}, Test Accuracy: {test_acc}")