In [None]:
# Clone github repository
!git clone --branch personal-contribution https://github.com/AlessandroMaini/federated-learning-project.git

In [None]:
%cd federated-learning-project

In [None]:
import torch
import torch.nn as nn
import os
import matplotlib.pyplot as plt
import random
from data.cifar100_loader import get_federated_cifar100_dataloaders_with_imbalances, get_class_distribution_matrix, plot_client_distributions
from eval import evaluate
from train import train, train_steps
from models.prepare_model import get_dino_vits16_model, freeze_backbone, unfreeze_backbone, freeze_head, unfreeze_head
from models.model_editing import mask_calculator, freeze_and_clean_client_masks
from models.federated_averaging import train_on_client, average_metrics, average_models, get_trainable_keys
from tqdm import tqdm

In [None]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Directories
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

In [None]:
# Number of clients
K = 100
# Classes per client
N = 10
# Fraction of clients
C = 0.1
# Number of local steps
J = 4

In [None]:
# Create dataloaders for the clients
federated_test = False
train_datasets, test_loader, client_class_map, assigned_sizes, assigned_distributions = get_federated_cifar100_dataloaders_with_imbalances(K, N)

criterion = nn.CrossEntropyLoss()

In [None]:
collaborative_model = get_dino_vits16_model(device)

# Freeze the backbone
freeze_backbone(collaborative_model)

In [None]:
warmup_rounds = 10
warmup_steps = 8

In [None]:
start_round = 0
num_rounds = warmup_rounds
best_test_acc = 0.0

warmup_train_loss = []
warmup_train_acc = []
warmup_test_loss = []
warmup_test_acc = []

In [None]:
print("--- Starting Federated Averaging Warmup ---")
# FedAvg loop
for round in range(start_round, start_round + num_rounds):
    print(f"\n--- Round {round + 1}/{start_round + num_rounds} ---")

    # Select clients
    selected_clients = random.sample(range(K), int(C * K))

    # Local training
    local_models, train_losses, train_accs = [], [], []
    for client_id in selected_clients:
        model_state, loss, acc = train_on_client(
            client_id,
            collaborative_model,
            train_datasets[client_id],
            warmup_steps,
            criterion,
            lr = 0.01,
            device = device,
            client_dim=assigned_sizes[client_id]
        )
        local_models.append(model_state)
        train_losses.append(loss)
        train_accs.append(acc)

    # Weighting by dataset size for metrics
    client_sample_counts = [len(train_datasets[c]) for c in selected_clients]
    total_samples = sum(client_sample_counts)
    client_weights_metrics = [count / total_samples for count in client_sample_counts]

    # Weighting uniformly
    client_weights = [1.0 / len(selected_clients)] * len(selected_clients)

    # Federated averaging
    trainable_keys = get_trainable_keys(collaborative_model)
    averaged_state = average_models(local_models, client_weights, trainable_keys)
    new_state = collaborative_model.state_dict()
    for key in averaged_state:
        new_state[key] = averaged_state[key]
    collaborative_model.load_state_dict(new_state)

    # Log average training metrics
    avg_train_loss = average_metrics(train_losses, client_weights_metrics)
    avg_train_acc = average_metrics(train_accs, client_weights_metrics)
    print(f"Avg Train Loss: {avg_train_loss:.4f}, Avg Train Accuracy: {avg_train_acc:.4f}")
    warmup_train_loss.append(avg_train_loss)
    warmup_train_acc.append(avg_train_acc)
        
    avg_test_loss, avg_test_acc = evaluate(collaborative_model, test_loader, criterion, device)

    print(f"Avg Test Loss: {avg_test_loss:.4f}, Avg Test Accuracy: {avg_test_acc:.4f}")
    warmup_test_loss.append(avg_test_loss)
    warmup_test_acc.append(avg_test_acc)

torch.save(collaborative_model.state_dict(), os.path.join(CHECKPOINT_DIR, f'pre_trained_federated_unbalance_{N}class.pth'))

In [None]:
# Plot the training and test loss
plt.plot(warmup_train_loss, label='Train Loss')
plt.plot(warmup_test_loss, label='Test Loss')
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot the training and test accuracy
plt.plot(warmup_train_acc, label='Train Accuracy')
plt.plot(warmup_test_acc, label='Test Accuracy')
plt.xlabel('Rounds')
plt.ylabel('Accuracy')
plt.legend()
plt.show()

In [None]:
# Load the pre-trained model
collaborative_model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, f'pre_trained_federated_unbalance_{N}class.pth')))
# Unfreeze the backbone
unfreeze_backbone(collaborative_model)
# Freeze the head
freeze_head(collaborative_model)

sparsity = 0.10

In [None]:
# Compute the mask for each client
client_masks = {}
for client_id in tqdm(range(K)):
    # Create a list of samples per class
    samples_per_class = [0] * 100
    for class_id in client_class_map[client_id]:
        samples_per_class[class_id] = 10
    client_masks[client_id] = mask_calculator(collaborative_model, train_datasets[client_id], device, rounds=4, sparsity=sparsity,
                                              samples_per_class=samples_per_class, client_dim=assigned_sizes[client_id], verbose=False)

In [None]:
client_masks, frozen_state = freeze_and_clean_client_masks(collaborative_model, client_masks, threshold=0.01, K=K)

# Save the client masks to a single file
torch.save(client_masks, os.path.join(CHECKPOINT_DIR, f'client_masks_unbalance_{N}class_{int(sparsity * 100)}.pth'))
# Save the frozen state of the model
torch.save(frozen_state, os.path.join(CHECKPOINT_DIR, f'frozen_state_unbalance_{N}class_{int(sparsity * 100)}.pth'))

In [None]:
# Load the client masks from the file
client_masks = torch.load(os.path.join(CHECKPOINT_DIR, f'client_masks_unbalance_{N}class_{int(sparsity * 100)}.pth'))
# Load frozen state from file
frozen_state = torch.load(os.path.join(CHECKPOINT_DIR, f'frozen_state_unbalance_{N}class_{int(sparsity * 100)}.pth'))

In [None]:
tot_non_zero = 0
for client_id in range(100):
  non_zero_count = 0
  for key, value in client_masks[client_id].items():
      # Check if the value is a tensor before attempting to sum
      if isinstance(value, torch.Tensor):
          non_zero_count += torch.sum(value != 0).item()
  print(f"Number of non-zeros in the mask: {non_zero_count}")
  tot_non_zero += non_zero_count
print(f"Mean non zero elems {tot_non_zero/100}")

In [None]:
start_round = 0
num_rounds = 200
best_test_acc = 0.0

hist_train_loss = []
hist_train_acc = []
hist_test_loss = []
hist_test_acc = []

In [None]:
# # Load previous checkpoint if available
# collaborative_model.load_state_dict(torch.load(os.path.join(CHECKPOINT_DIR, 'federated_round_80.pth')))

# freeze_head(collaborative_model)

# # Apply frozen state
# for name, param in collaborative_model.named_parameters():
#     if name in frozen_state:
#         param.requires_grad = False

In [None]:
# FedAvg loop
for round in range(start_round, start_round + num_rounds):
    print(f"\n--- Round {round + 1}/{start_round + num_rounds} ---")

    # Select clients
    selected_clients = random.sample(range(K), int(C * K))

    # Local training
    local_models, train_losses, train_accs = [], [], []
    for client_id in selected_clients:
        model_state, loss, acc = train_on_client(
            client_id,
            collaborative_model,
            train_datasets[client_id],
            J,
            criterion,
            lr = 0.01,
            device = device,
            mask=client_masks[client_id],
            client_dim=assigned_sizes[client_id]
        )
        local_models.append(model_state)
        train_losses.append(loss)
        train_accs.append(acc)

    # Weighting by dataset size for metrics
    client_sample_counts = [len(train_datasets[c]) for c in selected_clients]
    total_samples = sum(client_sample_counts)
    client_weights_metrics = [count / total_samples for count in client_sample_counts]

    # Weighting uniformly
    client_weights = [1.0 / len(selected_clients)] * len(selected_clients)

    # Federated averaging
    trainable_keys = get_trainable_keys(collaborative_model)
    averaged_state = average_models(local_models, client_weights, trainable_keys)
    new_state = collaborative_model.state_dict()
    for key in averaged_state:
        new_state[key] = averaged_state[key]
    collaborative_model.load_state_dict(new_state)

    # Log average training metrics
    avg_train_loss = average_metrics(train_losses, client_weights_metrics)
    avg_train_acc = average_metrics(train_accs, client_weights_metrics)
    print(f"Avg Train Loss: {avg_train_loss:.4f}, Avg Train Accuracy: {avg_train_acc:.4f}")
    hist_train_loss.append(avg_train_loss)
    hist_train_acc.append(avg_train_acc)

    if (round + 1) % 5 == 0:
        avg_test_loss, avg_test_acc = evaluate(collaborative_model, test_loader, criterion, device)

        print(f"Avg Test Loss: {avg_test_loss:.4f}, Avg Test Accuracy: {avg_test_acc:.4f}")
        hist_test_loss.append(avg_test_loss)
        hist_test_acc.append(avg_test_acc)

    # Save model each 20 rounds
    if (round + 1) % 20 == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"federated_unbalance_round_{round + 1}_{N}class_{int(sparsity * 100)}.pth")
        torch.save(collaborative_model.state_dict(), checkpoint_path)

In [None]:
# Plot the training and test loss
plt.plot(hist_train_loss, label='Train Loss')
plt.plot(hist_test_loss, label='Test Loss')
plt.xlabel('Rounds')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Plot the training and test accuracy
plt.plot(hist_train_acc, label='Train Accuracy')
plt.plot(hist_test_acc, label='Test Accuracy')
plt.xlabel('Rounds')
plt.ylabel('Accuracy')
plt.legend()
plt.show()