In [None]:
# Clone github repository
!git clone https://github.com/AlessandroMaini/federated-learning-project.git

In [None]:
%cd federated-learning-project

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from data.cifar100_loader import get_cifar100_loaders
from models.prepare_model import get_frozen_dino_vits16_model
from tools.hyperparameter_tuning import run_grid_search
from eval import evaluate
from train import train

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
train_loader, val_loader, test_loader = get_cifar100_loaders()
criterion = nn.CrossEntropyLoss()

In [None]:
best_cfg, best_model_state = run_grid_search(train_loader, val_loader, get_frozen_dino_vits16_model, criterion, device)

In [None]:
model = get_frozen_dino_vits16_model(device)
model.load_state_dict(best_model_state)

In [None]:
optimizer = optim.SGD(model.parameters(), lr=best_cfg[0], momentum=best_cfg[1], weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=best_cfg[2])

In [None]:
## If you want to load a checkpoint, uncomment the following lines
# PATH_TO_CHECKPOINT = "./checkpoints/dino_vits16_epoch30.pth"
# checkpoint = torch.load(PATH_TO_CHECKPOINT, map_location=device)

# model.load_state_dict(checkpoint['model_state_dict'])
# optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
# scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
# start_epoch = checkpoint['epoch']

In [None]:
# Train on full training set (train + val)
full_train_loader, _, test_loader = get_cifar100_loaders(val_split=0.0)

In [None]:
start_epoch = 0
num_epochs = 50

hist_train_loss = []
hist_train_acc = []
hist_test_loss = []
hist_test_acc = []

for epoch in range(start_epoch, start_epoch + num_epochs + 1):
    train_loss, train_acc = train(model, full_train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    hist_train_loss.append(train_loss)
    hist_train_acc.append(train_acc)
    hist_test_loss.append(test_loss)
    hist_test_acc.append(test_acc)

    print(f"Epoch {epoch+1}/{start_epoch + num_epochs}")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Test Loss:  {test_loss:.4f} | Test Acc:  {test_acc:.4f}")

    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f'dino_vits16_epoch{epoch+1}.pth'))