
# Federated Training of Multiple CNN Architectures on Common Vision Datasets

This notebook demonstrates how to train three popular convolutional neural networks—ResNet-18, MobileNetV2, and AlexNet—on three standard datasets (CIFAR-10, CIFAR-100, and MNIST) using a simple federated learning setup powered by [Flower](https://flower.dev/). Each scenario uses a simulated set of federated clients so you can experiment locally without standing up separate devices.



## Environment Setup

Uncomment the following cell to install the required dependencies if they are not already available in your environment.


In [None]:

# Uncomment to install dependencies when running in a fresh environment
# !pip install torch torchvision flwr



## Imports and Utilities

We define helper functions for creating datasets, partitioning them across clients, building models with the right output dimensions, and running basic training/evaluation steps.


In [None]:

import math
from dataclasses import dataclass
from typing import Callable, Iterable, List, Tuple

import flwr as fl
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset
import torchvision
import torchvision.transforms as T

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@dataclass
class DatasetConfig:
    name: str
    num_classes: int
    input_size: Tuple[int, int, int]
    train_set: torch.utils.data.Dataset
    test_set: torch.utils.data.Dataset


def get_transforms(dataset: str) -> Tuple[T.Compose, T.Compose]:
    normalize = {
        "cifar10": ([0.4914, 0.4822, 0.4465], [0.2470, 0.2435, 0.2616]),
        "cifar100": ([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761]),
        "mnist": ([0.1307], [0.3081]),
    }[dataset]

    train_transform = T.Compose(
        [
            T.Resize(32),
            T.RandomCrop(32, padding=4),
            T.RandomHorizontalFlip(),
            T.ToTensor(),
            T.Normalize(*normalize),
        ]
    )
    test_transform = T.Compose([
        T.Resize(32),
        T.ToTensor(),
        T.Normalize(*normalize),
    ])
    return train_transform, test_transform


def load_dataset(dataset: str, root: str = "./data") -> DatasetConfig:
    dataset = dataset.lower()
    train_tf, test_tf = get_transforms(dataset)

    if dataset == "cifar10":
        train_set = torchvision.datasets.CIFAR10(root=root, train=True, download=True, transform=train_tf)
        test_set = torchvision.datasets.CIFAR10(root=root, train=False, download=True, transform=test_tf)
        num_classes = 10
        input_size = (3, 32, 32)
    elif dataset == "cifar100":
        train_set = torchvision.datasets.CIFAR100(root=root, train=True, download=True, transform=train_tf)
        test_set = torchvision.datasets.CIFAR100(root=root, train=False, download=True, transform=test_tf)
        num_classes = 100
        input_size = (3, 32, 32)
    elif dataset == "mnist":
        train_set = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=train_tf)
        test_set = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=test_tf)
        num_classes = 10
        input_size = (1, 32, 32)
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")

    return DatasetConfig(dataset, num_classes, input_size, train_set, test_set)


def partition_dataset(dataset: torch.utils.data.Dataset, num_clients: int) -> List[Subset]:
    partition_size = math.ceil(len(dataset) / num_clients)
    subsets = []
    for i in range(num_clients):
        start = i * partition_size
        end = min((i + 1) * partition_size, len(dataset))
        if start >= len(dataset):
            break
        indices = list(range(start, end))
        subsets.append(Subset(dataset, indices))
    return subsets


def build_model(name: str, num_classes: int) -> nn.Module:
    name = name.lower()
    if name == "resnet18":
        model = torchvision.models.resnet18(weights=None)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif name == "mobilenet_v2":
        model = torchvision.models.mobilenet_v2(weights=None)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif name == "alexnet":
        model = torchvision.models.alexnet(weights=None)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    else:
        raise ValueError(f"Unsupported model: {name}")
    return model


def train_one_epoch(model: nn.Module, loader: DataLoader, criterion: nn.Module, optimizer: optim.Optimizer) -> float:
    model.train()
    running_loss = 0.0
    for images, labels in loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    return running_loss / len(loader.dataset)


def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module) -> Tuple[float, float]:
    model.eval()
    loss_sum = 0.0
    correct = 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss_sum += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
    return loss_sum / len(loader.dataset), correct / len(loader.dataset)



## Federated Client Definition

Each client trains locally on its partition of the dataset and reports weights and metrics back to the server.


In [None]:

class FederatedClient(fl.client.NumPyClient):
    def __init__(self, model_fn: Callable[[], nn.Module], train_subset: Subset, test_subset: Subset, epochs: int, batch_size: int):
        self.model_fn = model_fn
        self.model = self.model_fn().to(DEVICE)
        self.train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True, num_workers=2)
        self.test_loader = DataLoader(test_subset, batch_size=batch_size, shuffle=False, num_workers=2)
        self.epochs = epochs
        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.SGD(self.model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)

    def get_parameters(self, config=None):
        return [val.cpu().numpy() for val in self.model.state_dict().values()]

    def set_parameters(self, parameters: Iterable):
        state_dict = self.model.state_dict()
        for key, val in zip(state_dict.keys(), parameters):
            state_dict[key] = torch.tensor(val)
        self.model.load_state_dict(state_dict)

    def fit(self, parameters, config=None):
        self.set_parameters(parameters)
        for _ in range(self.epochs):
            train_one_epoch(self.model, self.train_loader, self.criterion, self.optimizer)
        return self.get_parameters(), len(self.train_loader.dataset), {}

    def evaluate(self, parameters, config=None):
        self.set_parameters(parameters)
        loss, accuracy = evaluate(self.model, self.test_loader, self.criterion)
        return float(loss), len(self.test_loader.dataset), {"accuracy": float(accuracy)}



## Federated Training Helper

The `run_federated_experiment` function configures a dataset, partitions it among simulated clients, and launches a Flower simulation using FedAvg. You can swap in any of the supported datasets and models.


In [None]:

def run_federated_experiment(
    dataset_name: str,
    model_name: str,
    num_clients: int = 5,
    rounds: int = 3,
    local_epochs: int = 1,
    batch_size: int = 64,
):
    cfg = load_dataset(dataset_name)
    train_subsets = partition_dataset(cfg.train_set, num_clients)
    test_subsets = partition_dataset(cfg.test_set, num_clients)

    def client_fn(cid: str):
        idx = int(cid)
        return FederatedClient(
            model_fn=lambda: build_model(model_name, cfg.num_classes),
            train_subset=train_subsets[idx],
            test_subset=test_subsets[idx % len(test_subsets)],
            epochs=local_epochs,
            batch_size=batch_size,
        )

    strategy = fl.server.strategy.FedAvg(
        fraction_fit=1.0,
        fraction_evaluate=1.0,
        min_fit_clients=len(train_subsets),
        min_evaluate_clients=len(test_subsets),
        min_available_clients=len(train_subsets),
        evaluate_metrics_aggregation_fn=lambda metrics: sum(m["accuracy"] * n for _, n, m in metrics) / sum(n for _, n, _ in metrics),
    )

    print(f"Starting federated run: dataset={dataset_name}, model={model_name}, clients={len(train_subsets)}")
    history = fl.simulation.start_simulation(
        client_fn=client_fn,
        num_clients=len(train_subsets),
        config=fl.server.ServerConfig(num_rounds=rounds),
        strategy=strategy,
    )
    return history



## Running Experiments

The following examples configure short federated runs to keep execution time reasonable. Increase `rounds` or `local_epochs` for deeper training when resources allow.


In [None]:

# Example: ResNet-18 on CIFAR-10
# history_resnet_cifar10 = run_federated_experiment("cifar10", "resnet18", num_clients=5, rounds=3, local_epochs=1)

# Example: MobileNetV2 on CIFAR-100
# history_mobilenet_cifar100 = run_federated_experiment("cifar100", "mobilenet_v2", num_clients=5, rounds=3, local_epochs=1)

# Example: AlexNet on MNIST
# history_alexnet_mnist = run_federated_experiment("mnist", "alexnet", num_clients=3, rounds=3, local_epochs=1)



## Visualizing Accuracy Trends

If you run multiple experiments, you can collect their histories and plot global accuracy across rounds.


In [None]:

import matplotlib.pyplot as plt


def plot_history(history, label: str):
    rounds = list(range(1, len(history.metrics_distributed)["accuracy"] + 1))
    accuracies = [entry[1] for entry in history.metrics_distributed["accuracy"]]
    plt.plot(rounds, accuracies, marker="o", label=label)
    plt.xlabel("Round")
    plt.ylabel("Global accuracy")
    plt.title("Federated Accuracy over Rounds")
    plt.legend()
    plt.grid(True)
    plt.show()

# Example usage after running experiments:
# plot_history(history_resnet_cifar10, "ResNet18 CIFAR-10")
