In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from data.cifar100_loader import get_cifar100_loaders
from models.prepare_model import get_dino_vits16_model, freeze_backbone
from models.hyperparameter_tuning import run_grid_search
from eval import evaluate
from train import train


In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
train_loader, val_loader, test_loader = get_cifar100_loaders()
criterion = nn.CrossEntropyLoss()

In [None]:
best_results = []

In [None]:
# Define the grid search parameters
configs1 = [
    {'scheduler': 'linear', 'lr': 0.01, 'momentum': 0.8},
    {'scheduler': 'linear', 'lr': 0.01, 'momentum': 0.9},
    {'scheduler': 'linear', 'lr': 0.01, 'momentum': 0.95},
]
configs2 = [
    {'scheduler': 'linear', 'lr': 0.001, 'momentum': 0.8},
    {'scheduler': 'linear', 'lr': 0.001, 'momentum': 0.9},
    {'scheduler': 'linear', 'lr': 0.001, 'momentum': 0.95},
]
configs3 = [
    {'scheduler': 'linear', 'lr': 0.005, 'momentum': 0.8},
    {'scheduler': 'linear', 'lr': 0.005, 'momentum': 0.9},
    {'scheduler': 'linear', 'lr': 0.005, 'momentum': 0.95},
]

In [None]:
best_cfg1, results1 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs1, device)
torch.save(best_cfg1, os.path.join(CHECKPOINT_DIR, 'best_cfg1.pth'))
best_results.append(best_cfg1)

In [None]:
best_cfg2, results2 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs2, device)
torch.save(best_cfg2, os.path.join(CHECKPOINT_DIR, 'best_cfg2.pth'))
best_results.append(best_cfg2)

In [None]:
best_cfg3, results3 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs3, device)
torch.save(best_cfg3, os.path.join(CHECKPOINT_DIR, 'best_cfg3.pth'))
best_results.append(best_cfg3)

In [None]:
# Define the grid search parameters
configs4 = [
    {'scheduler': 'exp', 'lr': 0.01, 'momentum': 0.8},
    {'scheduler': 'exp', 'lr': 0.01, 'momentum': 0.9},
    {'scheduler': 'exp', 'lr': 0.01, 'momentum': 0.95},
]
configs5 = [
    {'scheduler': 'exp', 'lr': 0.001, 'momentum': 0.8},
    {'scheduler': 'exp', 'lr': 0.001, 'momentum': 0.9},
    {'scheduler': 'exp', 'lr': 0.001, 'momentum': 0.95},
]
configs6 = [
    {'scheduler': 'exp', 'lr': 0.005, 'momentum': 0.8},
    {'scheduler': 'exp', 'lr': 0.005, 'momentum': 0.9},
    {'scheduler': 'exp', 'lr': 0.005, 'momentum': 0.95},
]

In [None]:
best_cfg4, results4 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs4, device)
torch.save(best_cfg4, os.path.join(CHECKPOINT_DIR, 'best_cfg4.pth'))
best_results.append(best_cfg4)

In [None]:
best_cfg5, results5 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs5, device)
torch.save(best_cfg5, os.path.join(CHECKPOINT_DIR, 'best_cfg5.pth'))
best_results.append(best_cfg5)

In [None]:
best_cfg6, results6 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs6, device)
torch.save(best_cfg6, os.path.join(CHECKPOINT_DIR, 'best_cfg6.pth'))
best_results.append(best_cfg6)

In [None]:
# Define the grid search parameters
configs7 = [
    {'scheduler': 'cosine', 'lr': 0.01, 'momentum': 0.8},
    {'scheduler': 'cosine', 'lr': 0.01, 'momentum': 0.9},
    {'scheduler': 'cosine', 'lr': 0.01, 'momentum': 0.95},
]
configs8 = [
    {'scheduler': 'cosine', 'lr': 0.001, 'momentum': 0.8},
    {'scheduler': 'cosine', 'lr': 0.001, 'momentum': 0.9},
    {'scheduler': 'cosine', 'lr': 0.001, 'momentum': 0.95},
]
configs9 = [
    {'scheduler': 'cosine', 'lr': 0.005, 'momentum': 0.8},
    {'scheduler': 'cosine', 'lr': 0.005, 'momentum': 0.9},
    {'scheduler': 'cosine', 'lr': 0.005, 'momentum': 0.95},
]

In [None]:
best_cfg7, results7 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs7, device)
torch.save(best_cfg7, os.path.join(CHECKPOINT_DIR, 'best_cfg7.pth'))
best_results.append(best_cfg7)

In [None]:
best_cfg8, results8 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs8, device)
torch.save(best_cfg8, os.path.join(CHECKPOINT_DIR, 'best_cfg8.pth'))
best_results.append(best_cfg8)

In [None]:
best_cfg9, results9 = run_grid_search(train_loader, val_loader, get_dino_vits16_model, criterion, configs9, device)
torch.save(best_cfg9, os.path.join(CHECKPOINT_DIR, 'best_cfg9.pth'))
best_results.append(best_cfg9)

In [None]:
# Get the best configuration
best_cfg = None
best_accuracy = 0.0
for result in best_results:
    if result['accuracy'] > best_accuracy:
        best_accuracy = result['accuracy']
        best_cfg = result['cfg']

In [None]:
model = get_dino_vits16_model(device)
freeze_backbone(model)

In [None]:
# Train on full training set (train + val)
full_train_loader, _, test_loader = get_cifar100_loaders(val_split=0.0)

In [None]:
start_epoch = 0
num_epochs = 30
best_test_acc = 0.0

hist_train_loss = []
hist_train_acc = []
hist_test_loss = []
hist_test_acc = []

In [None]:
optimizer = optim.SGD(model.parameters(), lr=best_cfg['lr'], momentum=best_cfg['momentum'], weight_decay=5e-4)
if best_cfg['scheduler'] == 'cosine':
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
elif best_cfg['scheduler'] == 'linear':
    scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, total_iters=num_epochs)
elif best_cfg['scheduler'] == 'exp':
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)

In [None]:
for epoch in range(start_epoch, start_epoch + num_epochs):
    train_loss, train_acc = train(model, full_train_loader, optimizer, criterion, device)
    test_loss, test_acc = evaluate(model, test_loader, criterion, device)
    scheduler.step()
    
    hist_train_loss.append(train_loss)
    hist_train_acc.append(train_acc)
    hist_test_loss.append(test_loss)
    hist_test_acc.append(test_acc)

    print(f"Epoch {epoch+1}/{start_epoch + num_epochs}")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"  Test Loss:  {test_loss:.4f} | Test Acc:  {test_acc:.4f}")

    if test_acc > best_test_acc:
        best_test_acc = test_acc
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }, os.path.join(CHECKPOINT_DIR, 'best_model_overall.pth'))

    if (epoch + 1) % 10 == 0:
        checkpoint = {
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict()
        }
        torch.save(checkpoint, os.path.join(CHECKPOINT_DIR, f'centralized_epoch{epoch+1}.pth'))

In [None]:
import matplotlib.pyplot as plt

# Plot the training and test loss
plt.plot(hist_train_loss, label='Train Loss')
plt.plot(hist_test_loss, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot the training and test accuracy
plt.plot(hist_train_acc, label='Train Accuracy')
plt.plot(hist_test_acc, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.show()