In [None]:
import sys
sys.path.append("..")

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import numpy as np
import pandas as pd
from typing import List
from tabulate import tabulate
from dataclasses import dataclass, field
import matplotlib.pyplot as plt

# Trainer
from src.trainer.trainer import Trainer

# Models
from src.models.chA_p4_cnn import A_Ch_P4CNN
from src.models.spA_p4_cnn import A_Sp_P4CNN
from src.models.fA_p4_allcnn import fA_P4AllCNNC
from src.models.big_cnn import B15_P4CNN, B11_P4CNN

# Data Utils
from src.datasets.rot_mnist_dataset import get_dataset

## rot-MNIST dataset

In [None]:
train_loader, val_loader, test_loader = get_dataset(batch_size=128, num_workers=2)

In [None]:
import matplotlib.pyplot as plt

images, labels = next(iter(train_loader))

print("Image shape:", images[0].shape)  
image = images[0].squeeze().cpu().numpy()  
print("Squeezed shape:", image.shape)  

# Plot
plt.imshow(image, cmap='gray')
plt.title(f"Label: {labels[0].item()}")
plt.axis('off')
plt.show()


## Experiment

In [None]:
# ----- Helper Functions -----
def init_model(name):
    if name == "big15_p4_cnn":
        return B15_P4CNN()
    if name == "chA_p4_cnn":
        return A_Ch_P4CNN()
    elif name == "spA_p4_cnn":
        return A_Sp_P4CNN()
    elif name == "big11_p4_cnn":
        return B11_P4CNN()
    elif name == "fA_p4_allcnn":
        return fA_P4AllCNNC()
    else:
        raise ValueError(f"Unknown model name: {name}")

def init_optimizer(model, lr, weight_decay):
    return optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

def init_scheduler(optimizer, milestones):
    return optim.lr_scheduler.MultiStepLR(optimizer, milestones=[200, 250, 300], gamma=0.1)

# ----- Helper Classes -----
@dataclass
class HyperParams:
    lr: float
    epochs: int
    weight_decay: float
    momentum: float
    gamma: float
    milestones: List[int] = field(default_factory=list)

In [None]:
# ----- Configuration -----
num_iterations = 3
log_dir = "../logs"

model_hyperparameters = {
    "big15_p4_cnn":  HyperParams(lr=0.001, epochs=100, weight_decay=0.0001, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "chA_p4_cnn":  HyperParams(lr=0.001, epochs=100, weight_decay=0.0001, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "spA_p4_cnn": HyperParams(lr=0.001, epochs=100, weight_decay=0.0001, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "big11_p4_cnn":  HyperParams(lr=0.001, epochs=100, weight_decay=0.0001, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "fA_p4_allcnn": HyperParams(lr=0.001, epochs=100, weight_decay=0.0001, momentum=0.9, milestones=[200, 250, 300], gamma=0.1)
}
model_names = model_hyperparameters.keys()
accuracies = {name: [] for name in model_names}

In [None]:
# ----- Main Training Loop -----
for it in range(num_iterations):
    print(f"Iteration {it + 1}/{num_iterations}")

    for name in model_names:
        print(f"\n→ Training model: {name}")

        # 1. Grab hyperparams for this model
        hp = model_hyperparameters[name]

        # 2. Initialize model, criterion, optimizer, scheduler
        model = init_model(name)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(
            model.parameters(),
            lr=hp.lr,
            momentum=hp.momentum,
            weight_decay=hp.weight_decay
        )
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=hp.milestones,
            gamma=hp.gamma
        )

        # 3. Wrap in our Trainer (single‐model)
        trainer = Trainer(
            models={name: model},
            optimizers=[optimizer],
            criterions=[criterion],
            schedulers=[scheduler],
            log_dir=f"{log_dir}/{name}"
        )

        # 4. Train & validate with the model‐specific epoch count
        trainer.train(
            num_epochs=hp.epochs,
            train_loader=train_loader,
            val_loader=val_loader,
        )

        # 5. Evaluate on test set and record accuracy
        test_acc = trainer.evaluate(test_loader=test_loader)[name]
        accuracies[name].append(test_acc)

In [None]:
# ----- Final Statistics -----
final_stats = {
    name: {
        "% Test error": (1 - float(np.mean(vals))) * 100,
        "% std": float(np.std(vals)) * 100,
        "Num Parameters": sum(p.numel() for p in init_model(name).parameters())
    }
    for name, vals in accuracies.items()
}

## Table generation

In [None]:
df = pd.DataFrame.from_dict(final_stats, orient='index')
df = df.round(2)
print("📊 Model Accuracy Summary in CIFAR10\n")
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))