In [None]:
"""Centralized scenario with Masked Aggregation.
Change this cell to run the scenario with different parameters."""

# Density of the mask (percent of non-zero elements)
density = 0.1

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from data.cifar100_loader import get_cifar100_loaders
from model.prepare_model import get_dino_vits16_model, freeze_backbone, unfreeze_backbone, freeze_head, unfreeze_head
from model.hyperparameter_tuning import run_grid_search
from eval import evaluate
from train import train
from model.model_editing import mask_calculator
import matplotlib.pyplot as plt

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories for data and checkpoints
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
model = get_dino_vits16_model(device)
# Freeze backbone
freeze_backbone(model)

In [None]:
# Train on full training set (train + val)
full_train_loader, _, test_loader = get_cifar100_loaders(val_split=0.0)

criterion = nn.CrossEntropyLoss()

In [None]:
# Perform brief pre-training on the full training set
start_epoch = 0
warmup_epochs = 3
best_test_acc = 0.0

warmup_train_loss = []
warmup_train_acc = []
warmup_test_loss = []
warmup_test_acc = []

In [None]:
# Set best configuration found during grid search
best_cfg = {"lr": 0.005, "momentum": 0.9}

In [None]:
# Set up optimizer and scheduler
optimizer = optim.SGD(model.parameters(), lr=best_cfg['lr'], momentum=best_cfg['momentum'], weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=warmup_epochs)

In [None]:
# Run warmup training
for epoch in range(start_epoch, start_epoch + warmup_epochs):
    train_loss, train_acc = train(model, full_train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    warmup_train_loss.append(train_loss)
    warmup_train_acc.append(train_acc)
    warmup_test_loss.append(test_loss)
    warmup_test_acc.append(test_acc)

    print(f"Epoch {epoch+1}/{start_epoch + warmup_epochs}")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Test Loss:  {test_loss:.4f} | Test Acc:  {test_acc:.4f}")

torch.save(model.state_dict(), os.path.join(CHECKPOINT_DIR, f'pre_trained_model_centralized.pth'))

In [None]:
# Plot the training and test loss
plt.plot(warmup_train_loss, label='Train Loss')
plt.plot(warmup_test_loss, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot the training and test accuracy
plt.plot(warmup_train_acc, label='Train Accuracy')
plt.plot(warmup_test_acc, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
# Load the pre-trained model
model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, 'pre_trained_model_centralized.pth')))
# Unfreeze the backbone
unfreeze_backbone(model)
# Freeze the head
freeze_head(model)

# Define the number of samples per class for CIFAR-100
samples_per_class = [5] * 100

# Compute the mask
mask = mask_calculator(model, full_train_loader.dataset, device, samples_per_class=samples_per_class, density=density)

In [None]:
# Train the model with the mask
start_epoch = 0
num_epochs = 10
best_test_acc = 0.0

hist_train_loss = []
hist_train_acc = []
hist_test_loss = []
hist_test_acc = []

In [None]:
# Re-set optimizer and scheduler for fine-tuning
optimizer = optim.SGD(model.parameters(), lr=best_cfg['lr'], momentum=best_cfg['momentum'], weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

In [None]:
# Train the model with the mask
for epoch in range(start_epoch, start_epoch + num_epochs):
    train_loss, train_acc = train(model, full_train_loader, optimizer, criterion, device, grad_mask=mask)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    hist_train_loss.append(train_loss)
    hist_train_acc.append(train_acc)
    hist_test_loss.append(test_loss)
    hist_test_acc.append(test_acc)

    print(f"Epoch {epoch+1}/{start_epoch + num_epochs}")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Test Loss:  {test_loss:.4f} | Test Acc:  {test_acc:.4f}")

    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, os.path.join(CHECKPOINT_DIR, 'best_model_overall.pth'))

    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f'centralized_edited_epoch{epoch+1}.pth'))

In [None]:
# Plot the training and test loss
plt.plot(hist_train_loss, label='Train Loss')
plt.plot(hist_test_loss, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot the training and test accuracy
plt.plot(hist_train_acc, label='Train Accuracy')
plt.plot(hist_test_acc, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()