"""
ATML PA4 — Task 4 from scratch (no reuse):
A clean, single-file PyTorch implementation of a small federated learning framework
with four heterogeneity-mitigation strategies:
- FedProx (local proximal regularization)
- SCAFFOLD (control variates)
- FedGH (server-side gradient harmonization)
- FedSAM (sharpness-aware minimization on clients)


Also includes:
- CIFAR-10 loading and Dirichlet non-IID partitioning
- Simple CNN model
- FedAvg baseline
- Weighted aggregation, client drift metric, logging hooks


HOW TO USE (example):
python ATML-PA4-Task4.py --strategy fedavg --alpha 0.1 --num-clients 10 --rounds 50 --K 5
python ATML-PA4-Task4.py --strategy fedprox --mu 0.01 --alpha 0.1 --num-clients 10 --rounds 50 --K 5
python ATML-PA4-Task4.py --strategy scaffold --alpha 0.1 --num-clients 10 --rounds 50 --K 5
python ATML-PA4-Task4.py --strategy fedgh --alpha 0.1 --num-clients 10 --rounds 50 --K 5
python ATML-PA4-Task4.py --strategy fedsam --rho 0.05 --alpha 0.1 --num-clients 10 --rounds 50 --K 5


Notes:
* Keep the model small to fit within assignment constraints.
* SCAFFOLD doubles comms (sends control variates). FedSAM ~2x local compute.
* FedGH adds O(M^2) server-time pairwise projections per round.


This file is deliberately verbose and self-contained for clarity and grading.
"""

In [13]:
from __future__ import annotations
import argparse
import copy
import math
import os
import random
from dataclasses import dataclass
from typing import Dict, List, Tuple, Iterable, Optional


import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Subset

# torchvision is permissible for CIFAR-10
from torchvision import datasets, transforms

# Reproducibility helpers

def set_seed(seed: int = 42):
  random.seed(seed)
  torch.manual_seed(seed)
  torch.cuda.manual_seed_all(seed)
  os.environ["PYTHONHASHSEED"] = str(seed)

In [14]:
# ---------------------------
# Simple CNN for CIFAR-10
# ---------------------------
class SmallCNN(nn.Module):
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 16x16
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),  # 8x8
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(64 * 8 * 8, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


# ---------------------------
# Dirichlet non-IID partition
# ---------------------------

def dirichlet_partition_indices(
    targets: torch.Tensor, num_clients: int, alpha: float, seed: int = 42
) -> List[List[int]]:
    """Split dataset indices into num_clients using class-wise Dirichlet(α) proportions.
    Smaller α => higher label skew.
    """
    g = torch.Generator().manual_seed(seed)
    num_classes = int(targets.max().item() + 1)
    class_indices = [torch.where(targets == c)[0].tolist() for c in range(num_classes)]
    for ci in class_indices:
        random.shuffle(ci)

    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        # sample proportions for this class
        proportions = torch.distributions.Dirichlet(torch.full((num_clients,), alpha)).sample()
        proportions = (proportions / proportions.sum()).tolist()
        cls_ids = class_indices[c]
        # split cls_ids according to proportions
        prev = 0
        for k in range(num_clients):
            take = int(round(proportions[k] * len(cls_ids)))
            client_indices[k].extend(cls_ids[prev : prev + take])
            prev += take
        # in case of rounding leftovers, dump remainder into last client
        if prev < len(cls_ids):
            client_indices[-1].extend(cls_ids[prev:])

    # shuffle each client list
    for k in range(num_clients):
        random.shuffle(client_indices[k])
    return client_indices


# ---------------------------
# Utilities for (de)flattening params and deltas
# ---------------------------

def get_model_params_vector(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.detach().view(-1) for p in model.parameters()])


def get_model_grads_vector(model: nn.Module) -> torch.Tensor:
    return torch.cat([p.grad.detach().view(-1) if p.grad is not None else torch.zeros_like(p).view(-1) for p in model.parameters()])


def assign_params_from_vector(model: nn.Module, vec: torch.Tensor):
    offset = 0
    with torch.no_grad():
        for p in model.parameters():
            numel = p.numel()
            p.copy_(vec[offset : offset + numel].view_as(p))
            offset += numel


def add_inplace(tensors: Iterable[torch.Tensor], alphas: Iterable[float], out: torch.Tensor):
    """out = sum(alpha_i * tensor_i). Assumes flat vectors of equal shape."""
    out.zero_()
    for t, a in zip(tensors, alphas):
        out.add_(t, alpha=a)

In [15]:
# ---------------------------
# Client logic (baseline + hooks)
# ---------------------------
@dataclass
class ClientConfig:
    lr: float = 0.01
    momentum: float = 0.9
    batch_size: int = 64
    local_epochs: int = 5  # K
    mu: float = 0.0  # for FedProx
    rho: float = 0.0  # for FedSAM


class Client:
    def __init__(
        self,
        cid: int,
        dataset: torch.utils.data.Dataset,
        indices: List[int],
        device: torch.device,
        cfg: ClientConfig,
        strategy: str,
        model_template: nn.Module,
        scaffold_ci_template: Optional[List[torch.Tensor]] = None,
    ):
        self.cid = cid
        self.device = device
        self.cfg = cfg
        self.strategy = strategy.lower()
        self.loader = DataLoader(Subset(dataset, indices), batch_size=cfg.batch_size, shuffle=True, num_workers=2, pin_memory=True)
        self.model = copy.deepcopy(model_template).to(device)

        # SCAFFOLD control variate for this client (list of tensors matching params)
        if self.strategy == "scaffold":
            assert scaffold_ci_template is not None
            self.ci = [torch.zeros_like(t, device=device) for t in scaffold_ci_template]
        else:
            self.ci = None

    def set_model_from_global(self, global_model: nn.Module):
        self.model.load_state_dict(copy.deepcopy(global_model.state_dict()))

    def _scaffold_apply_correction(self, model: nn.Module, c_global: List[torch.Tensor]):
        # add (ci - c) to each parameter's gradient
        with torch.no_grad():
            for (p, gi, cg) in zip(model.parameters(), self.ci, c_global):
                if p.grad is None:
                    continue
                p.grad.add_(gi - cg)

    def _fedprox_add_proximal(self, model: nn.Module, global_params: List[torch.Tensor]):
        # add µ/2 * ||theta - theta_g||^2 to loss => grads add µ*(theta - theta_g)
        mu = self.cfg.mu
        if mu <= 0:
            return 0.0
        prox = 0.0
        for p, g in zip(model.parameters(), global_params):
            prox = prox + 0.5 * mu * torch.sum((p - g) ** 2)
        return prox

    def _fedsam_ascent(self, model: nn.Module, rho: float):
        # Perturb weights: w_adv = w + rho * g/||g|| (g is grad w.r.t current w)
        grad_vec = get_model_grads_vector(model)
        eps = 1e-12
        scale = rho / (grad_vec.norm(p=2) + eps)
        offset = 0
        with torch.no_grad():
            for p in model.parameters():
                if p.grad is None:
                    continue
                numel = p.numel()
                p.add_(grad_vec[offset : offset + numel].view_as(p), alpha=scale)
                offset += numel

    def _fedsam_descent_restore(self, model: nn.Module, rho: float):
        # Undo the perturbation by subtracting same delta applied in ascent.
        # NOTE: We recompute using the *current* grad vector, which is at w_adv; to precisely undo, we stored nothing.
        # A more exact impl would store the ascent delta. We'll compute it again from grads-at-w (before ascent),
        # but we no longer have those grads. So we do the simple approach: store ascent deltas.
        pass  # We'll store deltas explicitly below.

    def local_train(
        self,
        global_model: nn.Module,
        c_global: Optional[List[torch.Tensor]] = None,
    ) -> Dict[str, torch.Tensor | List[torch.Tensor]]:
        device = self.device
        self.set_model_from_global(global_model)
        model = self.model
        model.train()
        opt = optim.SGD(model.parameters(), lr=self.cfg.lr, momentum=self.cfg.momentum)

        # cache global params for FedProx gradient contribution
        global_params = [p.detach().clone() for p in global_model.parameters()]

        rho = self.cfg.rho if self.strategy == "fedsam" else 0.0

        for ep in range(self.cfg.local_epochs):
            for xb, yb in self.loader:
                xb, yb = xb.to(device, non_blocking=True), yb.to(device, non_blocking=True)

                if not torch.isfinite(loss):
                  print(f"[Client {self.cid}] non-finite loss; breaking batch")
                  break

                # ----- standard forward/backward -----
                opt.zero_grad(set_to_none=True)
                logits = model(xb)
                loss = F.cross_entropy(logits, yb)

                # FedProx proximal term
                if self.strategy == "fedprox" and self.cfg.mu > 0:
                    loss = loss + self._fedprox_add_proximal(model, global_params)

                loss.backward()

                # SCAFFOLD gradient correction
                if self.strategy == "scaffold":
                    assert c_global is not None and self.ci is not None
                    self._scaffold_apply_correction(model, c_global)

                # FedSAM two-step
                if self.strategy == "fedsam" and rho > 0:
                    # store ascent deltas
                    ascent_deltas = []
                    with torch.no_grad():
                        grad_vec = get_model_grads_vector(model)
                        eps = 1e-12
                        scale = rho / (grad_vec.norm(p=2) + eps)
                        offset = 0
                        for p in model.parameters():
                            if p.grad is None:
                                ascent_deltas.append(None)
                                continue
                            numel = p.numel()
                            delta = grad_vec[offset : offset + numel].view_as(p) * scale
                            p.add_(delta)
                            ascent_deltas.append(delta)
                            offset += numel

                    # compute grad at perturbed weights
                    opt.zero_grad(set_to_none=True)
                    logits_adv = model(xb)
                    loss_adv = F.cross_entropy(logits_adv, yb)
                    loss_adv.backward()

                    # restore original weights
                    with torch.no_grad():
                        for p, delta in zip(model.parameters(), ascent_deltas):
                            if delta is not None:
                                p.sub_(delta)

                    # now apply optimizer step using grads at w_adv (stored on params)
                    opt.step()
                else:
                    # vanilla or FedProx/SCAFFOLD (after correction)
                    opt.step()

        # return results
        with torch.no_grad():
            theta_i = [p.detach().clone() for p in model.parameters()]
            theta_g = [p.detach().clone() for p in global_model.parameters()]
            # update delta for aggregation
            deltas = [ti - tg for ti, tg in zip(theta_i, theta_g)]

        out: Dict[str, torch.Tensor | List[torch.Tensor]] = {
            "params": theta_i,
            "delta": deltas,
            "num_samples": torch.tensor(len(self.loader.dataset), dtype=torch.long),
        }

        # SCAFFOLD: update ci based on global and local change
        if self.strategy == "scaffold":
            assert self.ci is not None and c_global is not None
            K = self.cfg.local_epochs
            lr = self.cfg.lr
            inv = 1.0 / (K * lr)
            with torch.no_grad():
                for ci_p, c_p, wt_p, wt1i_p in zip(self.ci, c_global, theta_g, theta_i):
                    ci_p.add_(-c_p)                          # ci = ci - c
                    ci_p.add_((wt_p - wt1i_p) * inv)         # + (w_t - w_{t+1}^i)/(K*lr)
            out["ci"] = [t.detach().clone() for t in self.ci]

        return out




In [16]:
# ---------------------------
# Server & strategies
# ---------------------------
class Server:
    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model.to(device)
        self.device = device

    def aggregate_weighted(self, client_params: List[List[torch.Tensor]], weights: List[float]):
        with torch.no_grad():
            for p_idx, p in enumerate(self.model.parameters()):
                acc = None
                for w, params in zip(weights, client_params):
                    term = params[p_idx].to(self.device) * w
                    acc = term if acc is None else acc + term
                p.copy_(acc)

    def aggregate_from_deltas(self, deltas: List[List[torch.Tensor]], weights: List[float]):
        with torch.no_grad():
            for p_idx, p in enumerate(self.model.parameters()):
                acc = torch.zeros_like(p)
                for w, dlist in zip(weights, deltas):
                    acc.add_(dlist[p_idx].to(self.device), alpha=w)
                p.add_(acc)

    # FedGH: harmonize deltas before averaging
    def harmonize_pairwise(self, flat_updates: List[torch.Tensor]) -> List[torch.Tensor]:
        M = len(flat_updates)
        outs = [u.clone() for u in flat_updates]
        for i in range(M):
            for j in range(i + 1, M):
                gi, gj = outs[i], outs[j]
                dot = torch.dot(gi, gj)
                if dot < 0:
                    # project symmetric
                    gi_norm2 = torch.dot(gi, gi) + 1e-12
                    gj_norm2 = torch.dot(gj, gj) + 1e-12
                    proj_i = dot / gj_norm2
                    proj_j = dot / gi_norm2
                    outs[i] = gi - proj_i * gj
                    outs[j] = gj - proj_j * gi
        return outs


# ---------------------------
# Evaluation & metrics
# ---------------------------
@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, device: torch.device) -> Tuple[float, float]:
    model.eval()
    total, correct, loss_sum = 0, 0, 0.0
    for xb, yb in loader:
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        loss = F.cross_entropy(logits, yb, reduction='sum')
        preds = logits.argmax(dim=1)
        correct += (preds == yb).sum().item()
        total += yb.size(0)
        loss_sum += loss.item()
    return correct / total, loss_sum / total


def compute_drift(global_model: nn.Module, client_param_lists: List[List[torch.Tensor]], device: torch.device) -> float:
    with torch.no_grad():
        gparams = [p.detach().to(device) for p in global_model.parameters()]
        dists = []
        for plist in client_param_lists:
            s = 0.0
            for gp, cp in zip(gparams, plist):
                s += torch.norm(cp.to(device) - gp, p=2).item() ** 2
            dists.append(math.sqrt(s))
        return float(sum(dists) / len(dists))


# ---------------------------
# Data loading
# ---------------------------

def get_cifar10(root: str = "./data"):
    tfm_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    tfm_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
    ])
    train = datasets.CIFAR10(root, train=True, download=True, transform=tfm_train)
    test = datasets.CIFAR10(root, train=False, download=True, transform=tfm_test)
    return train, test




In [17]:
 #---------------------------
# Training orchestration
# ---------------------------

def run(
    strategy: str,
    num_clients: int,
    alpha: float,
    rounds: int,
    K: int,
    batch_size: int,
    lr: float,
    momentum: float,
    mu: float,
    rho: float,
    sample_frac: float,
    seed: int = 42,
    device_str: str = "cuda" if torch.cuda.is_available() else "cpu",
):
    set_seed(seed)
    device = torch.device(device_str)

    # data
    train_set, test_set = get_cifar10()
    test_loader = DataLoader(test_set, batch_size=256, shuffle=False, num_workers=2)

    # partition
    targets = torch.tensor(train_set.targets)
    splits = dirichlet_partition_indices(targets, num_clients=num_clients, alpha=alpha, seed=seed)

    # model template
    global_model = SmallCNN().to(device)
    server = Server(global_model, device)

    # client configs
    cfg = ClientConfig(lr=lr, momentum=momentum, batch_size=batch_size, local_epochs=K, mu=mu, rho=rho)

    # SCAFFOLD templates
    scaffold_template = None
    if strategy.lower() == "scaffold":
        scaffold_template = [p.detach().clone() for p in global_model.parameters()]

    # build clients
    clients: List[Client] = []
    for cid in range(num_clients):
        clients.append(
            Client(
                cid=cid,
                dataset=train_set,
                indices=splits[cid],
                device=device,
                cfg=cfg,
                strategy=strategy,
                model_template=global_model,
                scaffold_ci_template=scaffold_template,
            )
        )

    # SCAFFOLD global control variate c
    c_global: Optional[List[torch.Tensor]] = None
    if strategy.lower() == "scaffold":
        c_global = [torch.zeros_like(p, device=device) for p in global_model.parameters()]

    # training loop
    frac = sample_frac
    for rnd in range(1, rounds + 1):
        # sample participating clients
        m = max(1, int(round(frac * num_clients)))
        selected = random.sample(range(num_clients), m)

        # broadcast implicit via copying in local_train
        results = []
        for idx in selected:
            res = clients[idx].local_train(global_model, c_global)
            results.append((idx, res))

        # weights by client data size
        sizes = [int(res["num_samples"]) for _, res in results]
        total = sum(sizes)
        weights = [s / total for s in sizes]

        # metrics: drift before aggregation (based on current local params)
        drift_val = compute_drift(global_model, [res["params"] for _, res in results], device)

        # aggregation
        if strategy.lower() == "fedgh":
            # harmonize flat deltas then add to global
            flat = []
            for _, res in results:
                # concat layers (weighted delta will be applied after harmonization via weights)
                deltas = res["delta"]
                flat.append(torch.cat([d.detach().view(-1).to(device) for d in deltas]))
            flat_h = server.harmonize_pairwise(flat)
            # reconstruct per-layer from flat
            # We'll distribute harmonized flat deltas proportionally by weights
            # First, split shapes
            shapes = [p.shape for p in global_model.parameters()]
            sizes_layer = [int(torch.tensor(s).prod()) for s in shapes]
            per_client_deltas: List[List[torch.Tensor]] = []
            for fh in flat_h:
                offset = 0
                dl = []
                for sz, shp in zip(sizes_layer, shapes):
                    dl.append(fh[offset:offset+sz].view(shp))
                    offset += sz
                per_client_deltas.append(dl)
            server.aggregate_from_deltas(per_client_deltas, weights)
        else:
            # standard weighted average on parameters (FedAvg-style)
            server.aggregate_weighted([res["params"] for _, res in results], weights)

        # SCAFFOLD: update c_global to average of ci
        if strategy.lower() == "scaffold":
            with torch.no_grad():
                agg_ci = None
                for _, res in results:
                    ci_list = res["ci"]  # type: ignore
                    agg_ci = [t.clone() for t in ci_list] if agg_ci is None else [a + b for a, b in zip(agg_ci, ci_list)]
                for i in range(len(agg_ci)):
                    agg_ci[i] = agg_ci[i] / len(results)
                for i, p in enumerate(c_global):
                    p.copy_(agg_ci[i].to(device))

        # eval
        acc, loss = evaluate(global_model, test_loader, device)
        print(f"Round {rnd:03d} | clients {m:02d}/{num_clients} | drift {drift_val:.3f} | acc {acc*100:.2f}% | loss {loss:.4f}")




In [18]:
# ==== Setup: helpers, logging, and parser (run this once) ====
import io, sys, os, re, time, json, shutil
import pandas as pd
import matplotlib.pyplot as plt

# Assumes a function run(**kwargs) is already defined in another cell.

OUTDIR   = "pa4_task4_runs"   # root for artifacts
ROUNDS   = 30                 # default rounds (consistent across methods)
DOWNLOAD = True               # auto-download zips/plots (Colab)
os.makedirs(OUTDIR, exist_ok=True)

# Parse lines like: Round 001 | clients 10/10 | drift 3.103 | acc 33.10% | loss 2.1145
line_re = re.compile(
    r"Round\s+(\d+)\s+\|\s+clients\s+(\d+)\/(\d+)\s+\|\s+drift\s+([0-9.]+)\s+\|\s+acc\s+([0-9.]+)%\s+\|\s+loss\s+([0-9.]+)"
)

def to_jsonable(x):
    if isinstance(x, dict): return {str(k): to_jsonable(v) for k,v in x.items()}
    if isinstance(x, (list, tuple)): return [to_jsonable(v) for v in x]
    if isinstance(x, (str, int, float, bool)) or x is None: return x
    try:
        import numpy as np
        if isinstance(x, np.generic): return x.item()
    except: pass
    try:
        import torch
        if isinstance(x, (torch.device, torch.dtype)): return str(x)
    except: pass
    return str(x)

class Tee(io.TextIOBase):
    """Write to notebook + file simultaneously (live)."""
    def __init__(self, file_obj, mirror): self.f, self.m = file_obj, mirror
    def write(self, s): self.m.write(s); self.m.flush(); self.f.write(s); self.f.flush(); return len(s)
    def flush(self): self.m.flush(); self.f.flush()

def run_and_log(label: str, cfg: dict):
    """Runs one strategy once, prints live, saves logs/CSV/plots, and optionally zips+downloads."""
    stamp  = time.strftime("%Y%m%d-%H%M%S")
    tag    = f"{label}_alpha{cfg['alpha']}_K{cfg['K']}_N{cfg['num_clients']}_{stamp}"
    run_dir = os.path.join(OUTDIR, tag); os.makedirs(run_dir, exist_ok=True)

    # Save config
    cfg_to_save = dict(cfg); cfg_to_save["label"] = label
    with open(os.path.join(run_dir, "config.json"), "w") as f: json.dump(to_jsonable(cfg_to_save), f, indent=2)

    # Live + file logging
    log_path = os.path.join(run_dir, "train.log")
    with open(log_path, "w") as logf:
        old = sys.stdout; sys.stdout = Tee(logf, old)
        try:
            print("\n" + "="*80); print(f"Running {label}"); print("="*80)
            t0 = time.time()
            run(**cfg)  # prints will appear live AND be written to train.log
            print(f"[{label}] elapsed: {time.time()-t0:.2f}s")
        finally:
            sys.stdout = old

    # Parse metrics from train.log
    rows = []
    with open(log_path) as f:
        for line in f:
            m = line_re.search(line)
            if m:
                rows.append((
                    int(m.group(1)), int(m.group(2)), int(m.group(3)),
                    float(m.group(4)), float(m.group(5))/100.0, float(m.group(6))
                ))
    df = pd.DataFrame(rows, columns=["round","m_clients","n_clients","drift","acc","loss"])
    df.to_csv(os.path.join(run_dir, "metrics.csv"), index=False)

    # Plots
    if not df.empty:
        for y, title, name in [
            ("acc",  f"{label} — Accuracy vs Rounds", "acc_vs_rounds.png"),
            ("loss", f"{label} — Loss vs Rounds",     "loss_vs_rounds.png"),
            ("drift",f"{label} — Drift vs Rounds",    "drift_vs_rounds.png"),
        ]:
            plt.figure(); plt.plot(df["round"], df[y]); plt.xlabel("Round"); plt.ylabel(y.capitalize())
            plt.title(title); plt.tight_layout()
            plt.savefig(os.path.join(run_dir, name), dpi=150); plt.close()

        best = df.loc[df["acc"].idxmax()]
        summary = dict(
            final_round=int(df["round"].iloc[-1]),
            final_acc=float(df["acc"].iloc[-1]),
            best_acc=float(best["acc"]),
            best_round=int(best["round"]),
            final_loss=float(df["loss"].iloc[-1]),
            final_drift=float(df["drift"].iloc[-1]),
        )
        with open(os.path.join(run_dir, "summary.json"), "w") as f: json.dump(summary, f, indent=2)
        print(f"[{label}] Final acc {summary['final_acc']:.3f} | Best {summary['best_acc']:.3f} @ r{summary['best_round']} | Drift {summary['final_drift']:.3f}")
    else:
        print(f"[{label}] WARN: no metrics parsed. Check train.log format.")

    # Zip + optional download
    zip_path = shutil.make_archive(run_dir, "zip", run_dir)
    if DOWNLOAD:
        try:
            from google.colab import files
            files.download(zip_path)
        except Exception as e:
            print(f"[{label}] Download skipped ({e}). Zip at: {zip_path}")

    return run_dir


In [8]:
cfg_fedavg = dict(
    strategy="fedavg", num_clients=10, alpha=0.1, rounds=ROUNDS, K=5,
    batch_size=64, lr=0.01, momentum=0.9, sample_frac=1.0, seed=42, mu=0.0, rho=0.0
)
run_dir_fedavg = run_and_log("FedAvg", cfg_fedavg)



Running FedAvg


100%|██████████| 170M/170M [00:04<00:00, 39.9MB/s]


Round 001 | clients 10/10 | drift 3.102 | acc 33.28% | loss 2.1141
Round 002 | clients 10/10 | drift 3.011 | acc 41.39% | loss 1.7344
Round 003 | clients 10/10 | drift 2.845 | acc 49.08% | loss 1.4711
Round 004 | clients 10/10 | drift 2.727 | acc 54.18% | loss 1.3075
Round 005 | clients 10/10 | drift 2.688 | acc 58.09% | loss 1.2105
Round 006 | clients 10/10 | drift 2.699 | acc 61.17% | loss 1.1045
Round 007 | clients 10/10 | drift 2.673 | acc 63.22% | loss 1.0543
Round 008 | clients 10/10 | drift 2.683 | acc 63.94% | loss 1.0244
Round 009 | clients 10/10 | drift 2.690 | acc 66.43% | loss 0.9715
Round 010 | clients 10/10 | drift 2.720 | acc 67.80% | loss 0.9347
Round 011 | clients 10/10 | drift 2.751 | acc 68.05% | loss 0.9231
Round 012 | clients 10/10 | drift 2.771 | acc 69.02% | loss 0.8916
Round 013 | clients 10/10 | drift 2.832 | acc 69.10% | loss 0.9054
Round 014 | clients 10/10 | drift 2.804 | acc 70.24% | loss 0.8656
Round 015 | clients 10/10 | drift 2.844 | acc 70.40% | loss 0.

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [9]:
cfg_fedprox = dict(
    strategy="fedprox", num_clients=10, alpha=0.1, rounds=ROUNDS, K=5,
    batch_size=64, lr=0.01, momentum=0.9, sample_frac=1.0, seed=42, mu=0.01, rho=0.0
)
run_dir_fedprox = run_and_log("FedProx", cfg_fedprox)



Running FedProx
Round 001 | clients 10/10 | drift 2.646 | acc 36.54% | loss 2.1058
Round 002 | clients 10/10 | drift 2.557 | acc 40.36% | loss 1.7754
Round 003 | clients 10/10 | drift 2.403 | acc 47.97% | loss 1.5187
Round 004 | clients 10/10 | drift 2.282 | acc 52.74% | loss 1.3620
Round 005 | clients 10/10 | drift 2.239 | acc 56.20% | loss 1.2667
Round 006 | clients 10/10 | drift 2.234 | acc 58.99% | loss 1.1704
Round 007 | clients 10/10 | drift 2.212 | acc 61.29% | loss 1.1128
Round 008 | clients 10/10 | drift 2.207 | acc 61.46% | loss 1.0938
Round 009 | clients 10/10 | drift 2.206 | acc 64.04% | loss 1.0246
Round 010 | clients 10/10 | drift 2.233 | acc 65.20% | loss 0.9969
Round 011 | clients 10/10 | drift 2.247 | acc 66.69% | loss 0.9459
Round 012 | clients 10/10 | drift 2.261 | acc 67.28% | loss 0.9299
Round 013 | clients 10/10 | drift 2.304 | acc 68.13% | loss 0.9209
Round 014 | clients 10/10 | drift 2.282 | acc 69.29% | loss 0.8891
Round 015 | clients 10/10 | drift 2.298 | acc

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [19]:
cfg_scaffold = dict(
    strategy="scaffold",
    num_clients=10, alpha=0.1, rounds=ROUNDS, K=3,
    batch_size=64, lr=0.005, momentum=0.0, sample_frac=1.0,
    seed=42, mu=0.0, rho=0.0
)
run_dir_scaffold = run_and_log("SCAFFOLD", cfg_scaffold)



Running SCAFFOLD


UnboundLocalError: cannot access local variable 'loss' where it is not associated with a value

In [None]:


# combined accuracy plot
if all_curves:
    comb = pd.concat(all_curves, ignore_index=True)
    plt.figure()
    for label, grp in comb.groupby("label"):
        plt.plot(grp["round"], grp["acc"], label=label)
    plt.xlabel("Round"); plt.ylabel("Accuracy")
    plt.title(f"Accuracy vs Rounds — All Strategies ({common_cfg['rounds']} rounds)")
    plt.legend(); plt.tight_layout()
    combo_path = os.path.join(OUTDIR, f"combined_accuracy_{common_cfg['rounds']}r.png")
    plt.savefig(combo_path, dpi=160); plt.close()
    print(f"Combined accuracy plot saved -> {combo_path}")

print(f"All artifacts under: {os.path.abspath(OUTDIR)}")