# AML Project 5 - Federated Learning - Track B



## Setup & Installation

### Downloading and importing Library

In [None]:
import copy
import dataclasses
import logging
import os
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Literal, Optional, Tuple

import numpy as np
import numpy.typing as npt
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from torchvision.datasets.cifar import CIFAR100
from torch.utils.tensorboard.writer import SummaryWriter
from tqdm import tqdm


DEVICE: cuda
cuda


: 

### Logging Setup

In [2]:
class TqdmLoggingHandler(logging.Handler):
    def emit(self, record) -> None:
        try:
            msg = self.format(record)
            tqdm.write("\r\033[K" + msg)
            self.flush()
        except Exception:
            self.handleError(record)


class ColoredFormatter(logging.Formatter):
    COLORS = {
        "DEBUG": "\033[1;34m",
        "INFO": "\033[1;32m",
        "WARNING": "\033[1;33m",
        "ERROR": "\033[1;31m",
        "CRITICAL": "\033[1;35m",
        "RESET": "\033[0m",
    }

    def format(self, record):
        levelname = record.levelname
        if levelname in self.COLORS:
            record.levelname = (
                f"{self.COLORS[levelname]}{levelname}{self.COLORS['RESET']}"
            )
        return super().format(record)


def setup_logging(level=logging.INFO):
    root_logger = logging.getLogger()
    root_logger.handlers.clear()
    root_logger.setLevel(level)

    tqdm_handler = TqdmLoggingHandler()
    formatter = ColoredFormatter(
        fmt="%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
    )
    tqdm_handler.setFormatter(formatter)
    root_logger.addHandler(tqdm_handler)


setup_logging()

## Config Setup

In [3]:
@dataclass(frozen=True)
class BaseConfig:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    CPU_COUNT = os.cpu_count() or 1
    NUM_WORKERS = min(1, CPU_COUNT)
    SEED = 42

    # Paths
    ROOT_DIR: Path = Path.cwd()
    CONFIGS_DIR: Path = ROOT_DIR / "configs"
    DATA_DIR: Path = ROOT_DIR / "data"
    MODELS_DIR: Path = ROOT_DIR / "models"
    RESULTS_DIR: Path = ROOT_DIR / "results"
    RUNS_DIR: Path = ROOT_DIR / "runs"
    OLD_RUNS_DIR: Path = RUNS_DIR / "old_runs"

    # Training Parameters
    BATCH_SIZE = 64
    LEARNING_RATE = 0.01
    NUM_EPOCHS = 20
    MOMENTUM = 0.9
    WEIGHT_DECAY = 4e-4
    NUM_CLASSES = 100


# Create directories
config = BaseConfig()
for dir_path in [
    config.DATA_DIR,
    config.MODELS_DIR,
    config.RESULTS_DIR,
    config.CONFIGS_DIR,
    config.RUNS_DIR,
    config.OLD_RUNS_DIR,
]:
    dir_path.mkdir(parents=True, exist_ok=True)


In [4]:
@dataclass(frozen=True)
class FederatedConfig(BaseConfig):
    """Federated Learning specific configuration."""

    NUM_CLIENTS: int = 100  # K
    PARTICIPATION_RATE: float = 0.1  # C
    LOCAL_EPOCHS: int = 4  # J
    NUM_ROUNDS: int = 2000
    CLASSES_PER_CLIENT: Optional[int] = None  # None for IID
    PARTICIPATION_MODE: str = "uniform"
    DIRICHLET_ALPHA: Optional[float] = None

## Model

In [5]:
class LeNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.conv_block1 = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=5),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )

        self.connected = nn.Sequential(
            nn.Linear(5 * 5 * 64, 384),
            nn.ReLU(),
            nn.Linear(384, 192),
            nn.ReLU(),
        )

        self.classifier = nn.Linear(192, config.NUM_CLASSES)

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = x.view(x.size(0), -1)
        x = self.connected(x)
        x = self.classifier(x)
        return x

## Metrics Logger

This class is used to save the model training in a way to be analyzed using tensorboard 

In [6]:
class MetricsManager:
    """Manages logging and visualization of training metrics."""

    def __init__(
        self,
        config: BaseConfig,
        model_name: str,
        training_type: Literal["centralized", "federated"],
        experiment_name: Optional[str] = None,
    ):
        self.config = config
        timestamp = datetime.now().strftime("%Y%m%d-%H%M%S")

        # Archive old runs
        old_runs = list(config.RUNS_DIR.glob(f"{training_type}_{model_name}_*"))
        if old_runs:
            archive_dir = config.OLD_RUNS_DIR
            archive_dir.mkdir(exist_ok=True)
            for run in old_runs:
                run.rename(archive_dir / run.name)

        # Create descriptive run name for FL experiments
        if experiment_name:
            experiment_suffix = f"{experiment_name}_{timestamp}"
        elif training_type == "federated" and isinstance(config, FederatedConfig):
            distribution = (
                "iid"
                if config.CLASSES_PER_CLIENT is None
                else f"noniid_{config.CLASSES_PER_CLIENT}cls"
            )
            participation = f"{config.PARTICIPATION_MODE}"
            if config.PARTICIPATION_MODE == "skewed":
                participation += f"_alpha{config.DIRICHLET_ALPHA}"
            clients_info = f"C{config.NUM_CLIENTS}_P{config.PARTICIPATION_RATE}_E{config.LOCAL_EPOCHS}"
            experiment_suffix = (
                f"{distribution}_{participation}_{clients_info}_{timestamp}"
            )
        else:
            experiment_suffix = timestamp

        run_name = f"{training_type}_{model_name}_{experiment_suffix}"
        self.writer = SummaryWriter(config.RUNS_DIR / run_name)

    def log_metrics(
        self,
        split: Literal["train", "validation", "test"],
        loss: float,
        accuracy: float,
        step: int,
    ) -> None:
        """Log metrics for specified split."""
        self.writer.add_scalars("metrics/loss", {split: loss}, step)
        self.writer.add_scalars("metrics/accuracy", {split: accuracy}, step)

    def log_fl_metrics(
        self, round_idx: int, metrics: Dict, client_stats: Optional[Dict] = None
    ) -> None:
        """Log federated learning specific metrics."""
        # Log test metrics

        # The get method return None if the key is not found,
        # thus we can safely unpack and check if the logging is for the validation or test set
        val_loss, val_accuracy, test_loss, test_accuracy = (
            metrics.get("val_loss"),
            metrics.get("val_accuracy"),
            metrics.get("test_loss"),
            metrics.get("test_accuracy"),
        )

        if (
            val_loss is None
            and val_accuracy is None
            and test_loss is None
            and test_accuracy is None
        ):
            raise ValueError("No validation or test metrics provided.")

        if val_loss and val_accuracy and test_loss and test_accuracy:
            raise ValueError(
                "Both validation and test metrics provided. Provide only one."
            )

        if val_loss and val_accuracy:
            self.log_metrics("validation", val_loss, val_accuracy, round_idx)

        if test_loss and test_accuracy:
            self.log_metrics("test", test_loss, test_accuracy, round_idx)

        # Log client participation if available
        if client_stats:
            self.writer.add_scalars(
                "federated/client_stats",
                client_stats,
                round_idx,
            )

    def close(self) -> None:
        """Close TensorBoard writer."""
        self.writer.close()

## Dataset Manager

Here the CIFAR100 is downloaded and the train, validation and test split are constructed to be used later

In [7]:
class Cifar100DatasetManager:
    config: BaseConfig
    validation_split: float
    train_transform: transforms.Compose
    test_transform: transforms.Compose
    train_loader: DataLoader[CIFAR100]
    val_loader: DataLoader[CIFAR100]
    test_loader: DataLoader[CIFAR100]

    def __init__(self, config: BaseConfig, validation_split: float = 0.1) -> None:
        self.config = config
        self.validation_split = validation_split

        self.train_transform = transforms.Compose(
            [
                transforms.RandomCrop(32, padding=4),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]
                ),
            ]
        )

        self.test_transform = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Normalize(
                    [0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]
                ),
            ]
        )

        self.train_loader, self.val_loader, self.test_loader = self._prepare_data()

    def _prepare_data(
        self,
    ) -> Tuple[DataLoader[CIFAR100], DataLoader[CIFAR100], DataLoader[CIFAR100]]:
        full_trainset: CIFAR100 = CIFAR100(
            root=self.config.DATA_DIR,
            train=True,
            download=True,
            transform=self.train_transform,
        )

        train_size: int = int((1 - self.validation_split) * len(full_trainset))
        val_size: int = len(full_trainset) - train_size

        train_dataset, val_dataset = random_split(
            full_trainset,
            [train_size, val_size],
            generator=torch.Generator().manual_seed(self.config.SEED),
        )

        test_dataset: CIFAR100 = CIFAR100(
            root=self.config.DATA_DIR,
            train=False,
            download=False,
            transform=self.test_transform,
        )

        loader_kwargs = {"num_workers": self.config.NUM_WORKERS, "pin_memory": True}

        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.BATCH_SIZE,
            shuffle=True,
            **loader_kwargs,
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=self.config.BATCH_SIZE,
            shuffle=False,
            **loader_kwargs,
        )

        test_loader: DataLoader[CIFAR100] = DataLoader(
            test_dataset,
            batch_size=self.config.BATCH_SIZE,
            shuffle=False,
            **loader_kwargs,
        )

        return train_loader, val_loader, test_loader

    @property
    def train_dataset(self) -> Dataset[CIFAR100]:
        return self.train_loader.dataset

    @property
    def val_dataset(self) -> Dataset[CIFAR100]:
        return self.val_loader.dataset

    @property
    def test_dataset(self) -> Dataset[CIFAR100]:
        return self.test_loader.dataset

## Centralized Trainer

This class is responsible to train and evaluate the model (here only considered for typing the LeNet model) in the traditional sense.
Local training with normal train and evaluate methods

In [8]:
class CentralizedTrainer:
    model: LeNet
    config: BaseConfig
    device: torch.device
    metrics: MetricsManager

    def __init__(self, model: LeNet, config: BaseConfig) -> None:
        self.model = model.to(config.DEVICE)
        self.config = config
        self.device = config.DEVICE
        self.metrics = MetricsManager(
            config, model.__class__.__name__.lower(), "centralized"
        )

    def evaluate_model(
        self, model: LeNet, data_loader: DataLoader[CIFAR100]
    ) -> Tuple[float, float]:
        model.eval()
        total_loss = 0.0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with torch.no_grad():
            for inputs, targets in data_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                total_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        avg_loss: float = total_loss / total
        accuracy: float = 100.0 * correct / total
        return avg_loss, accuracy

    def train(
        self,
        train_loader: DataLoader[CIFAR100],
        val_loader: DataLoader[CIFAR100],
        test_loader: DataLoader[CIFAR100],
        max_epochs: Optional[int] = None,
    ) -> None:
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.SGD(
            self.model.parameters(),
            lr=self.config.LEARNING_RATE,
            momentum=self.config.MOMENTUM,
            weight_decay=self.config.WEIGHT_DECAY,
        )

        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer=optimizer, T_max=self.config.NUM_EPOCHS
        )

        if max_epochs is None:
            max_epochs = self.config.NUM_EPOCHS

        best_val_loss = float("inf")
        best_model_state = None
        patience = 10
        patience_counter = 0
        train_acc = 0.0
        avg_train_loss = 0.0
        device_type = str(self.device)

        epoch_pbar = tqdm(
            range(max_epochs or self.config.NUM_EPOCHS),
            desc="Training",
            unit="epoch",
            position=0,
            leave=True,
        )
        epoch = 0

        try:
            for epoch in epoch_pbar:
                self.model.train()
                train_loss = 0
                correct = 0
                total = 0

                for batch_idx, (inputs, targets) in enumerate(train_loader):
                    inputs = inputs.to(self.device, non_blocking=True)
                    targets = targets.to(self.device, non_blocking=True)

                    optimizer.zero_grad(set_to_none=True)

                    with torch.amp.autocast_mode.autocast(device_type=device_type):
                        outputs = self.model(inputs)
                        loss = criterion(outputs, targets)

                    loss.backward()
                    optimizer.step()

                    train_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += targets.size(0)
                    correct += predicted.eq(targets).sum().item()

                    # Update metrics
                    train_acc = 100.0 * correct / total
                    avg_train_loss = train_loss / (batch_idx + 1)

                    global_step = epoch * len(train_loader) + batch_idx
                    self.metrics.log_metrics(
                        "train", avg_train_loss, train_acc, global_step
                    )

                    del inputs, targets, outputs, loss
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                # Validation phase
                val_loss, val_acc = self.evaluate_model(self.model, val_loader)
                scheduler.step()

                self.metrics.log_metrics("validation", val_loss, val_acc, epoch)
                epoch_pbar.set_postfix(
                    {
                        "ep": f"{epoch+1}/{max_epochs or self.config.NUM_EPOCHS}",
                        "tr_loss": f"{avg_train_loss:.3f}",
                        "tr_acc": f"{train_acc:.1f}%",
                        "val_loss": f"{val_loss:.3f}",
                        "val_acc": f"{val_acc:.1f}%",
                    },
                    refresh=True,
                )

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_model_state = self.model.state_dict().copy()
                    patience_counter = 0
                else:
                    patience_counter += 1

                if patience_counter >= patience:
                    logging.info(f"Early stopping at epoch {epoch}")
                    break

            else:
                logging.info("Training completed!")
            # Final evaluation
            if best_model_state is not None:
                self.model.load_state_dict(best_model_state)

            test_loss, test_acc = self.evaluate_model(self.model, test_loader)
            self.metrics.log_metrics("test", test_loss, test_acc, epoch)
            logging.info(
                f"Final Test Results - Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%"
            )

        finally:
            self.metrics.close()

## Federated Class

The core of this project.
After setting a baseline using the CentralizedTrainer in the Cell above, the following class is responsible to generate, train, aggregate and evaluate the distributed model on client, with different settings:

### Dataset modularity:
* iid distribution of the classes for the clients
* non_iid distribution
* custom number of classes per clients

### Clients Participation
* Uniform selection of the clients
* Dirichlet distribution to select the clients

### Others
This class also consider creating checkpoint after each generation if the global stats improves to resume training at a later time. This may be useful, for example, in the following cases:
1. Federated Learning takes a really long times, thus if the real scenario the amount of client that can participate is too low, the training can be paused and resumed when the availability is increased.
2. Academic reasons:
   1. The code takes 50h to complete, pause it if you need your pc for something else, and then resume it when it is free to compute other generations.
   2. Google Colab limit the amount of resources per day, when the GPU session end, wait until the resource returns available to resume with the latest best model found.

In [9]:
class FederatedTrainer:
    def __init__(
        self,
        model: LeNet,
        train_dataset: Dataset[CIFAR100],
        val_loader: DataLoader[CIFAR100],
        test_loader: DataLoader[CIFAR100],
        config: FederatedConfig,
        experiment_name: Optional[str] = None,
    ) -> None:
        self.config = config
        self.global_model = model.to(config.DEVICE)
        self.device = config.DEVICE
        self.device_type = str(config.DEVICE)
        self.val_loader = val_loader
        self.test_loader = test_loader

        # Pre-create all client shards and dataloaders
        self.num_clients = config.NUM_CLIENTS
        self.client_loaders = self._setup_client_data(train_dataset)

        # Setup client selection
        if config.PARTICIPATION_MODE == "skewed":
            if config.DIRICHLET_ALPHA is None:
                raise ValueError("dirichlet_alpha required for skewed mode")
            self.selection_probs = np.random.dirichlet(
                [config.DIRICHLET_ALPHA] * config.NUM_CLIENTS
            )
        else:
            self.selection_probs = np.ones(config.NUM_CLIENTS) / config.NUM_CLIENTS

        self.metrics = MetricsManager(
            config, model.__class__.__name__.lower(), "federated", experiment_name
        )

        # Pre-allocate client models
        self.client_models = [copy.deepcopy(model) for _ in range(config.NUM_CLIENTS)]
        self.checkpoint_dir = config.MODELS_DIR / "federated"
        self.checkpoint_dir.mkdir(exist_ok=True)
        self.checkpoint_name = (
            f"{'iid' if config.CLASSES_PER_CLIENT is None else f'noniid_{config.CLASSES_PER_CLIENT}cls'}"
            f"_{config.PARTICIPATION_MODE}"
            f"_C{config.NUM_CLIENTS}_P{config.PARTICIPATION_RATE}_E{config.LOCAL_EPOCHS}.pt"
        )
        self.checkpoint_path = self.checkpoint_dir / self.checkpoint_name

    def _setup_client_data(
        self, dataset: Dataset[CIFAR100]
    ) -> List[DataLoader[CIFAR100]]:
        shards: List[Subset[CIFAR100]] = (
            self._create_iid_shards(dataset)
            if self.config.CLASSES_PER_CLIENT is None
            else self._create_noniid_shards(dataset)
        )

        return [
            DataLoader(
                shard,
                batch_size=self.config.BATCH_SIZE,
                shuffle=True,
                num_workers=self.config.NUM_WORKERS,
                pin_memory=True,
                persistent_workers=True,
                prefetch_factor=3,
                drop_last=True,
            )
            for shard in shards
        ]

    def _create_iid_shards(self, dataset: Dataset[CIFAR100]) -> List[Subset[CIFAR100]]:
        """Create IID data shards."""
        if len(dataset) == 0:
            raise ValueError("Empty dataset")

        indices = np.random.permutation(len(dataset))
        shard_size = len(dataset) // self.num_clients

        return [
            Subset(dataset, indices[i : i + shard_size])
            for i in range(0, len(indices), shard_size)
        ]

    def _create_noniid_shards(
        self, dataset: Dataset[CIFAR100]
    ) -> List[Subset[CIFAR100]]:
        """Create non-IID data shards using class distribution."""
        if not hasattr(dataset, "targets"):
            raise ValueError(
                "Dataset must have 'targets' attribute for non-IID sharding"
            )

        targets = np.array(dataset.targets)
        class_indices = {
            label: np.where(targets == label)[0]
            for label in range(len(dataset.classes))
        }

        client_indices = []
        for i in range(self.num_clients):
            indices = []
            # Select random classes for this client
            selected_classes = np.random.choice(
                list(class_indices.keys()),
                size=min(5, len(class_indices)),  # Default to 5 classes per client
                replace=False,
            )

            # Add samples from each selected class
            for class_label in selected_classes:
                class_samples = np.random.choice(
                    class_indices[class_label],
                    size=len(class_indices[class_label]) // self.num_clients,
                    replace=False,
                )
                indices.extend(class_samples)

            client_indices.append(Subset(dataset, indices))

        return client_indices

    def _evaluate(self, loader: DataLoader[CIFAR100]) -> Tuple[float, float]:
        """Evaluate model on given data loader.
        # Returns
            Tuple of (loss, accuracy)
        """
        self.global_model.eval()
        total_loss = 0
        correct = 0
        total = 0
        criterion = nn.CrossEntropyLoss()

        with (
            torch.no_grad(),
            torch.amp.autocast_mode.autocast(device_type=self.device_type),
        ):
            for inputs, targets in loader:
                inputs = inputs.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)

                outputs = self.global_model(inputs)
                loss = criterion(outputs, targets)

                total_loss += loss.item() * inputs.size(0)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        return total_loss / total, 100.0 * correct / total

    def _aggregate_models(self, selected_clients: List[int] | npt.NDArray) -> None:
        """Aggregate models using weighted average based on dataset sizes."""
        with (
            torch.no_grad(),
            torch.amp.autocast_mode.autocast(device_type=self.device_type),
        ):
            # Calculate total samples across selected clients
            total_samples = sum(
                len(self.client_loaders[idx].dataset) for idx in selected_clients
            )

            # Initialize aggregated parameters
            for k, v in self.global_model.state_dict().items():
                weighted_sum = torch.zeros_like(v)
                for idx in selected_clients:
                    # Get client's weight based on dataset size
                    client_weight = (
                        len(self.client_loaders[idx].dataset) / total_samples
                    )
                    client_params = (
                        self.client_models[idx].state_dict()[k].to(self.device)
                    )
                    weighted_sum.add_(client_params * client_weight)

                # Update global model
                v.copy_(weighted_sum)

    def save_checkpoint(self, round_idx: int, best_val_loss: float) -> None:
        """Save model checkpoint with training state."""
        checkpoint = {
            "round": round_idx,
            "model_state_dict": self.global_model.state_dict(),
            "best_val_loss": best_val_loss,
            "config": self.config,
        }
        torch.save(checkpoint, self.checkpoint_path)
        logging.info(f"Checkpoint saved: {self.checkpoint_path}")

    def load_checkpoint(self) -> Tuple[int, float]:
        """Load model checkpoint and return training state."""
        if not self.checkpoint_path.exists():
            return 0, float("inf")

        checkpoint = torch.load(self.checkpoint_path)

        # Verify config matches
        saved_config = checkpoint["config"]
        if (
            saved_config.NUM_CLIENTS != self.config.NUM_CLIENTS
            or saved_config.CLASSES_PER_CLIENT != self.config.CLASSES_PER_CLIENT
            or saved_config.PARTICIPATION_MODE != self.config.PARTICIPATION_MODE
        ):
            logging.warning("Config mismatch in checkpoint, starting fresh training")
            return 0, float("inf")

        self.global_model.load_state_dict(checkpoint["model_state_dict"])
        logging.info(f"Resumed from checkpoint: {self.checkpoint_path}")
        return checkpoint["round"], checkpoint["best_val_loss"]

    def train_client(self, client_idx: int, model: LeNet) -> None:
        """Train a single client in-place."""
        model.train()
        optimizer = torch.optim.SGD(
            model.parameters(),
            lr=self.config.LEARNING_RATE,
            momentum=self.config.MOMENTUM,
            weight_decay=self.config.WEIGHT_DECAY,
            nesterov=True,
        )
        criterion = nn.CrossEntropyLoss()
        scaler = torch.amp.grad_scaler.GradScaler(device=self.device_type)

        for _ in range(self.config.LOCAL_EPOCHS):
            for inputs, targets in self.client_loaders[client_idx]:
                inputs = inputs.to(self.device, non_blocking=True)
                targets = targets.to(self.device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)

                with torch.amp.autocast_mode.autocast(device_type=self.device_type):
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                del inputs, targets, outputs, loss
                torch.cuda.empty_cache()

    def train(self) -> None:
        # Load existing checkpoint if available
        start_round, best_val_loss = self.load_checkpoint()
        best_model_state = (
            self.global_model.state_dict().copy() if start_round > 0 else None
        )

        best_val_acc = 0.0

        if start_round > 0:
            logging.info(f"Resuming training from round {start_round}")

        round_pbar = tqdm(
            range(start_round, self.config.NUM_ROUNDS),
            desc="Training",
            unit="round",
            initial=start_round,
            colour="green",
        )

        try:
            for round_idx in round_pbar:
                # Select clients
                num_selected = max(
                    1, int(self.config.PARTICIPATION_RATE * self.config.NUM_CLIENTS)
                )
                selected_clients = np.random.choice(
                    self.config.NUM_CLIENTS,
                    size=num_selected,
                    replace=False,
                    p=self.selection_probs,
                )

                # Train selected clients in parallel
                for idx in selected_clients:
                    self.client_models[idx].load_state_dict(
                        self.global_model.state_dict()
                    )
                    self.train_client(idx, self.client_models[idx])

                # Aggregate models
                self._aggregate_models(selected_clients)

                # Evaluate
                val_loss, val_acc = self._evaluate(self.val_loader)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_val_acc = val_acc
                    best_model_state = self.global_model.state_dict().copy()
                    self.save_checkpoint(round_idx, best_val_loss)

                round_pbar.set_postfix(
                    {
                        "val_loss": f"{val_loss:.4f}",
                        "val_acc": f"{val_acc:.2f}%",
                        "best_val_loss": f"{best_val_loss:.4f}",
                        "best_val_acc": f"{best_val_acc:.2f}%",
                    },
                    refresh=True,
                )

                self.metrics.log_fl_metrics(
                    round_idx,
                    {"val_loss": val_loss, "val_accuracy": val_acc},
                    {"num_selected": len(selected_clients)},
                )

            # Final evaluation
            if best_model_state:
                self.global_model.load_state_dict(best_model_state)
            test_loss, test_acc = self._evaluate(self.test_loader)
            self.metrics.log_metrics(
                "test", test_loss, test_acc, self.config.NUM_ROUNDS
            )

            self.save_checkpoint(self.config.NUM_ROUNDS, best_val_loss)

        finally:
            self.metrics.close()

In [10]:
data = Cifar100DatasetManager(config)

Files already downloaded and verified


## Centralized - Baseline

The model is able to achieve the following result:
* Test Loss: 1.9331
* Test Accuracy: 49.29%

In [11]:
# model = LeNet(config)
# trainer = CentralizedTrainer(model, config)
# trainer.train(data.train_loader, data.val_loader, data.test_loader)


## FL - Baseline

In [None]:
# IID baseline (K=100, C=0.1, J=4)
iid_config = FederatedConfig(
    NUM_CLIENTS=100,
    PARTICIPATION_RATE=0.1,
    LOCAL_EPOCHS=4,
    NUM_ROUNDS=2000,
    CLASSES_PER_CLIENT=None,  # IID
)

# Non-IID with Nc=1
noniid_config = dataclasses.replace(
    iid_config,
    CLASSES_PER_CLIENT=1,  # Most heterogeneous case
)

# Run experiments
for config, name in [(iid_config, "iid"), (noniid_config, "noniid_1class")]:
    model = LeNet(config)
    trainer = FederatedTrainer(
        model=model,
        train_dataset=data.train_dataset,
        val_loader=data.val_loader,
        test_loader=data.test_loader,
        config=config,
        experiment_name=name,
    )
    trainer.train()

Training:   0%|[32m          [0m| 1/2000 [00:44<24:53:24, 44.82s/round, val_loss=4.6015, val_acc=1.90%, best_val_loss=4.6015, best_val_acc=1.90%]

[K2025-01-18 16:59:26 [[1;32mINFO[0m] Checkpoint saved: c:\Users\nick\Documents\GitHub\AdvanceML_project5\models\federated\iid_uniform_C100_P0.1_E4.pt


Training:   0%|[32m          [0m| 2/2000 [01:24<23:10:57, 41.77s/round, val_loss=4.5917, val_acc=2.32%, best_val_loss=4.5917, best_val_acc=2.32%]

[K2025-01-18 17:00:05 [[1;32mINFO[0m] Checkpoint saved: c:\Users\nick\Documents\GitHub\AdvanceML_project5\models\federated\iid_uniform_C100_P0.1_E4.pt


Training:   0%|[32m          [0m| 3/2000 [02:07<23:29:16, 42.34s/round, val_loss=4.5596, val_acc=1.84%, best_val_loss=4.5596, best_val_acc=1.84%]

[K2025-01-18 17:00:48 [[1;32mINFO[0m] Checkpoint saved: c:\Users\nick\Documents\GitHub\AdvanceML_project5\models\federated\iid_uniform_C100_P0.1_E4.pt
