# Decentralized baseline

## Requirements

This cell sets up Weights & Biases to log training metrics and hyperparameters. It’s useful for comparing experiments (e.g. FedAvg vs. model editing) and tracking performance over time.

In [None]:
!pip install wandb
import wandb
wandb.login()

## Dataset Partitioning Strategies for Federated Learning

This section defines two key dataset partitioning functions that simulate different data distributions among clients in a Federated Learning (FL) setup using the CIFAR-100 dataset. These functions ensure a realistic simulation of decentralized data scenarios by controlling how training data is split and assigned to each client.

### 1. `iid_shard_train_val`
This function simulates an i.i.d. (independent and identically distributed) setting by distributing samples equally across clients while maintaining uniform class distribution. Each client receives a balanced subset of the dataset and performs a local train/validation split. This setup mimics a scenario where clients have similar data distributions, which is idealized but useful for baseline comparisons.

### 2. `non_iid_shard_train_val`
This function implements a non-i.i.d. label-skew sharding strategy, where each client receives data from exactly `Nc` distinct classes without overlap. The data from each class is first divided into shards, which are then randomly distributed to clients such that each one ends up with data belonging to a limited subset of classes. A local train/validation split is also performed. This simulates real-world heterogeneity in FL systems, where clients often observe biased or non-representative data distributions.

Together, these two utilities provide a foundational component for studying the impact of data heterogeneity in Federated Learning experiments.

This following cell imports all core libraries needed for training: PyTorch, timm for loading pretrained models, torchvision for datasets and transforms, and utilities for data handling and visualization.

In [None]:
import torch
import torch.nn as nn
import timm
import torchvision
from torchvision import transforms, datasets
import numpy as np
from collections import defaultdict, Counter
import random
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple


In [None]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import random_split, Dataset

class TransformedSubset(Dataset):
  """
  This cell defines the data transformation pipeline and introduces a utility class to apply transformations to subsets of datasets.
  Specifically, it creates a TransformedSubset class that wraps around a PyTorch Subset, enabling the application of a custom
  transform to the data samples at access time.
  """

    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)


In [None]:
def iid_shard_train_val(dataset, K, val_split=0.2, seed=42):
    """
    This function implements stratified i.i.d. sharding with local train/validation split.

    - Ensures each of the K clients receives (approximately) the same number of samples per class.
    - Maintains class balance across all clients for fair i.i.d. distribution.
    - Performs a local train/validation split within each client, controlled by `val_split`.
    - Uses a fixed `seed` to guarantee reproducibility of the sharding.
    - Returns a dictionary: {client_id: {'train': [...], 'val': [...]}} mapping client indices to their dataset partitions.
    """

    rng = np.random.RandomState(seed)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    n_classes = len(np.unique(labels))
    class_indices = {c: np.where(labels == c)[0] for c in range(n_classes)}
    for c in class_indices:
        rng.shuffle(class_indices[c])

    # Nombre d'exemples par classe à répartir par client
    examples_per_class = {c: len(class_indices[c]) // K for c in class_indices}
    # On répartit les "restes" (si pas divisible) au début
    leftovers = {c: len(class_indices[c]) % K for c in class_indices}

    client_indices = {i: [] for i in range(K)}
    for c in range(n_classes):
        idxs = class_indices[c]
        cursor = 0
        for i in range(K):
            take = examples_per_class[c] + (1 if i < leftovers[c] else 0)
            client_indices[i].extend(idxs[cursor:cursor+take])
            cursor += take

    client_data = {}
    for i in range(K):
        idxs = np.array(client_indices[i])
        rng.shuffle(idxs)
        n_val = int(len(idxs) * val_split)
        val_idxs = idxs[:n_val]
        train_idxs = idxs[n_val:]
        client_data[i] = {'train': train_idxs.tolist(), 'val': val_idxs.tolist()}
    return client_data


In [None]:
def non_iid_shard_train_val(dataset, K, Nc, val_split=0.2, seed=42):
    """
    This function implements non-i.i.d. sharding (label-skew) combined with local train/validation splitting.

    - Each of the K clients is assigned Nc **distinct** class shards without class overlap across clients.
    - For each class, the examples are partitioned into small shards, which are then randomly distributed to clients.
    - The shards ensure that each client sees only a limited and specific subset of classes (controlled by Nc).
    - Within each client's shard, a local train/val split is performed according to `val_split`.
    - The process is randomized using a fixed `seed` to ensure reproducibility.
    - Returns a dictionary: {client_id: {'train': [...], 'val': [...]}} mapping client indices to their dataset splits.
    """

    rng = np.random.RandomState(seed)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    n_classes = len(np.unique(labels))
    class_indices = {c: rng.permutation(np.where(labels == c)[0]).tolist() for c in range(n_classes)}
    # Générer les shards par classe
    shards_per_class = (K * Nc) // n_classes
    shards = []
    for c in range(n_classes):
        idxs = class_indices[c]
        shard_size = len(idxs) // shards_per_class
        for i in range(shards_per_class):
            shard = idxs[i*shard_size:(i+1)*shard_size]
            if len(shard) > 0:
                shards.append((c, shard))
    rng.shuffle(shards)
    # Attribution des Nc shards de classes différentes par client
    client_shards = {i: [] for i in range(K)}
    used = set()
    for i in range(K):
        chosen = []
        class_seen = set()
        for s_idx, (c, shard) in enumerate(shards):
            if c not in class_seen and s_idx not in used:
                chosen.append(s_idx)
                class_seen.add(c)
            if len(chosen) == Nc:
                break
        for s_idx in chosen:
            used.add(s_idx)
            client_shards[i].extend(shards[s_idx][1])
    # Split local train/val pour chaque client
    client_data = {}
    for i in range(K):
        idxs = np.array(client_shards[i])
        rng.shuffle(idxs)
        n_val = int(len(idxs) * val_split)
        val_idxs = idxs[:n_val]
        train_idxs = idxs[n_val:]
        client_data[i] = {'train': train_idxs.tolist(), 'val': val_idxs.tolist()}
    return client_data




## Data Partitioning and Visualization for Federated Learning

This section presents the practical implementation and validation of both i.i.d. and non-i.i.d. data sharding strategies for the CIFAR-100 dataset in the context of Federated Learning (FL). It also includes diagnostic visualizations to assess class distribution across clients.

###  i.i.d. Partitioning
The CIFAR-100 training set is partitioned into `K = 100` clients using a stratified i.i.d. strategy, ensuring that each client receives a balanced and representative subset of all 100 classes. A local validation set is also carved out for each client. The correctness of the splits is verified by checking for overlapping indices and analyzing global distribution statistics.

###  non-i.i.d. Partitioning (Label Skew)
In contrast, a non-i.i.d. strategy is used to assign `Nc = 5` unique classes per client without class overlap. This setup reflects real-world data heterogeneity in FL. The resulting partitions are validated in the same way, with additional printouts of class distributions per client to confirm label skew.

###  Visualizations
Two heatmaps are generated (one for i.i.d. and one for non-i.i.d. sharding) showing the distribution of training samples per class across all clients. In the i.i.d. case, the heatmap shows uniform color intensity, while in the non-i.i.d. case, the matrix is sparse with strong class clustering per client.

These steps provide a robust foundation for simulating decentralized data scenarios and are critical to understanding how data heterogeneity impacts model training in Federated Learning.


In [None]:
"""
This cell loads the CIFAR-100 dataset, performs i.i.d. sharding into K clients, and verifies the correctness of the resulting data splits.

- Loads the CIFAR-100 dataset without applying any transformation.
- Applies the previously defined `iid_shard_train_val` function to split the training set across K clients with a local train/validation split.
- Defines a utility function `check_federated_splits` that validates the integrity of the sharding:
  - Ensures no overlap between train and validation sets within a client.
  - Ensures no sample is assigned to more than one client.
  - Prints distribution statistics such as global train/val ratio and per-client class counts.
- This verification helps confirm that the simulated federated dataset is consistent and suitable for experimentation.
"""


import numpy as np
import torch
import torchvision

full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)

K = 100
val_split = 0.2
seed = 42

client_data = iid_shard_train_val(full_train, K=K, val_split=val_split, seed=seed)

def check_federated_splits(dataset, client_data, K, val_split=0.2, mode='iid'):
    all_train = []
    all_val = []
    for i in range(K):
        train_idx = client_data[i]['train']
        val_idx = client_data[i]['val']
        all_train.extend(train_idx)
        all_val.extend(val_idx)
        # Check overlap train/val local
        overlap = set(train_idx) & set(val_idx)
        assert len(overlap) == 0, f"[Client {i}] Overlap train/val!"
    all_indices = all_train + all_val
    assert len(all_indices) == len(set(all_indices)),
    print(f"Total samples distributed: {len(all_indices)} / {len(dataset)}")
    n_total = len(dataset)
    n_expected_train = int((1-val_split) * n_total / K) * K
    n_expected_val = int(val_split * n_total / K) * K
    print(f"Per client (approx.): train={int((1-val_split)*n_total/K)}, val={int(val_split*n_total/K)}")
    print(f"Mode: {mode.upper()}")
    print(f"Train/val ratio (global): {len(all_train)/n_total:.3f} / {len(all_val)/n_total:.3f}")
    print(f"Unique indices: {len(set(all_indices))} (should = {len(all_indices)})")
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    print(f"Exemple: [Client 0] classes: {set(labels[idx] for idx in client_data[0]['train'])}")

    from collections import Counter
    class_counts = Counter([labels[idx] for idx in client_data[0]['train']])
    print(f"[Client 0] Samples per class:")
    for cls in sorted(class_counts.keys()):
        print(f"  Classe {cls}: {class_counts[cls]}")

check_federated_splits(full_train, client_data, K=K, val_split=val_split, mode='iid')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

"""
This cell defines and uses a function to visualize the class distribution across clients using a heatmap.

- Computes a (K × 100) matrix where each row corresponds to a client and each column to a CIFAR-100 class.
- The matrix entries indicate how many training samples of each class are assigned to each client.
- Uses seaborn to plot a heatmap showing class imbalance or uniformity across the K clients.
- This visualization is useful for verifying the effectiveness of the i.i.d. or non-i.i.d. partitioning strategies.
"""

def plot_label_heatmap(dataset, client_data, K):

    n_classes = 100  # CIFAR-100
    label_matrix = np.zeros((K, n_classes), dtype=int)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])

    for i in range(K):
        idxs = client_data[i]['train']
        lbls = labels[idxs]
        for c in range(n_classes):
            label_matrix[i, c] = np.sum(lbls == c)

    plt.figure(figsize=(18, 6))
    sns.heatmap(label_matrix, cmap="viridis", annot=False, cbar=True)
    plt.xlabel("Class")
    plt.ylabel("Client")
    plt.title("Label distribution per client (train set)")
    plt.show()

plot_label_heatmap(full_train, client_data, K)


In [None]:
import numpy as np
import torch
import torchvision

"""
This cell loads the CIFAR-100 dataset, applies a non-i.i.d. label-skew sharding, and verifies the resulting splits.

- Loads the raw CIFAR-100 training dataset (no transformations applied yet).
- Uses the `non_iid_shard_train_val` function to distribute data such that each client receives examples from Nc distinct classes.
- Performs a local train/validation split per client, with reproducibility ensured by a fixed seed.
- Redefines and invokes the `check_federated_splits` function to validate:
  - No overlap between training and validation samples.
  - No data leakage across clients.
  - Global and per-client stats, including class distributions for client 0.
- This setup is essential for simulating statistical heterogeneity in Federated Learning.
"""


full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)

K = 100
Nc = 5
val_split = 0.2
seed = 42

client_data = non_iid_shard_train_val(full_train, K=K, Nc=Nc, val_split=val_split, seed=seed)

def check_federated_splits(dataset, client_data, K, val_split=0.2, mode='non-iid'):
    all_train = []
    all_val = []
    for i in range(K):
        train_idx = client_data[i]['train']
        val_idx = client_data[i]['val']
        all_train.extend(train_idx)
        all_val.extend(val_idx)
        # Check overlap train/val local
        overlap = set(train_idx) & set(val_idx)
        assert len(overlap) == 0, f"[Client {i}] Overlap train/val!"
    all_indices = all_train + all_val
    assert len(all_indices) == len(set(all_indices)), "Some samples are assigned to multiple clients!"
    print(f"Total samples distributed: {len(all_indices)} / {len(dataset)}")
    n_total = len(dataset)
    print(f"Per client (approx.): train={int((1-val_split)*n_total/K)}, val={int(val_split*n_total/K)}")
    print(f"Mode: {mode.upper()}")
    print(f"Train/val ratio (global): {len(all_train)/n_total:.3f} / {len(all_val)/n_total:.3f}")
    print(f"Unique indices: {len(set(all_indices))} (should = {len(all_indices)})")
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    print(f"Exemple: [Client 0] classes: {set(labels[idx] for idx in client_data[0]['train'])}")

    from collections import Counter
    class_counts = Counter([labels[idx] for idx in client_data[0]['train']])
    print(f"[Client 0] Samples per class:")
    for cls in sorted(class_counts.keys()):
        print(f"  Classe {cls}: {class_counts[cls]}")

check_federated_splits(full_train, client_data, K=K, val_split=val_split, mode='non-iid')


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torchvision

"""
This cell visualizes the class distribution across clients in a non-i.i.d. label-skew scenario using a heatmap.

- Loads the CIFAR-100 training dataset and applies non-i.i.d. sharding with Nc distinct classes per client.
- Constructs a matrix of shape (K × 100), where each entry represents the number of training samples of a given class held by a specific client.
- Uses seaborn to render a heatmap, clearly illustrating the label imbalance across clients.
- This visualization confirms the effectiveness and severity of the non-i.i.d. partitioning and helps identify class sparsity at the client level.
"""

K = 100
Nc = 5
val_split = 0.2
seed = 42

full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
client_data = non_iid_shard_train_val(full_train, K=K, Nc=Nc, val_split=val_split, seed=seed)

labels = np.array([full_train[i][1] for i in range(len(full_train))])

mat = np.zeros((K, 100), dtype=int)
for i in range(K):
    train_idx = client_data[i]['train']
    client_labels = labels[train_idx]
    for c in range(100):
        mat[i, c] = np.sum(client_labels == c)

plt.figure(figsize=(20, 7))
sns.heatmap(mat, cmap="viridis", cbar=True)
plt.xlabel("Class")
plt.ylabel("Client")
plt.title("Label distribution per client (train set, non-iid)")
plt.show()


## Model Architecture: ViT-S/16 Pretrained with DINO for CIFAR-100

This section defines the neural architecture used throughout the Federated Learning experiments. The model is based on the Vision Transformer (ViT) architecture, specifically the ViT-Small variant with a patch size of 16×16, pretrained using the DINO self-supervised learning framework.

Using the `timm` library, the DINO ViT-S/16 model is loaded with its pretrained weights. The final DINO classification head is removed and replaced by a custom linear classifier with 100 output units, corresponding to the 100 classes of the CIFAR-100 dataset. The feature extractor produces 384-dimensional embeddings, which are passed to the classifier for supervised learning.

This lightweight yet expressive architecture serves as the base model for both centralized training and Federated Learning settings, and will later be used in combination with sparse fine-tuning and model editing techniques.


In [None]:
import torch
import torch.nn as nn
import timm
from torchvision import datasets, transforms

"""
This cell defines the neural network architecture used for classification: a modified ViT-S/16 model pre-trained with DINO.

- Leverages the `timm` library to load a pretrained ViT-Small (patch size 16) model trained with DINO self-supervision.
- Removes the DINO head (`nn.Identity`) to extract raw features from the backbone.
- Adds a custom linear classifier (`nn.Linear`) to adapt the model for the CIFAR-100 classification task (100 output classes).
- The resulting architecture combines strong pre-trained representations with a lightweight classifier suitable for fine-tuning or sparse updates.
"""

class DinoViT_CIFAR100(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.backbone = timm.create_model('vit_small_patch16_224.dino', pretrained=True)
        # the dimension of features for ViT-S/16 is always 384 (see doc timm/models/vit.py)
        self.backbone.head = nn.Identity()  # takes off head DINO
        self.classifier = nn.Linear(384, num_classes)
    def forward(self, x):
        # timm ViT returns (batch, 384) if the head is nn.Identity
        feats = self.backbone(x)   # (batch, 384)
        out = self.classifier(feats)
        return out



## Sparse Fine-Tuning and Model Editing with Fisher-Based Masking

This section introduces the core components enabling sparse fine-tuning for model editing in a Federated Learning context. The approach leverages Fisher Information to identify and selectively update the most relevant parameters, reducing interference and improving communication efficiency.

###  Fisher Score Approximation

The function `_compute_approximated_fisher_scores` estimates the diagonal of the Fisher Information Matrix by accumulating squared gradients over multiple batches. This approximation identifies parameter sensitivity, with low-scoring parameters considered less critical. These scores serve as the foundation for constructing binary masks that govern which weights are updated during fine-tuning. Additional utilities such as `_num_total_params`, `_num_zero_params`, and `_compute_sparsity` allow monitoring the proportion of masked parameters and evaluating sparsity levels.

### Progressive Gradient Mask Calibration

`calibrate_gradient_mask_progressive` progressively builds a binary gradient mask over several rounds by freezing the least important weights identified via Fisher scores. At each step, a smaller fraction of parameters is retained according to a decreasing keep ratio, gradually refining the mask until the target sparsity is reached. This method supports only the "train_least_important" strategy in its current implementation.

### Sparse Optimizer with Gradient Masking

The custom optimizer `SparseSGDM` extends PyTorch's SGD to incorporate gradient masking. During each optimization step, gradients for masked parameters are explicitly zeroed out using the precomputed binary masks. The optimizer maps each parameter to its corresponding mask using unique identifiers, ensuring selective updates.

### Sparse Fine-Tuning Procedure

`sparse_fine_tune` orchestrates the fine-tuning loop by applying the gradient mask and updating only the unmasked parameters. It adjusts the `requires_grad` flag before training and restores it afterward to avoid side effects. Combined with `SparseSGDM`, it ensures that the model undergoes low-interference updates consistent with the calibrated mask, enabling efficient and targeted model editing.

These tools together provide a robust framework for exploring model editing in federated settings, allowing flexible control over parameter updates and enabling experimentation with various sparsity and sensitivity-driven strategies.


In [None]:
"""
This cell defines helper functions for sparse mask analysis and for computing approximated Fisher scores.

- `_num_total_params`, `_num_zero_params`, and `_compute_sparsity`:
  - Provide statistics on a binary mask by calculating the total number of parameters, how many are zeroed (masked), and the resulting sparsity ratio.

- `_compute_approximated_fisher_scores`:
  - Estimates the diagonal of the Fisher Information Matrix using squared gradients averaged over a number of batches.
  - Operates on a validation or local client dataloader.
  - Supports optional masking to compute scores only over active (unmasked) parameters.
  - Returns a dictionary `{param_name: Fisher_diag_tensor}` to be used for identifying low-sensitivity weights during mask calibration.

These functions form the core utilities behind sensitivity-based sparse model editing in Federated Learning.
"""

def _num_total_params(mask: Dict[str, torch.Tensor]) -> int:
    """Returns the total number of parameters (elements) across all tensors in the mask."""
    return sum(t.numel() for t in mask.values())

def _num_zero_params(mask: Dict[str, torch.Tensor]) -> int:
    """Returns the number of parameters set to zero in the mask (i.e., masked out)."""
    return sum((t == 0).sum().item() for t in mask.values())

def _compute_sparsity(mask: Dict[str, torch.Tensor]) -> float:
    """Returns the sparsity, i.e., the fraction of parameters that are masked (value in [0, 1])."""
    return _num_zero_params(mask) / _num_total_params(mask)


def _compute_approximated_fisher_scores(
    model: nn.Module,
    dataloader: DataLoader,
    loss_fn: nn.Module,
    device: torch.device,
    num_batches: Optional[int] = None,
    mask: Optional[Dict[str, torch.Tensor]] = None
):
    """
    Approximate the diagonal of the Fisher Information Matrix via empirical average.
    Args:
        model: torch.nn.Module
        dataloader: DataLoader (local client data)
        loss_fn: torch.nn loss function (e.g. nn.CrossEntropyLoss())
        device: torch.device
        num_batches: number of batches to use for approximation
    Returns:
        Dict {param_name: tensor of Fisher diagonal}
    """
    model.eval()
    fisher_diag = {
        name: torch.zeros_like(param, device=device)
        for name, param in model.named_parameters()
        if param.requires_grad
    }
    total_batches = len(dataloader) if num_batches is None else num_batches

    for batch_idx, (inputs, targets) in enumerate(
        tqdm(dataloader, total=total_batches, desc="Computing Fisher")
    ):
        if num_batches is not None and batch_idx >= num_batches:
            break

        inputs, targets = inputs.to(device), targets.to(device)
        model.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()

        for name, param in model.named_parameters():
            if param.grad is not None:
                fisher_diag[name] += param.grad.detach() ** 2
                if mask is not None:
                    fisher_diag[name] *= mask[name]

    for name in fisher_diag:
        fisher_diag[name] /= total_batches

    return fisher_diag


In [None]:
def calibrate_gradient_mask_progressive(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
    sparsity: float = 0.9,
    rounds: int = 5,
    num_batches: Optional[int] = None,
    loss_fn: nn.Module = nn.CrossEntropyLoss(),
    approximate_fisher: bool = True,
    strategy: str = "train_least_important",
) -> Dict[str, torch.Tensor]:
    """
    This function performs progressive gradient mask calibration using approximate Fisher Information scores.

    - Supports the "train_least_important" strategy, which iteratively identifies and freezes the least sensitive parameters.
    - At each round, scores are computed (currently only approximate Fisher scores), and the active parameter set is pruned further by updating a binary mask.
    - The fraction of parameters retained decreases progressively over multiple calibration rounds to reach the desired sparsity level.
    - The function maintains and updates a dictionary of binary masks (`{param_name: mask}`), which can later be used to freeze selected weights during fine-tuning.
    - Useful for implementing sparse fine-tuning strategies in Federated Learning or model editing without retraining the entire model.
    """

    print("*" * 50)
    print(f"Progressive Mask Calibration - Strategy: {strategy}")
    print("*" * 50)

    model.to(device)

    mask = {
        n: torch.ones_like(p, device=device)
        for n, p in model.named_parameters()
        if p.requires_grad
    }

    for r in range(1, rounds + 1):
        print(f"[Round {r}]")

        # Score computation (only approximate Fisher supported here)
        if approximate_fisher:
            scores = _compute_approximated_fisher_scores(
                model=model,
                dataloader=dataloader,
                loss_fn=loss_fn,
                num_batches=num_batches,
                device=device,
                mask=mask,
            )
        else:
            raise NotImplementedError("Only approximate Fisher is implemented.")

         # 1. take every scores (to log, debug)
        all_scores = torch.cat([v.flatten() for v in scores.values()])
        # 2. Retain only the scores of parameters that remain active (i.e., where mask == 1)
        active_scores = torch.cat([
            score[mask[name] != 0].flatten()
            for name, score in scores.items()
        ])
        total_params = all_scores.numel()
        total_active_params = active_scores.numel()

        # Exponentially decrease keep_fraction for progressive pruning
        keep_fraction = (1-sparsity) ** (r / rounds)
        k = int(keep_fraction * total_params)
        print(f"Current keep fraction: {keep_fraction:.4f} | Keeping only top k: {k}")

        if strategy == "train_least_important":
            #To prevent bugs: ensure that k does not exceed the number of active parameters
            k = max(1, min(k, total_active_params))
            threshold, _ = torch.kthvalue(active_scores, k)
            print("Threshold (below which params are kept):", threshold)
            for name, score in scores.items():
                # Mask only newly selected parameters; keep previously zeroed (masked) ones unchanged
                new_mask = (score <= threshold).float()
                mask[name] = mask[name] * new_mask
        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        print(
            f"After round {r} mask sparsity: { _compute_sparsity(mask):.4f} "
            f"({_num_zero_params(mask)}/{_num_total_params(mask)} zeroed params)"
        )
        print()

    print("Progressive Mask Calibration completed.")
    return mask


In [None]:
import torch
from torch.optim import SGD
from typing import Dict, Iterable

"""
This cell defines `SparseSGDM`, a custom optimizer that extends PyTorch's SGD to support gradient masking.

- Inherits from `torch.optim.SGD` and adds support for per-parameter binary masks.
- During each optimization step, gradients of masked-out parameters are zeroed out before the update.
- Accepts both the list of parameters (`params`) and their associated names (`named_params`) to align each gradient with its corresponding mask.
- This mechanism enables sparse fine-tuning by ensuring that only a specific subset of parameters (those with mask=1) are updated.
- Especially useful in Federated Learning and model editing where memory and compute constraints require sparse updates.
"""

class SparseSGDM(SGD):
    def __init__(
        self,
        params: Iterable[torch.nn.Parameter],
        named_params: Dict[str, torch.nn.Parameter],
        lr: float,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        mask: Dict[str, torch.Tensor] = None,
    ):
        super().__init__(
            params,
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        self.mask = mask  # Dict {param_name: mask_tensor}
        self.named_params = named_params  # Dict {name: param}
        self.param_id_to_name = {id(p): n for n, p in named_params.items()}

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None:
            with torch.enable_grad():
                closure()

        for group in self.param_groups:
            for p in group["params"]:
                name = self.param_id_to_name.get(id(p))
                if p.grad is not None and self.mask is not None and name in self.mask:
                    p.grad.mul_(self.mask[name])

        return super().step(closure)


In [None]:
def sparse_fine_tune(
    model: nn.Module,
    dataloader,
    device,
    mask,
    lr=1e-3,
    epochs=1,
    momentum=0.9,
    weight_decay=5e-4,
):
    """
    This function performs sparse fine-tuning of a model using a fixed binary mask.

    - Freezes all parameters except those marked with 1 in the provided mask, by setting `requires_grad` accordingly.
    - Uses the `SparseSGDM` optimizer to ensure that masked gradients are zeroed out during training.
    - Runs standard training for a given number of epochs using cross-entropy loss on the specified dataloader.
    - Only the subset of parameters defined by the mask are updated; all others remain unchanged throughout fine-tuning.
    - After training, it resets all parameters’ `requires_grad` flags to True to avoid unintended side effects if the model is reused.
    - This approach is central to efficient sparse model editing in Federated Learning and continual learning contexts.
    """

    model.to(device)
    # Set requires_grad according to the mask
    for name, param in model.named_parameters():
        if name in mask:
            param.requires_grad = (mask[name] == 1).any().item()
        else:
            param.requires_grad = False

    # Prepare params to optimize (only those with requires_grad)
    named_params = dict(model.named_parameters())
    params = [p for p  in named_params.values() if p.requires_grad]

    # SGD standard (no need for custom optimizer since masking is done by requires_grad)
    optimizer = SparseSGDM(
        params=params,
        named_params=named_params,
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay,
        mask=mask,
    )

    criterion = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(epochs):
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    # Optional: reset requires_grad to True for all params if you reuse model elsewhere
    for param in model.parameters():
        param.requires_grad = True



## Federated Orchestration: Client Logic, Server Aggregation, and Evaluation

This section implements the full Federated Learning pipeline, including the local client behavior, the central training loop (FedAvg), and global model evaluation. It supports both standard and sparse fine-tuning regimes, enabling experimentation with model editing in a decentralized setting.

### `Client` Class

Each client represents a participant in the Federated Learning setup and encapsulates all logic for local training. Clients hold their own private dataset and device context, and can perform:

- **Standard training**, where all parameters are updated using SGD.
- **Sparse fine-tuning**, where only a subset of parameters—selected via a precomputed binary mask—are updated.
- **Mask calibration**, using Fisher Information to identify the least important parameters and construct a sparsity-aware update mask.

This design makes each client modular and self-contained, allowing seamless switching between dense and sparse update strategies.

### `FederatedTrainer` Class

This class orchestrates the full FL loop, using the FedAvg algorithm:

- At each round, a random subset of clients is selected according to a specified participation rate.
- Each client performs local training (either full or sparse, depending on the `use_sparse` flag).
- The server aggregates the clients’ model weights using a sample-weighted average to update the global model.
- Optional evaluation can be performed after each round using a provided evaluation function.

The trainer supports non-i.i.d. settings and sparse model editing strategies, making it highly configurable for a variety of FL experiments.

### `evaluate` Function

A utility function to assess model performance on a validation or test set. It computes classification accuracy and average loss over a given dataloader. This function is used during federated training to track model convergence and generalization.

Together, these components enable a complete and flexible Federated Learning pipeline that integrates sparse fine-tuning techniques, offering scalability, efficiency, and adaptability to real-world data heterogeneity.


In [None]:
class Client:
    """
    This class defines a `Client` in a Federated Learning setting, capable of performing both standard local training
    and sparse fine-tuning using a binary mask.

    - Each client holds a private dataset and its own device context for local computation.
    - `calibrate_mask` computes a binary gradient mask using Fisher scores (or other strategies) to identify the least important parameters.
    - `apply_mask_requires_grad` sets the `requires_grad` flag based on the mask, effectively freezing non-selected weights.
    - `sparse_fine_tune` fine-tunes only the unmasked parameters using SGD, keeping others fixed during training.
    - `local_train` is the main orchestration method:
      - If `use_sparse` is `True`, it triggers mask calibration and sparse fine-tuning for a few epochs.
      - Otherwise, it performs standard local training using all parameters with optional scheduler support.
    - This class encapsulates all logic needed for local updates in a Federated Learning loop, supporting experimentation with model editing techniques.
    """

    def __init__(self, client_id, dataset, device):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device
        self.last_mask = None  # Store the mask if needed

    def calibrate_mask(
        self,
        model,
        sparsity_ratio=0.9,
        num_calib_rounds=5,
        batch_size=128,
        num_batches: Optional[int] = None,
        loss_fn=None,
        strategy: str = "train_least_important",
    ):
        """Calibrate a binary mask based on importance strategy (Fisher, magnitude, or random)."""
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        if loss_fn is None:
            loss_fn = nn.CrossEntropyLoss()
        mask = calibrate_gradient_mask_progressive(
            model=model,
            dataloader=dataloader,
            device=self.device,
            sparsity=sparsity_ratio,
            rounds=num_calib_rounds,
            num_batches=num_batches,
            loss_fn=loss_fn,
            approximate_fisher=True,
            strategy=strategy,
        )
        self.last_mask = mask
        return mask

    def apply_mask_requires_grad(self, model, mask):
        """
        Sets requires_grad=True for params where mask == 1, False otherwise.
        """
        for name, param in model.named_parameters():
            if name in mask:
                param.requires_grad = (mask[name] == 1).any().item()
            else:
                param.requires_grad = False

    def sparse_fine_tune(
        self,
        model,
        mask,
        lr=1e-3,
        epochs=1,
        batch_size=128,
        momentum=0.9,
        weight_decay=5e-4,
    ):
        """
        Sparse fine-tuning: only params with requires_grad=True (i.e. mask==1) are updated.
        """
        self.apply_mask_requires_grad(model, mask)
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        criterion = nn.CrossEntropyLoss()
        model.train()
        for epoch in range(epochs):
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

        # Reset requires_grad
        for param in model.parameters():
            param.requires_grad = True

    def local_train(
        self,
        global_model,
        epochs,
        batch_size,
        lr,
        momentum,
        weight_decay,
        scheduler_fn=None,
        use_sparse=False,
        sparsity_ratio=0.9,
        num_calib_rounds=5,
        num_batches: Optional[int] = None,
        sparse_ft_epochs=1,
        strategy: str = "train_least_important",
    ):
        """
        Performs standard local training or (if use_sparse) sparse fine-tuning.
        """
        model = DinoViT_CIFAR100(num_classes=100).to(self.device)
        model.load_state_dict(global_model.state_dict())

        if use_sparse:
            mask = self.calibrate_mask(
                model,
                sparsity_ratio=sparsity_ratio,
                num_calib_rounds=num_calib_rounds,
                batch_size=batch_size,
                num_batches=num_batches,
                strategy=strategy,  # ← AJOUT
            )
            self.sparse_fine_tune(
                model,
                mask,
                lr=lr,
                epochs=sparse_ft_epochs,
                batch_size=batch_size,
                momentum=momentum,
                weight_decay=weight_decay,
            )
        else:
            # Standard local training
            loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
            optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=lr,
                momentum=momentum,
                weight_decay=weight_decay,
            )
            scheduler = scheduler_fn(optimizer) if scheduler_fn else None
            criterion = nn.CrossEntropyLoss()
            model.train()
            for epoch in range(epochs):
                for X, y in loader:
                    X, y = X.to(self.device), y.to(self.device)
                    optimizer.zero_grad()
                    loss = criterion(model(X), y)
                    loss.backward()
                    optimizer.step()
                if scheduler:
                    scheduler.step()

        return model.state_dict()


In [None]:
class FederatedTrainer:
    """
    This class defines a `FederatedTrainer` that orchestrates the entire Federated Learning process using FedAvg,
    with optional support for sparse model editing.

    - The constructor initializes hyperparameters, the global model, and flags for enabling sparse fine-tuning strategies.
    - `aggregate_weights` performs FedAvg-style aggregation by computing a sample-weighted average of client models.
    - `train_round` executes one communication round:
      - Randomly selects a fraction of clients.
      - Each selected client performs either standard training or sparse fine-tuning, based on the `use_sparse` flag.
      - Their resulting model weights and dataset sizes are collected for weighted aggregation.
    - `fit` coordinates multiple training rounds, optionally evaluating the global model every `eval_every` rounds using a user-provided `eval_fn`.

    This class encapsulates both standard and sparsity-aware federated training workflows, making it a flexible engine for experimentation in FL settings.
    """


    def __init__(
        self,
        clients,
        global_model,
        device,
        client_fraction,
        local_epochs,
        batch_size,
        lr,
        momentum,
        weight_decay,
        scheduler_fn=None,
        use_sparse=False,
        sparsity_ratio=0.9,
        num_calib_rounds=5,
        num_batches: Optional[int] = None,
        sparse_ft_epochs=1,
        strategy: str = "train_least_important",
    ):
        self.clients = clients
        self.global_model = global_model
        self.device = device
        self.client_fraction = client_fraction
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.scheduler_fn = scheduler_fn

        self.use_sparse = use_sparse
        self.sparsity_ratio = sparsity_ratio
        self.num_calib_rounds = num_calib_rounds
        self.num_batches = num_batches
        self.sparse_ft_epochs = sparse_ft_epochs
        self.strategy = strategy

    def aggregate_weights(self, client_states, client_sizes):
        """
        Weighted average (FedAvg) of the selected client weights.
        client_states: list of state_dicts (one per client)
        client_sizes: list of int (number of samples per client)
        """
        total = sum(client_sizes)
        avg_state = {}
        for key in client_states[0].keys():
            weighted_sum = sum(state[key].float() * size for state, size in zip(client_states, client_sizes))
            avg_state[key] = weighted_sum / total
        return avg_state

    def train_round(self):
        """
        Runs one FedAvg round with optional model editing (sparse fine-tune).
        Aggregates using sample-weighted mean (FedAvg-style).
        """
        num_clients = len(self.clients)
        m = max(int(self.client_fraction * num_clients), 1)
        selected = np.random.choice(self.clients, m, replace=False)
        client_states = []
        client_sizes = []

        for client in selected:
            client_state = client.local_train(
                global_model=self.global_model,
                epochs=self.local_epochs,
                batch_size=self.batch_size,
                lr=self.lr,
                momentum=self.momentum,
                weight_decay=self.weight_decay,
                scheduler_fn=self.scheduler_fn,
                use_sparse=self.use_sparse,
                sparsity_ratio=self.sparsity_ratio,
                num_calib_rounds=self.num_calib_rounds,
                num_batches=self.num_batches,
                sparse_ft_epochs=self.sparse_ft_epochs,
                strategy=self.strategy  # ← AJOUT
            )
            client_states.append(client_state)
            client_sizes.append(len(client.dataset))

        avg_state = self.aggregate_weights(client_states, client_sizes)
        self.global_model.load_state_dict(avg_state)

    def fit(self, n_rounds, eval_fn=None, eval_every=1):
        for rnd in range(1, n_rounds + 1):
            print(f'---- FedAvg Round {rnd} {"(SPARSE-EDITING)" if self.use_sparse else ""} ----')
            self.train_round()
            if eval_fn and (rnd % eval_every == 0 or rnd == n_rounds):
                acc, loss = eval_fn(self.global_model)
                print(f'[Round {rnd}] Eval: Acc={acc:.3f} | Loss={loss:.3f}')


In [None]:
import torch
import torch.nn as nn

def evaluate(model, dataloader, device):
    """
    This function evaluates a classification model on a given dataset.

    - Runs inference on the provided `dataloader` using `torch.no_grad()` to disable gradient tracking.
    - Computes both the total number of correct predictions and the average cross-entropy loss.
    - Returns a tuple `(accuracy, average_loss)` which can be used to track model performance over time.
    - Used during federated training (e.g., in the `fit` method of `FederatedTrainer`) to evaluate the global model at regular intervals.
    """

    model = model.to(device)
    model.eval()
    criterion = nn.CrossEntropyLoss()
    correct, total, total_loss = 0, 0, 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            loss = criterion(outputs, y)
            total_loss += loss.item() * X.size(0)
            _, preds = outputs.max(1)
            correct += (preds == y).sum().item()
            total += X.size(0)
    if total == 0:
        return 0.0, 0.0
    return correct / total, total_loss / total


## Dataset Sharding Setup: IID and Non-IID Splits

This section initializes the dataset and prepares the client-specific data splits for all future Federated Learning experiments.

- The CIFAR-100 dataset is loaded in its raw form, without transformations.
- Two types of data partitioning are applied to the training set:
  - **IID Split**: Each of the 100 clients receives a balanced subset of the dataset, covering all classes uniformly.
  - **Non-IID Split**: Each client receives examples from a limited number of classes (`Nc`, e.g., 50), simulating label skew and statistical heterogeneity.
- The resulting splits (`iid_split` and `non_iid_split`) include both training and validation indices for every client.
- These partitions are fixed and reusable across all runs, ensuring consistent conditions for evaluating different models and strategies.


In [None]:
import torch
import torchvision
from torchvision import transforms
import numpy as np

# Hyperparams
K = 100
val_split = 0.2
seed = 42
Nc = 50

full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

iid_split = iid_shard_train_val(full_train, K=K, val_split=val_split, seed=seed)

non_iid_split = non_iid_shard_train_val(full_train, K=K, Nc=Nc, val_split=val_split, seed=seed)




## FedAvg IID

This cell launches a complete Federated Learning experiment using the FedAvg algorithm on the CIFAR-100 dataset with 100 i.i.d. clients and a DINO ViT-S/16 model. It sets the main hyperparameters, initializes Weights & Biases for experiment tracking, and handles training, evaluation, checkpointing, and plotting. The fit method of the trainer is overridden to incorporate all these features, allowing the experiment to be resumed seamlessly and the best-performing configuration to be saved for further analysis.

## Federated Training with FedAvg and Grid Search

This cell sets up and launches a full Federated Learning experiment using the FedAvg algorithm on the CIFAR-100 dataset with ViT-S/16 (pretrained with DINO). The training is performed under an i.i.d. data distribution across 100 clients and includes grid search over key hyperparameters.

### Experiment Setup
- **Dataset**: CIFAR-100 is partitioned using an i.i.d. strategy with local train/validation splits.
- **Model**: A DINO-pretrained ViT-S/16 model adapted for CIFAR-100 classification.
- **Clients**: Each client is assigned a local dataset and instantiated with transformation pipelines.
- **Global validation**: A centralized validation set is constructed by aggregating all local validation samples.

### Hyperparameter Grid Search
- A grid search is performed over combinations of learning rates (`lr_list`) and momentum values (`momentum_list`).
- For each configuration, a new global model and `FederatedTrainer` instance is initialized.

### Training Logic
- The trainer runs `n_rounds` of FedAvg with a client fraction `C`.
- After every few rounds, validation performance is evaluated and logged to Weights & Biases (wandb).
- Results, including best validation accuracy and configuration metadata, are saved to disk in JSON format.
- A custom `fit_with_all_logs` method is used to support checkpointing, live plotting, and structured logging.

This workflow enables large-scale, reproducible Federated Learning experiments with robust logging, checkpointing, and evaluation, serving as a baseline for future comparisons with sparse model editing or non-i.i.d. setups.


In [None]:
import torch
import torchvision
from torchvision import transforms
import numpy as np
import wandb
import types
import os
import json
import matplotlib.pyplot as plt

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 20
batch_size = 128
weight_decay = 5e-4

# --- Hyperparams ---
lr_list = [0.001, 0.01, 0.1]
momentum_list = [0.8, 0.9, 0.95]

CHECKPOINT_PATH = "fedavg_grid_ckpt.pt"
RESULTS_PATH = "fedavg_grid_results.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Loading CIFAR-100 ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: IID sharding + train/val local ---
client_data = iid_shard_train_val(full_train, K=K, val_split=0.2, seed=42)

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

val_indices = np.concatenate([client_data[i]['val'] for i in range(K)])
val_set = TransformedSubset(Subset(full_train, val_indices), val_transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=False)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
def make_model():
    return DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn(model):
    return evaluate(model, val_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Val Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Val Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Val Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Val Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=None, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume and checkpoint_path:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            val_acc, val_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Val Acc={val_acc:.3f} | Val Loss={val_loss:.3f}')
            wandb.log({"round": rnd, "val_acc": val_acc, "val_loss": val_loss})
            acc_history.append(val_acc)
            loss_history.append(val_loss)
            if checkpoint_path and (rnd % 5 == 0 or rnd == n_rounds):
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)
    # Log best hyperparams for this run
    best_acc = max(acc_history) if acc_history else 0
    result = dict(wandb.config)
    result["best_val_acc"] = best_acc
    with open(RESULTS_PATH, "a") as f:
        f.write(json.dumps(result) + "\n")

# --- GRID SEARCH RUNS ---
run_idx = 0
for lr in lr_list:
    for momentum in momentum_list:
        run_idx += 1
        print(f"\n=== NEW HP RUN {run_idx}: lr={lr}, momentum={momentum} ===\n")
        global_model = make_model()

        run_name = f"fedavg_iid_grid_lr{lr}_mom{momentum}_J{J}_nrounds{n_rounds}_bs{batch_size}"

        wandb.init(
            project="fl-fedavg",
            name=run_name,
            config={
                "model": "DINO ViT-S/16",
                "K": K,
                "C": C,
                "J": J,
                "n_rounds": n_rounds,
                "batch_size": batch_size,
                "lr": lr,
                "momentum": momentum,
                "weight_decay": weight_decay,
                "sharding": "iid",
                "Nc": None,
                "use_sparse": False
            }
        )

        trainer = FederatedTrainer(
            clients=clients,
            global_model=global_model,
            device=device,
            client_fraction=C,
            local_epochs=J,
            batch_size=batch_size,
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            scheduler_fn=make_scheduler,
            use_sparse=False
        )

        trainer.fit = types.MethodType(fit_with_all_logs, trainer)
        trainer.fit(
            n_rounds,
            eval_fn=eval_fn,
            eval_every=2,
            checkpoint_path=None,
            resume=False
        )
        wandb.finish()


## FedAvg Non IID

This script runs a full non-i.i.d. grid search over multiple combinations of local epochs (Jc), communication rounds, and class partitions per client (Nc). For each configuration, a new non-i.i.d. sharding is generated, a new global model is initialized, and a separate training run is launched using the FedAvg algorithm on the CIFAR-100 dataset with a DINO ViT-S/16 model. Training is logged with Weights & Biases, checkpoints are saved for resuming or analysis, and performance metrics (accuracy and loss) are plotted and stored for each experiment.

## Federated Grid Search under Non-IID Conditions

This cell runs a comprehensive set of Federated Learning experiments using the FedAvg algorithm under non-i.i.d. conditions on the CIFAR-100 dataset. It systematically explores the impact of different degrees of data heterogeneity and local update intensity.

### Dataset Setup
- CIFAR-100 is partitioned across 100 clients using **non-i.i.d. label-skew** splitting, controlled by the `Nc` parameter (number of classes per client).
- Each client receives a personalized dataset along with appropriate train transformations.
- A global validation set is built from the union of all clients' validation splits.

### Fixed & Tuned Hyperparameters
- The best learning rate (`lr=0.001`) and momentum (`momentum=0.8`) values found in previous i.i.d. experiments are reused.
- Three values for local epochs (`Jc = 4, 8, 16`) are paired with decreasing numbers of communication rounds (`n_rounds = 50, 25, 12`) to maintain a constant local workload.
- The number of classes per client (`Nc`) is varied over `[1, 5, 10, 50]` to simulate increasing degrees of heterogeneity.

### Training Workflow
- For each `(Jc, Nc)` combination, a new `FederatedTrainer` and `DinoViT_CIFAR100` model are instantiated.
- Training proceeds for `n_rounds` with client sampling fraction `C = 0.1`, using FedAvg for aggregation.
- Validation performance is evaluated every 2 rounds and logged to Weights & Biases.
- Checkpoints are saved periodically in a dedicated directory (`fedavg_non_iid_ckpts`), and the best results are stored in a results file (`fedavg_non_iid_results.json`).

This experimental protocol allows fine-grained analysis of how local update intensity and data heterogeneity affect convergence and generalization in federated settings.


In [None]:
import torch
import torchvision
from torchvision import transforms
import numpy as np
import wandb
import types
import os
import json
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

K = 100
C = 0.1
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

Jc_list = [4, 8, 16]
n_rounds_list = [50, 25, 12]
Nc_list = [1, 5, 10, 50]

RESULTS_PATH = "fedavg_non_iid_results.json"
CKPT_DIR = "fedavg_non_iid_ckpts"
os.makedirs(CKPT_DIR, exist_ok=True)

# --- Transforms ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

def make_scheduler(optimizer, J_value):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J_value)

# ==== GRIDSEARCH NON-IID FL-READY ====
run_idx = 0

for Jc, n_rounds in zip(Jc_list, n_rounds_list):
    for Nc in Nc_list:
        run_idx += 1

        non_iid_split = non_iid_shard_train_val(full_train, K=K, Nc=Nc, val_split=0.2, seed=42)
        clients = []
        for i in range(K):
            train_idxs = non_iid_split[i]['train']
            client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
            clients.append(Client(i, client_train_dataset, device))

        val_indices = np.concatenate([non_iid_split[i]['val'] for i in range(K)])
        val_set = TransformedSubset(Subset(full_train, val_indices), val_transform)
        val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=False)

        global_model = DinoViT_CIFAR100(num_classes=100).to(device)

        ckpt_name = f"fedavg_non_iid_checkpoint_Nc{Nc}_J{Jc}_nrounds{n_rounds}.pt"
        checkpoint_path = os.path.join(CKPT_DIR, ckpt_name)

        run_name = f"fedavg_non_iid_J{Jc}_nrounds{n_rounds}_Nc{Nc}_lr{lr}_mom{momentum}"
        wandb.init(
            project="fl-fedavg",
            name=run_name,
            config={
                "model": "DINO ViT-S/16",
                "K": K,
                "C": C,
                "J": Jc,
                "n_rounds": n_rounds,
                "batch_size": batch_size,
                "lr": lr,
                "momentum": momentum,
                "weight_decay": weight_decay,
                "sharding": "non-iid",
                "Nc": Nc,
                "use_sparse": False
            }
        )

        trainer = FederatedTrainer(
            clients=clients,
            global_model=global_model,
            device=device,
            client_fraction=C,
            local_epochs=Jc,
            batch_size=batch_size,
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            scheduler_fn=lambda opt: make_scheduler(opt, Jc),
            use_sparse=False
        )

        def eval_fn(model):
            return evaluate(model, val_loader, device)

        trainer.fit = types.MethodType(fit_with_all_logs, trainer)
        print(f"\n=== RUN {run_idx}: J={Jc}, n_rounds={n_rounds}, Nc={Nc} ===\n")
        trainer.fit(n_rounds, eval_fn=eval_fn, eval_every=2, checkpoint_path=checkpoint_path, resume=True)
        wandb.finish()



## Model Editing FedAvg IID, Research of best hyperparameters

It is IID model Editing code.We are looking for the best hyperparams on 50 rounds. This code runs a grid search over various sparsity ratios and calibration rounds to evaluate the impact of model editing using sparse fine-tuning in a federated learning setting. Each configuration corresponds to a separate run, where a DINO ViT-S/16 model is trained with i.i.d. sharding on CIFAR-100 using the FedAvg algorithm. Although this code was executed on a different Google Colab instance, all training curves and metrics remain accessible via the associated Weights & Biases (wandb) project, enabling transparent and centralized monitoring across experiments.

## Federated Sparse Fine-Tuning via Model Editing (IID Setup)

This cell launches a Federated Learning experiment using **model editing via sparse fine-tuning** under i.i.d. conditions, based on a DINO-pretrained ViT-S/16 model and the CIFAR-100 dataset.

### Objective
The goal is to assess the impact of sparse fine-tuning on global model performance by updating only a subset of parameters per communication round. This is achieved through binary gradient masks calibrated via approximated Fisher Information.

### Experimental Settings
- **Fixed Parameters**:
  - `K=100` clients, with a sampling fraction `C=0.1` per round.
  - Local training: `J=4` epochs per round.
  - **Learning rate (`lr=0.001`) and momentum (`momentum=0.8`) are reused from the best configuration found in the standard IID FedAvg grid search**.
  - Additional fixed values: `weight_decay=5e-4`, `batch_size=128`.

- **Grid Search Parameters**:
  - `sparsity_ratio ∈ {0.85, 0.90, 0.95}`: proportion of weights masked out (not updated).
  - `num_calib_rounds ∈ {1, 3, 5}`: number of rounds used to progressively refine the gradient mask.
  - `sparse_ft_epochs = 1`: number of epochs used during the sparse fine-tuning phase on each client.

### Model Editing Logic
For each client:
- A **gradient mask** is calibrated using approximated Fisher scores computed over local batches.
- Only parameters with high sensitivity (mask = 1) are updated during sparse fine-tuning.
- Fine-tuning is conducted via SGD, with masked gradients zeroed out via `SparseSGDM`.

### Logging & Checkpointing
- Performance metrics are logged to Weights & Biases (wandb) every few rounds.
- Model checkpoints are saved periodically and at the final round.
- The best validation accuracy and corresponding hyperparameters are saved to a JSON file for post-analysis.

This pipeline enables a rigorous evaluation of model editing techniques through sparsity-aware federated training, offering insight into the efficiency and effectiveness of selective parameter updates.


In [None]:
import torch
import torchvision
from torchvision import transforms
import numpy as np
import wandb
import types
import os
import json

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparams
K = 100
C = 0.1
J = 4
n_rounds = 50
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

# Param grid
sparsity_ratios = [0.85, 0.90, 0.95]
num_calib_rounds_list = [1, 3, 5]
sparse_ft_epochs = 1

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Data (no split/transform here) ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)
iid_split = iid_shard_train_val(full_train, K=K, val_split=0.2, seed=42)

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

val_indices = np.concatenate([iid_split[i]['val'] for i in range(K)])
val_set = TransformedSubset(Subset(full_train, val_indices), val_transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=False)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_edit(model):
    return evaluate(model, val_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']}) from {path}")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def save_best_hyperparams(acc_history, config, path):
    best_acc = max(acc_history) if acc_history else 0.0
    run_data = {"best_val_acc": best_acc}
    run_data.update(config)
    with open(path, "a") as f:
        f.write(json.dumps(run_data) + "\n")

def fit_with_wandb_and_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=None, best_json_path=None, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume and checkpoint_path:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (SPARSE-EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            acc, loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Eval: Acc={acc:.3f} | Loss={loss:.3f}')
            wandb.log({"round": rnd, "val_acc": acc, "val_loss": loss})
            acc_history.append(acc)
            loss_history.append(loss)
            if checkpoint_path and (rnd % 5 == 0 or rnd == n_rounds):
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
    if best_json_path:
        save_best_hyperparams(acc_history, wandb.config, best_json_path)


# === GRIDSEARCH MODEL EDITING FL-READY ===
run_idx = 0
for sparsity_ratio in sparsity_ratios:
    for num_calib_rounds in num_calib_rounds_list:
        run_idx += 1
        print(f"\n=== MODEL EDITING RUN {run_idx}/9 ===\n")
        clients_edit = []
        for i in range(K):
            train_idxs = iid_split[i]['train']
            client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
            clients_edit.append(Client(i, client_train_dataset, device))

        global_model = DinoViT_CIFAR100(num_classes=100).to(device)

        run_name = (f"model_editing_iid_nrounds{n_rounds}_lr{lr}_"
                    f"sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}_ftep{sparse_ft_epochs}")
        checkpoint_path = f"{run_name}.pt"
        best_json_path = f"{run_name}.json"

        wandb.init(
            project="fl-fedavg",
            name=run_name,
            config={
                "model": "DINO ViT-S/16",
                "K": K,
                "C": C,
                "J": J,
                "n_rounds": n_rounds,
                "batch_size": batch_size,
                "lr": lr,
                "momentum": momentum,
                "weight_decay": weight_decay,
                "sharding": "iid",
                "use_sparse": True,
                "sparsity_ratio": sparsity_ratio,
                "num_calib_rounds": num_calib_rounds,
                "sparse_ft_epochs": sparse_ft_epochs
            }
        )

        trainer_edit = FederatedTrainer(
            clients=clients_edit,
            global_model=global_model,
            device=device,
            client_fraction=C,
            local_epochs=J,
            batch_size=batch_size,
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            scheduler_fn=make_scheduler,
            use_sparse=True,
            sparsity_ratio=sparsity_ratio,
            num_calib_rounds=num_calib_rounds,
            sparse_ft_epochs=sparse_ft_epochs
        )
        trainer_edit.fit = types.MethodType(fit_with_wandb_and_logs, trainer_edit)
        trainer_edit.fit(
            n_rounds,
            eval_fn=eval_fn_edit,
            eval_every=2,
            checkpoint_path=checkpoint_path,
            best_json_path=best_json_path,
            resume=True
        )
        wandb.finish()



## Model Editing FedAvg non IID, Research of best hyperparameters

## Federated Sparse Fine-Tuning via Model Editing (Non-IID Setup)

This experiment replicates the model editing strategy with **sparse fine-tuning**, this time under **non-i.i.d. conditions** where each client receives data from only a limited number of classes (`Nc = 50`). The CIFAR-100 dataset is distributed in a label-skewed fashion to better simulate realistic statistical heterogeneity.

### Objective
To evaluate how sparse fine-tuning performs in non-i.i.d. settings when clients calibrate personalized update masks based on their own local data distributions.

### Experimental Settings
- **Fixed Parameters**:
  - `K = 100` clients, with a participation fraction `C = 0.1` per round.
  - Each round consists of `J = 4` local epochs.
  - **Learning rate (`lr = 0.001`) and momentum (`momentum = 0.8`) are inherited from the best IID baseline**.
  - `Nc = 50`: each client receives samples from 50 unique classes.
  - `batch_size = 128`, `weight_decay = 5e-4`.

- **Grid Search Parameters**:
  - `sparsity_ratio ∈ {0.85, 0.90, 0.95}`: controls how many weights are frozen.
  - `num_calib_rounds ∈ {1, 3, 5}`: number of rounds used to progressively calibrate the sparse mask.
  - `sparse_ft_epochs = 1`: number of fine-tuning epochs on each client.

### Model Editing Logic
Each client:
- Computes an importance mask using approximated Fisher Information.
- Applies the mask to freeze less sensitive parameters.
- Fine-tunes only the most important subset using SGD.

A central validation set is created by aggregating all local validation sets. Performance is evaluated every two rounds.

### Logging and Saving
- Model metrics are logged to **Weights & Biases (wandb)**.
- Checkpoints are saved every 5 rounds or at the end of training.
- Each run's best validation accuracy and associated hyperparameters are stored in a JSON file for analysis.

This pipeline allows assessing the robustness and efficiency of sparse model editing techniques in the presence of non-i.i.d. data distributions.


In [None]:
import torch
import torchvision
from torchvision import transforms
import numpy as np
import wandb
import types
import os
import json
from torch.utils.data import Subset, Dataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
J = 4
C = 0.1
Nc = 50
n_rounds = 50
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4
sparsity_ratios = [0.85, 0.90, 0.95]
num_calib_rounds_list = [1, 3, 5]
sparse_ft_epochs = 1
K = 100


seed = 42
val_split = 0.2
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

non_iid_split = non_iid_shard_train_val(full_train, K=K, Nc=Nc, val_split=val_split, seed=seed)

class TransformedSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Global Validation
val_indices = np.concatenate([non_iid_split[i]['val'] for i in range(K)])
val_set = TransformedSubset(torch.utils.data.Subset(full_train, val_indices), val_transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=False)


def eval_fn_edit(model):
    return evaluate(model, val_loader, device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def save_checkpoint(model, round_idx, acc_history, loss_history, path):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']}) from {path}")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def save_best_hyperparams(acc_history, config, path):
    best_acc = max(acc_history) if acc_history else 0.0
    run_data = {"best_val_acc": best_acc}
    run_data.update(config)
    with open(path, "a") as f:
        f.write(json.dumps(run_data) + "\n")

def fit_with_wandb_and_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=None, best_json_path=None, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume and checkpoint_path:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (SPARSE-EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            acc, loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Eval: Acc={acc:.3f} | Loss={loss:.3f}')
            wandb.log({"round": rnd, "val_acc": acc, "val_loss": loss})
            acc_history.append(acc)
            loss_history.append(loss)
            if checkpoint_path and (rnd % 5 == 0 or rnd == n_rounds):
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
    if best_json_path:
        save_best_hyperparams(acc_history, wandb.config, best_json_path)

# --- GRIDSEARCH MODEL EDITING NON-IID ---
run_idx = 0
for sparsity_ratio in sparsity_ratios:
    for num_calib_rounds in num_calib_rounds_list:
        run_idx += 1
        print(f"\n=== MODEL EDITING NON-IID RUN {run_idx}/9 ===\n")
        # Clients FL-ready (non-iid)
        clients_edit = []
        for i in range(K):
            train_idxs = non_iid_split[i]['train']
            client_train_dataset = TransformedSubset(torch.utils.data.Subset(full_train, train_idxs), train_transform)
            clients_edit.append(Client(i, client_train_dataset, device))

        global_model = DinoViT_CIFAR100(num_classes=100).to(device)

        run_name = (f"model_editing_non_iid_J{J}_nrounds{n_rounds}_Nc{Nc}_lr{lr}_"
                    f"sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}_ftep{sparse_ft_epochs}")
        checkpoint_path = f"{run_name}.pt"
        best_json_path = f"{run_name}.json"

        wandb.init(
            project="fl-fedavg",
            name=run_name,
            config={
                "model": "DINO ViT-S/16",
                "K": K,
                "C": C,
                "J": J,
                "n_rounds": n_rounds,
                "Nc": Nc,
                "batch_size": batch_size,
                "lr": lr,
                "momentum": momentum,
                "weight_decay": weight_decay,
                "sharding": "non-iid",
                "use_sparse": True,
                "sparsity_ratio": sparsity_ratio,
                "num_calib_rounds": num_calib_rounds,
                "sparse_ft_epochs": sparse_ft_epochs
            }
        )

        trainer_edit = FederatedTrainer(
            clients=clients_edit,
            global_model=global_model,
            device=device,
            client_fraction=C,
            local_epochs=J,
            batch_size=batch_size,
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            scheduler_fn=make_scheduler,
            use_sparse=True,
            sparsity_ratio=sparsity_ratio,
            num_calib_rounds=num_calib_rounds,
            sparse_ft_epochs=sparse_ft_epochs
        )
        trainer_edit.fit = types.MethodType(fit_with_wandb_and_logs, trainer_edit)
        trainer_edit.fit(
            n_rounds,
            eval_fn=eval_fn_edit,
            eval_every=2,
            checkpoint_path=checkpoint_path,
            best_json_path=best_json_path,
            resume=True
        )
        wandb.finish()


### Loading of same dataset with split train / test


## Final Dataset Freezing for Evaluation and Reuse

This cell prepares and freezes the dataset partitions for future experimental runs under both i.i.d. and non-i.i.d. settings **without any validation split** (`val_split = 0`), making it suitable for training or inference-only workflows.

### CIFAR-100 Dataset
- The full CIFAR-100 training and test sets are loaded once without transformations.
- Validation sets are omitted for this configuration, i.e., all available data is allocated to training per client.

### Sharding Strategies
- **IID Split**: Clients receive a uniformly distributed subset of all classes, ensuring a balanced and representative dataset.
- **Non-IID Split**: Clients receive examples from `Nc = 50` distinct classes, simulating partial class visibility and label skew typical in federated scenarios.

### Reusability
- The resulting splits (`iid_split` and `non_iid_split`) are stored in memory as dictionaries mapping each client ID to its training indices.
- These static splits can now be reused **consistently across multiple runs and model variants**, ensuring fair and reproducible comparisons.

This setup is especially useful for locked-in evaluation stages, test-time performance comparisons, or fixed-data ablation studies.


In [None]:
import torch
import torchvision
from torchvision import transforms
import numpy as np

# Hyperparams
K = 100
val_split = 0
seed = 42
Nc = 50  # <-- choisis ici la valeur de Nc que tu veux pour le non-iid

full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

iid_split = iid_shard_train_val(full_train, K=K, val_split=val_split, seed=seed)

non_iid_split = non_iid_shard_train_val(full_train, K=K, Nc=Nc, val_split=val_split, seed=seed)


## Test Accuracy FedAvg IID

## Final Evaluation of the FedAvg Baseline on Test Set Only

This cell performs the **final evaluation** of the global model trained using FedAvg in **IID configuration**. The focus is solely on the **test set**, in order to assess the **generalization performance** of the model on unseen data.

### Key Features of This Evaluation:
- **Client Data**: Clients were initially sharded in an **IID** manner with `val_split = 0`, meaning no local validation was used to preserve the **full training capacity**.
- **Model**: A `DINO ViT-S/16` backbone (pretrained), adapted for CIFAR-100 with a linear classification head.
- **Test Set**: The CIFAR-100 test set is fully normalized and used without any stochastic data augmentation.
- **Tracking**: Every evaluable `round` logs `test accuracy` and `test loss` to **WandB**, while dynamically generating **performance plots**.

### Training Setup
The hyperparameters (batch size, learning rate, momentum) are those obtained from the **best-performing settings** during previous classical IID FedAvg experiments.

This cell is intended to be used **after training**, to produce a **standardized, reproducible, and trackable** final evaluation of the global model.


In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, FEDAVG BASELINE ======

import torch
import torchvision
from torchvision import transforms
import numpy as np
import wandb
import types
import os
import json
import matplotlib.pyplot as plt

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 50
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

CHECKPOINT_PATH = "fedavg_checkpoint.pt"
BEST_HYPERPARAMS_PATH = "best_run.json"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Loading CIFAR-100  ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: IID sharding + train local ---
client_data = iid_split

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

# --- Eval function (TEST ONLY) ---
def eval_fn_test(model):
    return evaluate(model, test_loader, device)

# --- Logging, checkpoint, plotting utils ---
def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)


# --- WANDB init ---
wandb.init(
    project="fl-fedavg",
    name=f"fedavg_iid_baseline_test_acc_J{J}_nrounds{n_rounds}_lr{lr}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "iid",
        "Nc": None,
        "use_sparse": False
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=False
)

# --- Patch et run ---
trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()


## FedAvg Non IID Test Accuracy

## Final Evaluation of the FedAvg Baseline on Non-IID Test Set

This experiment performs the **final evaluation** of a global model trained with the **FedAvg algorithm** under a **non-IID client distribution**. The primary goal is to assess how well the model generalizes to unseen data when local training data is significantly biased across clients.

### Key Points:
- **Non-IID Setup**: Each of the `K=100` clients is assigned data from a fixed number of `Nc` distinct classes. This configuration simulates real-world heterogeneity in federated learning systems.
- **Test-Only Evaluation**: The model is evaluated on the **official CIFAR-100 test set**, fully normalized and independent of all clients.
- **Model**: Uses a DINO ViT-S/16 backbone pretrained on ImageNet, with a lightweight classifier head adapted for 100-way classification.
- **Metrics Logged**: Test accuracy and loss are recorded every 2 rounds to **Weights & Biases**, alongside automatic plotting and checkpoint saving for reproducibility.



In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, FEDAVG NON-IID BASELINE ======

import torch
import torchvision
from torchvision import transforms
import numpy as np
import wandb
import types
import os
import json
import matplotlib.pyplot as plt

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 50
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4
Nc = Nc

CHECKPOINT_PATH = "fedavg_non_iid_checkpoint.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Loading CIFAR-100  ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: NON-IID sharding + train/val local ---
client_data = non_iid_split

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_test(model):
    return evaluate(model, test_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)

# --- WANDB init ---
wandb.init(
    project="fl-fedavg",
    name=f"fedavg_non_iid_baseline_test_acc_J{J}_nrounds{n_rounds}_lr{lr}_Nc{Nc}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "non-iid",
        "Nc": Nc,
        "use_sparse": False
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=False
)

trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()


## Test Accuracy FedAvg IID Model Editing

## Final Evaluation of Model Editing under IID Setup (Test Accuracy)

This experiment evaluates the final performance of the **Model Editing technique** under an **IID client distribution**. The model is trained with **sparse fine-tuning**, using a calibrated gradient mask and limited parameter updates per round.

### Key Features:
- **IID Distribution**: All clients are assigned random, balanced subsets of the dataset (uniform over all classes).
- **Model Editing Strategy**: A fixed sparsity mask (ratio = 85%) is calibrated over `5` rounds and used to restrict updates to a subset of parameters. Fine-tuning is done with `1` epoch locally.
- **Global Evaluation**: Performance is measured on the CIFAR-100 **official test set**, ensuring consistency with the FedAvg baseline.
- **Model**: A ViT-S/16 pretrained with DINO is used as the backbone, with a new classifier head.
- **Metrics**: Test accuracy and loss are logged every 2 rounds to **Weights & Biases** with visual plots and periodic checkpoint saving.



In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, MODEL EDITING IID TEST ACC ======

import torch
import torchvision
from torchvision import transforms
import numpy as np
import wandb
import types
import os
import json
import matplotlib.pyplot as plt

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 50
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

# --- Model Editing HP choosen ---
sparsity_ratio = 0.85
num_calib_rounds = 5
sparse_ft_epochs = 1

CHECKPOINT_PATH = "model_editing_iid_checkpoint.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: IID sharding + train local ---
client_data = iid_split

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_test(model):
    return evaluate(model, test_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (MODEL EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)

# --- WANDB init ---
wandb.init(
    project="fl-fedavg",
    name=f"model_editing_iid_test_acc_J{J}_nrounds{n_rounds}_lr{lr}_sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "iid",
        "Nc": None,
        "use_sparse": True,
        "sparsity_ratio": sparsity_ratio,
        "num_calib_rounds": num_calib_rounds,
        "sparse_ft_epochs": sparse_ft_epochs
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=True,
    sparsity_ratio=sparsity_ratio,
    num_calib_rounds=num_calib_rounds,
    sparse_ft_epochs=sparse_ft_epochs
)

# --- Patch and run ---
trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()


## Test Accuracy FedAvg non IID Model Editing

## Final Evaluation of Model Editing under Non-IID Setup (Test Accuracy)

This experiment evaluates the final performance of the **Model Editing strategy** in a **Non-IID federated learning scenario**, where each client receives a biased subset of CIFAR-100 (Nc = 50 classes per client).

### Key Features:
- **Non-IID Distribution**: Each client receives data from only `Nc` distinct classes, simulating strong statistical heterogeneity.
- **Model Editing Technique**: Gradient masking is used to restrict training to only 10% of the model’s parameters (`sparsity_ratio = 0.90`), calibrated over `5` rounds. Each sparse fine-tuning phase runs for `1` local epoch.
- **Global Evaluation**: The performance is assessed on the full CIFAR-100 test set, allowing comparison with other baselines.
- **Backbone Model**: A DINO-pretrained ViT-S/16 with frozen feature extractor and learnable classifier head.
- **Logging & Visualization**: Test accuracy and loss are reported every 2 rounds using Weights & Biases. Checkpoints and plots help monitor convergence and performance over time.



In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, MODEL EDITING NON-IID TEST ACC ======

import torch
import torchvision
from torchvision import transforms
import numpy as np
import wandb
import types
import os
import json
import matplotlib.pyplot as plt

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 50
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4
Nc = Nc

# --- Model Editing HP choosen ---
sparsity_ratio = 0.90
num_calib_rounds = 5
sparse_ft_epochs = 1

CHECKPOINT_PATH = "model_editing_non_iid_checkpoint.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Chargement CIFAR-100 brut (pas de transform ici) ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: NON-IID sharding + train/val local ---
client_data = non_iid_split  # Utilise ton split NON-IID fixé

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_test(model):
    return evaluate(model, test_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (MODEL EDITING NON-IID) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)

# --- WANDB init ---
wandb.init(
    project="fl-fedavg",
    name=f"model_editing_non_iid_test_acc_J{J}_nrounds{n_rounds}_lr{lr}_sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "non-iid",
        "Nc": Nc,
        "use_sparse": True,
        "sparsity_ratio": sparsity_ratio,
        "num_calib_rounds": num_calib_rounds,
        "sparse_ft_epochs": sparse_ft_epochs
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=True,
    sparsity_ratio=sparsity_ratio,
    num_calib_rounds=num_calib_rounds,
    sparse_ft_epochs=sparse_ft_epochs
)

# --- Patch and run ---
trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()
