
# Federated Training of CNNs on MNIST, CIFAR-10, and CIFAR-100 (No Flower)

This notebook demonstrates a simple, standalone federated learning simulation **without** Flower. It trains ResNet-18, MobileNetV2, and AlexNet on MNIST, CIFAR-10, or CIFAR-100 using a basic FedAvg loop. Each round saves the model update for every client so you can inspect or reuse them later.


In [None]:
import os
from pathlib import Path
from typing import Dict, List, Tuple

import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision import datasets, models, transforms

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {DEVICE}")

BASE_DIR = Path("model_updates")
BASE_DIR.mkdir(exist_ok=True)



## Dataset utilities

The helpers below download datasets, apply standard normalization transforms, and partition training data across a configurable number of clients.


In [None]:
def get_transforms(dataset: str):
    dataset = dataset.lower()
    if dataset == "mnist":
        # Replicate the grayscale channel so models expecting RGB inputs still work.
        mean, std = (0.1307, 0.1307, 0.1307), (0.3081, 0.3081, 0.3081)
        size = 28
        tfms = [transforms.Resize(size), transforms.Grayscale(num_output_channels=3)]
    else:
        mean, std = (0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)
        size = 32
        tfms = [transforms.Resize(size)]
    tfms.extend([transforms.ToTensor(), transforms.Normalize(mean, std)])
    return transforms.Compose(tfms)


def load_datasets(dataset: str, data_dir: str = "./data") -> Tuple[Dataset, Dataset]:
    dataset = dataset.lower()
    tfm = get_transforms(dataset)
    if dataset == "mnist":
        train = datasets.MNIST(data_dir, train=True, download=True, transform=tfm)
        test = datasets.MNIST(data_dir, train=False, download=True, transform=tfm)
    elif dataset == "cifar10":
        train = datasets.CIFAR10(data_dir, train=True, download=True, transform=tfm)
        test = datasets.CIFAR10(data_dir, train=False, download=True, transform=tfm)
    elif dataset == "cifar100":
        train = datasets.CIFAR100(data_dir, train=True, download=True, transform=tfm)
        test = datasets.CIFAR100(data_dir, train=False, download=True, transform=tfm)
    else:
        raise ValueError(f"Unsupported dataset: {dataset}")
    return train, test


def partition_dataset(dataset: Dataset, num_clients: int) -> List[Subset]:
    indices = torch.randperm(len(dataset))
    shards = torch.chunk(indices, num_clients)
    return [Subset(dataset, shard.tolist()) for shard in shards]


## Model builders

We reuse torchvision reference implementations. AlexNet and other ImageNet-style backbones expect 3-channel inputs, so MNIST images are converted to RGB via the transform above.


In [None]:
def build_model(arch: str, num_classes: int) -> nn.Module:
    arch = arch.lower()
    if arch == "resnet18":
        model = models.resnet18(weights=None)
        model.fc = nn.Linear(model.fc.in_features, num_classes)
    elif arch == "mobilenetv2":
        model = models.mobilenet_v2(weights=None)
        model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)
    elif arch == "alexnet":
        model = models.alexnet(weights=None)
        model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    else:
        raise ValueError(f"Unsupported architecture: {arch}")
    return model



## Training and evaluation helpers

Each client performs standard supervised training locally. The server aggregates client weights via FedAvg, and we log accuracy after every round.


In [None]:
def train_one_client(model: nn.Module, loader: DataLoader, epochs: int, lr: float) -> Tuple[nn.Module, float]:
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
    model.train()
    total_loss = 0.0
    for _ in range(epochs):
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * images.size(0)
    avg_loss = total_loss / len(loader.dataset)
    return model, avg_loss


def evaluate(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return correct / total


def fedavg(global_model: nn.Module, client_states: List[Dict[str, torch.Tensor]]):
    avg_state: Dict[str, torch.Tensor] = {}
    for key in global_model.state_dict().keys():
        tensors = [state[key] for state in client_states]
        if tensors[0].dtype.is_floating_point:
            stacked = torch.stack([t.detach() for t in tensors], dim=0)
            avg_state[key] = stacked.mean(dim=0)
        else:
            avg_state[key] = tensors[0].detach().clone()
    global_model.load_state_dict(avg_state)
    return global_model


def save_client_update(round_idx: int, client_idx: int, global_before: Dict[str, torch.Tensor], client_after: Dict[str, torch.Tensor]):
    update = {k: client_after[k].cpu() - global_before[k].cpu() for k in global_before.keys()}
    path = BASE_DIR / f"round_{round_idx:03d}_client_{client_idx:02d}.pt"
    torch.save(update, path)
    return path



## Federated simulation loop

The `run_federated_training` function orchestrates the process:
1. Load and partition the dataset across clients.
2. Train each client locally for a few epochs.
3. Save every client's model update for the round.
4. Aggregate updates into the global model via FedAvg.
5. Evaluate the global model after each round.


In [None]:
def run_federated_training(
    dataset: str = "mnist",
    architecture: str = "resnet18",
    num_clients: int = 5,
    client_epochs: int = 1,
    rounds: int = 3,
    batch_size: int = 64,
    lr: float = 0.01,
):
    train_ds, test_ds = load_datasets(dataset)
    num_classes = len(train_ds.classes) if hasattr(train_ds, "classes") else 10
    clients = partition_dataset(train_ds, num_clients)

    test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=torch.cuda.is_available())

    global_model = build_model(architecture, num_classes).to(DEVICE)

    history = {"round": [], "accuracy": []}
    for rnd in range(1, rounds + 1):
        client_states = []
        global_before = {k: v.detach().clone() for k, v in global_model.state_dict().items()}

        for cid, subset in enumerate(clients):
            loader = DataLoader(subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=torch.cuda.is_available())
            client_model = build_model(architecture, num_classes).to(DEVICE)
            client_model.load_state_dict(global_model.state_dict())
            client_model, loss = train_one_client(client_model, loader, epochs=client_epochs, lr=lr)
            client_state = {k: v.detach().clone() for k, v in client_model.state_dict().items()}
            client_states.append(client_state)
            path = save_client_update(rnd, cid, global_before, client_state)
            print(f"Saved update for round {rnd}, client {cid} at {path}")

        global_model = fedavg(global_model, client_states)
        acc = evaluate(global_model, test_loader)
        history["round"].append(rnd)
        history["accuracy"].append(acc)
        print(f"Round {rnd}: global accuracy={acc:.4f}")

    return global_model, history



## Example run (quick sanity check)

The example below runs a very small simulation for speed. Increase the number of rounds, clients, or epochs for more meaningful results.


In [None]:
# Warning: running this cell will download datasets if not already present.
# To keep runtime short in limited environments, use small values.

model, history = run_federated_training(
    dataset="mnist", architecture="resnet18", num_clients=3, client_epochs=1, rounds=2, batch_size=64, lr=0.01
)
print(history)
