In [None]:
import os
import re
import sys
sys.path.append("..")

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

import numpy as np
import pandas as pd
from typing import List
from tabulate import tabulate
from dataclasses import dataclass, field

# Trainer
from src.trainer.trainer import Trainer

# Models
from src.models.mlp import MLP
from src.models.p4_allcnn import P4AllCNNC
from src.models.fA_p4_allcnn import fA_P4AllCNNC
from src.models.p4m_allcnn import  P4MAllCNNC
from src.models.fA_p4m_allcnn import fA_P4MAllCNNC
from src.models.p4m_resnet import  P4MResNet
from src.models.fA_p4_resnet import fA_P4MResNet
from src.models.dbageconv import DBAGEConvNet

## CIFAR10 dataset

In [None]:
root_path = "../data"
transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor()
])
transform_test = transforms.ToTensor()

full_train = datasets.CIFAR10(root=root_path, train=True, download=True, transform=transform_train)

train_size = int(0.9 * len(full_train))
val_size = len(full_train) - train_size
train_set, val_set = random_split(full_train, [train_size, val_size])
test_set = datasets.CIFAR10(root=root_path, train=False, download=True, transform=transform_test)

num_workers = 4
batch_size = 128
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True)

## Experiment

In [None]:
# ----- Helper Functions -----
def init_model(name):
    if name == "p4_allcnn":
        return P4AllCNNC()
    elif name == "fA_p4_allcnn":
        return fA_P4AllCNNC()
    elif name == "p4m_allcnn":
        return P4MAllCNNC()
    elif name == "fA_p4m_allcnn":
        return fA_P4MAllCNNC()
    elif name == "p4m_resnet":
        return P4MResNet()
    elif name == "fA_p4_resnet":
        return fA_P4MResNet()
    elif name == "dbageconv":
        return DBAGEConvNet()
    else:
        raise ValueError(f"Unknown model name: {name}")

def init_optimizer(model, lr, weight_decay):
    return optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=weight_decay)

def init_scheduler(optimizer, milestones):
    return optim.lr_scheduler.MultiStepLR(optimizer, milestones=[200, 250, 300], gamma=0.1)

# ----- Helper Classes -----
@dataclass
class HyperParams:
    lr: float
    epochs: int
    weight_decay: float
    momentum: float
    gamma: float
    milestones: List[int] = field(default_factory=list)

In [None]:
# ----- Configuration -----
num_iterations = 1
log_dir = "../logs"

model_hyperparameters = {
    "dbageconv":  HyperParams(lr=0.01, epochs=300, weight_decay=1e-3, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "p4_allcnn":  HyperParams(lr=0.01, epochs=350, weight_decay=1e-3, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "fA_p4_allcnn": HyperParams(lr=0.01, epochs=85, weight_decay=1e-3, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "p4m_allcnn": HyperParams(lr=0.01, epochs=50, weight_decay=1e-3, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "fA_p4m_allcnn": HyperParams(lr=0.01, epochs=45, weight_decay=1e-3, momentum=0.9, milestones=[200, 250, 300], gamma=0.1),
    "p4m_resnet": HyperParams(lr=0.05, epochs=45, weight_decay=0.0, momentum=0.9, milestones=[50, 100, 150], gamma=0.1),
    "fA_p4_resnet": HyperParams(lr=0.05, epochs=30, weight_decay=0.0, momentum=0.9, milestones=[50, 100, 150], gamma=0.1)
}
model_names = model_hyperparameters.keys()
accuracies = {name: [] for name in model_names}

In [None]:
checkpoint_dir = "../checkpoints" # "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

def get_latest_checkpoint(model_name):
    files = os.listdir(checkpoint_dir)
    pattern = re.compile(f"{re.escape(model_name)}_epoch(\\d+).pt")
    checkpoints = [(int(m.group(1)), os.path.join(checkpoint_dir, f))
                   for f in files if (m := pattern.match(f))]
    if checkpoints:
        return max(checkpoints)  # returns (epoch, filepath)
    return None

# ----- Main Training Loop -----
for it in range(num_iterations):
    print(f"Iteration {it + 1}/{num_iterations}")

    for name in model_names:
        print(f"\n→ Training model: {name}")

        # 1. Grab hyperparams for this model
        hp = model_hyperparameters[name]

        # 2. Check if a checkpoint exists
        checkpoint_info = get_latest_checkpoint(name)
        model = init_model(name)
        start_epoch = 0

        if checkpoint_info:
            start_epoch, checkpoint_path = checkpoint_info
            print(f"→ Found checkpoint at epoch {start_epoch}: {checkpoint_path}")
            model.load_state_dict(torch.load(checkpoint_path))
        else:
            print(f"→ No checkpoint for {name}. Training from scratch.")

        # 3. Setup training components
        criterion = nn.CrossEntropyLoss()
        optimizer = init_optimizer(model, hp.lr, hp.weight_decay)
        scheduler = init_scheduler(optimizer, hp.milestones)

        trainer = Trainer(
            models={name: model},
            optimizers=[optimizer],
            criterions=[criterion],
            schedulers=[scheduler],
            log_dir=f"{log_dir}/{name}"
        )

        # 4. Train only the remaining epochs
        remaining_epochs = hp.epochs - start_epoch
        if remaining_epochs > 0:
            trainer.train(
                num_epochs=remaining_epochs,
                train_loader=train_loader,
                val_loader=val_loader
            )

            final_epoch = start_epoch + remaining_epochs
            new_checkpoint_path = os.path.join(checkpoint_dir, f"{name}_epoch{final_epoch}.pt")
            torch.save(model.state_dict(), new_checkpoint_path)
            print(f"✓ Saved checkpoint to {new_checkpoint_path}")
        else:
            print(f"✓ {name} already trained for {hp.epochs} epochs. Skipping.")

        # 5. Evaluate and store results
        test_acc = trainer.evaluate(test_loader=test_loader)[name]
        accuracies[name].append(test_acc)


In [None]:
# ----- Final Statistics -----
final_stats = {
    name: {
        "% Test error": (1 - float(np.mean(vals))) * 100,
        "% std": float(np.std(vals)) * 100,
        "Num Parameters": sum(p.numel() for p in init_model(name).parameters())
    }
    for name, vals in accuracies.items()
}

## Table generation

In [None]:
df = pd.DataFrame.from_dict(final_stats, orient='index')
df = df.round(2)
print("📊 Model Accuracy Summary in CIFAR10\n")
print(tabulate(df, headers="keys", tablefmt="fancy_grid"))