"""
ATML PA4 — Task 4 from scratch (no reuse):
A clean, single-file PyTorch implementation of a small federated learning framework
with four heterogeneity-mitigation strategies:
- FedProx (local proximal regularization)
- SCAFFOLD (control variates)
- FedGH (server-side gradient harmonization)
- FedSAM (sharpness-aware minimization on clients)


Also includes:
- CIFAR-10 loading and Dirichlet non-IID partitioning
- Simple CNN model
- FedAvg baseline
- Weighted aggregation, client drift metric, logging hooks


HOW TO USE (example):
python ATML-PA4-Task4.py --strategy fedavg --alpha 0.1 --num-clients 10 --rounds 50 --K 5
python ATML-PA4-Task4.py --strategy fedprox --mu 0.01 --alpha 0.1 --num-clients 10 --rounds 50 --K 5
python ATML-PA4-Task4.py --strategy scaffold --alpha 0.1 --num-clients 10 --rounds 50 --K 5
python ATML-PA4-Task4.py --strategy fedgh --alpha 0.1 --num-clients 10 --rounds 50 --K 5
python ATML-PA4-Task4.py --strategy fedsam --rho 0.05 --alpha 0.1 --num-clients 10 --rounds 50 --K 5


Notes:
* Keep the model small to fit within assignment constraints.
* SCAFFOLD doubles comms (sends control variates). FedSAM ~2x local compute.
* FedGH adds O(M^2) server-time pairwise projections per round.


This file is deliberately verbose and self-contained for clarity and grading.
"""

In [1]:
from __future__ import annotations
import argparse
import copy
import math
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Iterable, Optional


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

# torchvision is permissible for CIFAR-10
from torchvision import datasets, transforms

# Reproducibility helpers

def set_seed(seed: int = 42):
  random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  os.environ["PYTHONHASHSEED"] = str(seed)

In [2]:
# ---------------------------
# Simple CNN for CIFAR-10
# ---------------------------
class SmallCNN(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 8x8
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


# ---------------------------
# Dirichlet non-IID partition
# ---------------------------

def dirichlet_partition_indices(
    targets: torch.Tensor, num_clients: int, alpha: float, seed: int = 42
) -> List[List[int]]:
    """Split dataset indices into num_clients using class-wise Dirichlet(α) proportions.
    Smaller α => higher label skew.
    """
    g = torch.Generator().manual_seed(seed)
    num_classes = int(targets.max().item() + 1)
    class_indices = [torch.where(targets == c)[0].tolist() for c in range(num_classes)]
    for ci in class_indices:
        random.shuffle(ci)

    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        # sample proportions for this class
        proportions = torch.distributions.Dirichlet(torch.full((num_clients,), alpha)).sample()
        proportions = (proportions / proportions.sum()).tolist()
        cls_ids = class_indices[c]
        # split cls_ids according to proportions
        prev = 0
        for k in range(num_clients):
            take = int(round(proportions[k] * len(cls_ids)))
            client_indices[k].extend(cls_ids[prev : prev + take])
            prev += take
        # in case of rounding leftovers, dump remainder into last client
        if prev < len(cls_ids):
            client_indices[-1].extend(cls_ids[prev:])

    # shuffle each client list
    for k in range(num_clients):
        random.shuffle(client_indices[k])
    return client_indices


# ---------------------------
# Utilities for (de)flattening params and deltas
# ---------------------------

def get_model_params_vector(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.detach().view(-1) for p in model.parameters()])


def get_model_grads_vector(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.grad.detach().view(-1) if p.grad is not None else torch.zeros_like(p).view(-1) for p in model.parameters()])


def assign_params_from_vector(model: nn.Module, vec: torch.Tensor):
    offset = 0
    with torch.no_grad():
        for p in model.parameters():
            numel = p.numel()
            p.copy_(vec[offset : offset + numel].view_as(p))
            offset += numel


def add_inplace(tensors: Iterable[torch.Tensor], alphas: Iterable[float], out: torch.Tensor):
    """out = sum(alpha_i * tensor_i). Assumes flat vectors of equal shape."""
    out.zero_()
    for t, a in zip(tensors, alphas):
        out.add_(t, alpha=a)

In [3]:
# ---------------------------
# Client logic (baseline + hooks)
# ---------------------------
@dataclass
class ClientConfig:
    lr: float = 0.01
    momentum: float = 0.9
    batch_size: int = 64
    local_epochs: int = 5  # K
    mu: float = 0.0  # for FedProx
    rho: float = 0.0  # for FedSAM


class Client:
    def __init__(
        self,
        cid: int,
        dataset: torch.utils.data.Dataset,
        indices: List[int],
        device: torch.device,
        cfg: ClientConfig,
        strategy: str,
        model_template: nn.Module,
        scaffold_ci_template: Optional[List[torch.Tensor]] = None,
    ):
        self.cid = cid
        self.device = device
        self.cfg = cfg
        self.strategy = strategy.lower()
        self.loader = DataLoader(Subset(dataset, indices), batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)
        self.model = copy.deepcopy(model_template).to(device)
        # SCAFFOLD control variate for this client (list of tensors matching params)
        if self.strategy == "scaffold":
            assert scaffold_ci_template is not None
            self.ci = [torch.zeros_like(t, device=device) for t in scaffold_ci_template]
        else:
            self.ci = None

    def set_model_from_global(self, global_model: nn.Module):
        self.model.load_state_dict(copy.deepcopy(global_model.state_dict()))

    def _scaffold_apply_correction(self, model: nn.Module, c_global: List[torch.Tensor]):
        # add (ci - c) to each parameter's gradient
        with torch.no_grad():
            for (p, gi, cg) in zip(model.parameters(), self.ci, c_global):
                if p.grad is None:
                    continue
                p.grad.add_(gi - cg)

    def _fedprox_add_proximal(self, model: nn.Module, global_params: List[torch.Tensor]):
        # add µ/2 * ||theta - theta_g||^2 to loss => grads add µ*(theta - theta_g)
        mu = self.cfg.mu
        if mu <= 0:
            return 0.0
        prox = 0.0
        for p, g in zip(model.parameters(), global_params):
            prox = prox + 0.5 * mu * torch.sum((p - g) ** 2)
        return prox

    def _fedsam_ascent(self, model: nn.Module, rho: float):
        # Perturb weights: w_adv = w + rho * g/||g|| (g is grad w.r.t current w)
        grad_vec = get_model_grads_vector(model)
        eps = 1e-12
        scale = rho / (grad_vec.norm(p=2) + eps)
        offset = 0
        with torch.no_grad():
            for p in model.parameters():
                if p.grad is None:
                    continue
                numel = p.numel()
                p.add_(grad_vec[offset : offset + numel].view_as(p), alpha=scale)
                offset += numel

    def _fedsam_descent_restore(self, model: nn.Module, rho: float):
        # Undo the perturbation by subtracting same delta applied in ascent.
        # NOTE: We recompute using the *current* grad vector, which is at w_adv; to precisely undo, we stored nothing.
        # A more exact impl would store the ascent delta. We'll compute it again from grads-at-w (before ascent),
        # but we no longer have those grads. So we do the simple approach: store ascent deltas.
        pass  # We'll store deltas explicitly below.

    def local_train(
        self,
        global_model: nn.Module,
        c_global: Optional[List[torch.Tensor]] = None,
    ) -> Dict[str, torch.Tensor | List[torch.Tensor]]:
        device = self.device
        self.set_model_from_global(global_model)
        model = self.model
        model.train()
        opt = optim.SGD(model.parameters(), lr=self.cfg.lr, momentum=self.cfg.momentum)

        # cache global params for FedProx gradient contribution
        global_params = [p.detach().clone() for p in global_model.parameters()]

        rho = self.cfg.rho if self.strategy == "fedsam" else 0.0

        for ep in range(self.cfg.local_epochs):
            for xb, yb in self.loader:
                xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)

                # ----- standard forward/backward -----
                opt.zero_grad(set_to_none=True)
                logits = model(xb)
                loss = F.cross_entropy(logits, yb)

                # FedProx proximal term
                if self.strategy == "fedprox" and self.cfg.mu > 0:
                    loss = loss + self._fedprox_add_proximal(model, global_params)

                loss.backward()

                # SCAFFOLD gradient correction
                if self.strategy == "scaffold":
                    assert c_global is not None and self.ci is not None
                    self._scaffold_apply_correction(model, c_global)

                # FedSAM two-step
                if self.strategy == "fedsam" and rho > 0:
                    # store ascent deltas
                    ascent_deltas = []
                    with torch.no_grad():
                        grad_vec = get_model_grads_vector(model)
                        eps = 1e-12
                        scale = rho / (grad_vec.norm(p=2) + eps)
                        offset = 0
                        for p in model.parameters():
                            if p.grad is None:
                                ascent_deltas.append(None)
                                continue
                            numel = p.numel()
                            delta = grad_vec[offset : offset + numel].view_as(p) * scale
                            p.add_(delta)
                            ascent_deltas.append(delta)
                            offset += numel

                    # compute grad at perturbed weights
                    opt.zero_grad(set_to_none=True)
                    logits_adv = model(xb)
                    loss_adv = F.cross_entropy(logits_adv, yb)
                    loss_adv.backward()

                    # restore original weights
                    with torch.no_grad():
                        for p, delta in zip(model.parameters(), ascent_deltas):
                            if delta is not None:
                                p.sub_(delta)

                    # now apply optimizer step using grads at w_adv (stored on params)
                    opt.step()
                else:
                    # vanilla or FedProx/SCAFFOLD (after correction)
                    opt.step()

        # return results
        with torch.no_grad():
            theta_i = [p.detach().clone() for p in model.parameters()]
            theta_g = [p.detach().clone() for p in global_model.parameters()]
            # update delta for aggregation
            deltas = [ti - tg for ti, tg in zip(theta_i, theta_g)]

        out: Dict[str, torch.Tensor | List[torch.Tensor]] = {
            "params": theta_i,
            "delta": deltas,
            "num_samples": torch.tensor(len(self.loader.dataset), dtype=torch.long),
        }

        # SCAFFOLD: update ci based on global and local change
        if self.strategy == "scaffold":
            assert self.ci is not None and c_global is not None
            # ci <- c + (1 / (K * lr)) * (theta_g - theta_i)
            K = self.cfg.local_epochs
            lr = self.cfg.lr
            with torch.no_grad():
                for idx, (gi, cg, tg, ti) in enumerate(zip(self.ci, c_global, theta_g, theta_i)):
                    gi.copy_(cg + (tg - ti) / (K * lr))
            out["ci"] = [t.detach().clone() for t in self.ci]

        return out




In [4]:
# ---------------------------
# Server & strategies
# ---------------------------
class Server:
    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model.to(device)
        self.device = device

    def aggregate_weighted(self, client_params: List[List[torch.Tensor]], weights: List[float]):
        with torch.no_grad():
            for p_idx, p in enumerate(self.model.parameters()):
                acc = None
                for w, params in zip(weights, client_params):
                    term = params[p_idx].to(self.device) * w
                    acc = term if acc is None else acc + term
                p.copy_(acc)

    def aggregate_from_deltas(self, deltas: List[List[torch.Tensor]], weights: List[float]):
        with torch.no_grad():
            for p_idx, p in enumerate(self.model.parameters()):
                acc = torch.zeros_like(p)
                for w, dlist in zip(weights, deltas):
                    acc.add_(dlist[p_idx].to(self.device), alpha=w)
                p.add_(acc)

    # FedGH: harmonize deltas before averaging
    def harmonize_pairwise(self, flat_updates: List[torch.Tensor]) -> List[torch.Tensor]:
        M = len(flat_updates)
        outs = [u.clone() for u in flat_updates]
        for i in range(M):
            for j in range(i + 1, M):
                gi, gj = outs[i], outs[j]
                dot = torch.dot(gi, gj)
                if dot < 0:
                    # project symmetric
                    gi_norm2 = torch.dot(gi, gi) + 1e-12
                    gj_norm2 = torch.dot(gj, gj) + 1e-12
                    proj_i = dot / gj_norm2
                    proj_j = dot / gi_norm2
                    outs[i] = gi - proj_i * gj
                    outs[j] = gj - proj_j * gi
        return outs


# ---------------------------
# Evaluation & metrics
# ---------------------------
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = F.cross_entropy(logits, yb, reduction='sum')
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
        loss_sum += loss.item()
    return correct / total, loss_sum / total


def compute_drift(global_model: nn.Module, client_param_lists: List[List[torch.Tensor]], device: torch.device) -> float:
    with torch.no_grad():
        gparams = [p.detach().to(device) for p in global_model.parameters()]
        dists = []
        for plist in client_param_lists:
            s = 0.0
            for gp, cp in zip(gparams, plist):
                s += torch.norm(cp.to(device) - gp, p=2).item() ** 2
            dists.append(math.sqrt(s))
        return float(sum(dists) / len(dists))


# ---------------------------
# Data loading
# ---------------------------

def get_cifar10(root: str = "./data"):
    tfm_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    tfm_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    train = datasets.CIFAR10(root, train=True, download=True, transform=tfm_train)
    test = datasets.CIFAR10(root, train=False, download=True, transform=tfm_test)
    return train, test




In [5]:
 #---------------------------
# Training orchestration
# ---------------------------

def run(
    strategy: str,
    num_clients: int,
    alpha: float,
    rounds: int,
    K: int,
    batch_size: int,
    lr: float,
    momentum: float,
    mu: float,
    rho: float,
    sample_frac: float,
    seed: int = 42,
    device_str: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    set_seed(seed)
    device = torch.device(device_str)

    # data
    train_set, test_set = get_cifar10()
    test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2)

    # partition
    targets = torch.tensor(train_set.targets)
    splits = dirichlet_partition_indices(targets, num_clients=num_clients, alpha=alpha, seed=seed)

    # model template
    global_model = SmallCNN().to(device)
    server = Server(global_model, device)

    # client configs
    cfg = ClientConfig(lr=lr, momentum=momentum, batch_size=batch_size, local_epochs=K, mu=mu, rho=rho)

    # SCAFFOLD templates
    scaffold_template = None
    if strategy.lower() == "scaffold":
        scaffold_template = [p.detach().clone() for p in global_model.parameters()]

    # build clients
    clients: List[Client] = []
    for cid in range(num_clients):
        clients.append(
            Client(
                cid=cid,
                dataset=train_set,
                indices=splits[cid],
                device=device,
                cfg=cfg,
                strategy=strategy,
                model_template=global_model,
                scaffold_ci_template=scaffold_template,
            )
        )

    # SCAFFOLD global control variate c
    c_global: Optional[List[torch.Tensor]] = None
    if strategy.lower() == "scaffold":
        c_global = [torch.zeros_like(p, device=device) for p in global_model.parameters()]

    # training loop
    frac = sample_frac
    for rnd in range(1, rounds + 1):
        # sample participating clients
        m = max(1, int(round(frac * num_clients)))
        selected = random.sample(range(num_clients), m)

        # broadcast implicit via copying in local_train
        results = []
        for idx in selected:
            res = clients[idx].local_train(global_model, c_global)
            results.append((idx, res))

        # weights by client data size
        sizes = [int(res["num_samples"]) for _, res in results]
        total = sum(sizes)
        weights = [s / total for s in sizes]

        # metrics: drift before aggregation (based on current local params)
        drift_val = compute_drift(global_model, [res["params"] for _, res in results], device)

        # aggregation
        if strategy.lower() == "fedgh":
            # harmonize flat deltas then add to global
            flat = []
            for _, res in results:
                # concat layers (weighted delta will be applied after harmonization via weights)
                deltas = res["delta"]
                flat.append(torch.cat([d.detach().view(-1).to(device) for d in deltas]))
            flat_h = server.harmonize_pairwise(flat)
            # reconstruct per-layer from flat
            # We'll distribute harmonized flat deltas proportionally by weights
            # First, split shapes
            shapes = [p.shape for p in global_model.parameters()]
            sizes_layer = [int(torch.tensor(s).prod()) for s in shapes]
            per_client_deltas: List[List[torch.Tensor]] = []
            for fh in flat_h:
                offset = 0
                dl = []
                for sz, shp in zip(sizes_layer, shapes):
                    dl.append(fh[offset:offset+sz].view(shp))
                    offset += sz
                per_client_deltas.append(dl)
            server.aggregate_from_deltas(per_client_deltas, weights)
        else:
            # standard weighted average on parameters (FedAvg-style)
            server.aggregate_weighted([res["params"] for _, res in results], weights)

        # SCAFFOLD: update c_global to average of ci
        if strategy.lower() == "scaffold":
            with torch.no_grad():
                agg_ci = None
                for _, res in results:
                    ci_list = res["ci"]  # type: ignore
                    agg_ci = [t.clone() for t in ci_list] if agg_ci is None else [a + b for a, b in zip(agg_ci, ci_list)]
                for i in range(len(agg_ci)):
                    agg_ci[i] = agg_ci[i] / len(results)
                for i, p in enumerate(c_global):
                    p.copy_(agg_ci[i].to(device))

        # eval
        acc, loss = evaluate(global_model, test_loader, device)
        print(f"Round {rnd:03d} | clients {m:02d}/{num_clients} | drift {drift_val:.3f} | acc {acc*100:.2f}% | loss {loss:.4f}")




In [6]:
# Call the run function directly with desired parameters
run(
    strategy="fedavg",  # default strategy
    num_clients=10,    # default number of clients
    alpha=0.1,         # default Dirichlet concentration
    rounds=50,         # default number of rounds
    K=5,               # default local epochs
    batch_size=64,     # default batch size
    lr=0.01,           # default learning rate
    momentum=0.9,      # default momentum
    mu=0.01,           # default FedProx mu
    rho=0.05,          # default FedSAM rho
    sample_frac=1.0,   # default sample fraction
    seed=42,           # default seed
    device_str="cuda" if torch.cuda.is_available() else "cpu", # default device
)

100%|██████████| 170M/170M [00:03<00:00, 44.3MB/s]


Round 001 | clients 10/10 | drift 3.102 | acc 33.26% | loss 2.1159
Round 002 | clients 10/10 | drift 3.010 | acc 41.19% | loss 1.7394
Round 003 | clients 10/10 | drift 2.849 | acc 49.18% | loss 1.4678
Round 004 | clients 10/10 | drift 2.727 | acc 53.92% | loss 1.3078
Round 005 | clients 10/10 | drift 2.689 | acc 58.48% | loss 1.2013
Round 006 | clients 10/10 | drift 2.696 | acc 61.09% | loss 1.1091
Round 007 | clients 10/10 | drift 2.667 | acc 63.22% | loss 1.0548
Round 008 | clients 10/10 | drift 2.679 | acc 64.37% | loss 1.0264
Round 009 | clients 10/10 | drift 2.682 | acc 65.99% | loss 0.9841
Round 010 | clients 10/10 | drift 2.711 | acc 67.67% | loss 0.9363
Round 011 | clients 10/10 | drift 2.740 | acc 68.26% | loss 0.9074
Round 012 | clients 10/10 | drift 2.768 | acc 68.93% | loss 0.8963
Round 013 | clients 10/10 | drift 2.841 | acc 69.33% | loss 0.8960
Round 014 | clients 10/10 | drift 2.810 | acc 70.57% | loss 0.8575
Round 015 | clients 10/10 | drift 2.844 | acc 70.61% | loss 0.

In [None]:
strategies = [
{"strategy": "fedavg", "label": "FedAvg"},
{"strategy": "fedprox", "label": "FedProx", "mu": 0.01},
{"strategy": "scaffold", "label": "SCAFFOLD"},
{"strategy": "fedgh", "label": "FedGH"},
{"strategy": "fedsam", "label": "FedSAM", "rho": 0.05},
]


common_cfg = dict(
num_clients=10,
alpha=0.1,
rounds=50,
K=5,
batch_size=64,
lr=0.01,
momentum=0.9,
sample_frac=1.0,
seed=42,
)


for cfg in strategies:
label = cfg.pop("label")
print(f"\n{'='*80}\nRunning {label}\n{'='*80}")
run(**common_cfg, **cfg)