# Personal Contribution IID

## Introduction :
This notebook contains experiments for the personal contribution on alternative gradient masking strategies applied for iid model editing (train_most_important, magnitude_most, magnitude_least and random)

## Import + setup

In [None]:
!pip install wandb

In [None]:
import wandb
wandb.login()

In [None]:
import torch
import torch.nn as nn
import timm
import torchvision
import wandb
from torchvision import transforms, datasets
import numpy as np
from collections import defaultdict, Counter
import random
import matplotlib.pyplot as plt
from typing import List, Dict, Tuple,  Optional, Iterable
from torch.utils.data import random_split, Dataset, DataLoader
from tqdm import tqdm
from torch.optim import SGD
import types
import os
import json



In [None]:
print("GPU available :", torch.cuda.is_available())
print("Nom GPU", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "not available")

## Utils

In [None]:
class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)

In [None]:
def iid_shard_train_val(dataset, K, val_split=0.2, seed=42):
    """
    Stratified IID sharding: each client receives (approximately) the same number of examples for each class.
    Args:
      dataset: PyTorch Dataset (WITHOUT applied transform)
      K: number of clients
      val_split: local validation fraction (within each client)
      seed: for reproducibility
    Returns:
      client_data: {client_id: {'train': [idxs], 'val': [idxs]}}
    """

    rng = np.random.RandomState(seed)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    n_classes = len(np.unique(labels))
    class_indices = {c: np.where(labels == c)[0] for c in range(n_classes)}
    for c in class_indices:
        rng.shuffle(class_indices[c])

    # Number of examples per class to be distributed to each client
    examples_per_class = {c: len(class_indices[c]) // K for c in class_indices}
    # Any "leftovers" (if not evenly divisible) are assigned first
    leftovers = {c: len(class_indices[c]) % K for c in class_indices}

    client_indices = {i: [] for i in range(K)}
    for c in range(n_classes):
        idxs = class_indices[c]
        cursor = 0
        for i in range(K):
            take = examples_per_class[c] + (1 if i < leftovers[c] else 0)
            client_indices[i].extend(idxs[cursor:cursor+take])
            cursor += take

    client_data = {}
    for i in range(K):
        idxs = np.array(client_indices[i])
        rng.shuffle(idxs)
        n_val = int(len(idxs) * val_split)
        val_idxs = idxs[:n_val]
        train_idxs = idxs[n_val:]
        client_data[i] = {'train': train_idxs.tolist(), 'val': val_idxs.tolist()}
    return client_data


In [None]:
def non_iid_shard_train_val(dataset, K, Nc, val_split=0.2, seed=42):
    """
    Non-IID sharding (label-skew) + local train/val split for FL.
    Each client receives Nc distinct classes (with no overlap),
    then locally splits its data into train/val according to val_split.
    Args:
        dataset: PyTorch Dataset (without applied transform)
        K: number of clients
        Nc: number of different classes per client
        val_split: local fraction for validation
        seed: seed for reproducibility
    Returns:
      client_data: {client_id: {'train': [idxs], 'val': [idxs]}}
    """

    rng = np.random.RandomState(seed)
    labels = np.array([dataset[i][1] for i in range(len(dataset))])
    n_classes = len(np.unique(labels))
    class_indices = {c: rng.permutation(np.where(labels == c)[0]).tolist() for c in range(n_classes)}
    # Generate shards by class
    shards_per_class = (K * Nc) // n_classes
    shards = []
    for c in range(n_classes):
        idxs = class_indices[c]
        shard_size = len(idxs) // shards_per_class
        for i in range(shards_per_class):
            shard = idxs[i*shard_size:(i+1)*shard_size]
            if len(shard) > 0:
                shards.append((c, shard))
    rng.shuffle(shards)
    # Assign Nc shards, each from a different class, to each client
    client_shards = {i: [] for i in range(K)}
    used = set()
    for i in range(K):
        chosen = []
        class_seen = set()
        for s_idx, (c, shard) in enumerate(shards):
            if c not in class_seen and s_idx not in used:
                chosen.append(s_idx)
                class_seen.add(c)
            if len(chosen) == Nc:
                break
        for s_idx in chosen:
            used.add(s_idx)
            client_shards[i].extend(shards[s_idx][1])
    # Local split train/val for each client
    client_data = {}
    for i in range(K):
        idxs = np.array(client_shards[i])
        rng.shuffle(idxs)
        n_val = int(len(idxs) * val_split)
        val_idxs = idxs[:n_val]
        train_idxs = idxs[n_val:]
        client_data[i] = {'train': train_idxs.tolist(), 'val': val_idxs.tolist()}
    return client_data


In [None]:
class DinoViT_CIFAR100(nn.Module):
    def __init__(self, num_classes=100):
        super().__init__()
        self.backbone = timm.create_model('vit_small_patch16_224.dino', pretrained=True)
        # the dimension of features for ViT-S/16 is always 384 (see doc timm/models/vit.py)
        self.backbone.head = nn.Identity()  # takes off head DINO
        self.classifier = nn.Linear(384, num_classes)
    def forward(self, x):
        # timm ViT returns (batch, 384) if the head is nn.Identity
        feats = self.backbone(x)   # (batch, 384)
        out = self.classifier(feats)
        return out

In [None]:
def _num_total_params(mask: Dict[str, torch.Tensor]) -> int:
    """Returns the total number of parameters (elements) across all tensors in the mask."""
    return sum(t.numel() for t in mask.values())

def _num_zero_params(mask: Dict[str, torch.Tensor]) -> int:
    """Returns the number of parameters set to zero in the mask (i.e., masked out)."""
    return sum((t == 0).sum().item() for t in mask.values())

def _compute_sparsity(mask: Dict[str, torch.Tensor]) -> float:
    """Returns the sparsity, i.e., the fraction of parameters that are masked (value in [0, 1])."""
    return _num_zero_params(mask) / _num_total_params(mask)


def _compute_approximated_fisher_scores(
    model: nn.Module,
    dataloader: DataLoader,
    loss_fn: nn.Module,
    device: torch.device,
    num_batches: Optional[int] = None,
    mask: Optional[Dict[str, torch.Tensor]] = None
):
    """
    Approximate the diagonal of the Fisher Information Matrix via empirical average.
    Args:
        model: torch.nn.Module
        dataloader: DataLoader (local client data)
        loss_fn: torch.nn loss function (e.g. nn.CrossEntropyLoss())
        device: torch.device
        num_batches: number of batches to use for approximation
    Returns:
        Dict {param_name: tensor of Fisher diagonal}
    """
    model.eval()
    fisher_diag = {
        name: torch.zeros_like(param, device=device)
        for name, param in model.named_parameters()
        if param.requires_grad
    }
    total_batches = len(dataloader) if num_batches is None else num_batches

    for batch_idx, (inputs, targets) in enumerate(
        tqdm(dataloader, total=total_batches, desc="Computing Fisher")
    ):
        if num_batches is not None and batch_idx >= num_batches:
            break

        inputs, targets = inputs.to(device), targets.to(device)
        model.zero_grad()
        outputs = model(inputs)
        loss = loss_fn(outputs, targets)
        loss.backward()

        for name, param in model.named_parameters():
            if param.grad is not None:
                fisher_diag[name] += param.grad.detach() ** 2
                if mask is not None:
                    fisher_diag[name] *= mask[name]

    for name in fisher_diag:
        fisher_diag[name] /= total_batches

    return fisher_diag


## Adjusted strategies

In [None]:
def calibrate_gradient_mask_progressive(
    model: nn.Module,
    dataloader: DataLoader,
    device: torch.device,
    sparsity: float = 0.9,
    rounds: int = 5,
    num_batches: Optional[int] = None,
    loss_fn: nn.Module = nn.CrossEntropyLoss(),
    approximate_fisher: bool = True,
    strategy: str = "train_least_important",
) -> Dict[str, torch.Tensor]:
    """
    Progressive mask calibration via Fisher scores, magnitude, or random strategy.
    Returns a dict {param_name: binary mask tensor}.
    """
    print("*" * 50)
    print(f"Progressive Mask Calibration - Strategy: {strategy}")
    print("*" * 50)

    model.to(device)

    mask = {
        n: torch.ones_like(p, device=device)
        for n, p in model.named_parameters()
        if p.requires_grad
    }

    for r in range(1, rounds + 1):
        print(f"[Round {r}]")

        # Score computation (only approximate Fisher supported here)
        if approximate_fisher:
            scores = _compute_approximated_fisher_scores(
                model=model,
                dataloader=dataloader,
                loss_fn=loss_fn,
                num_batches=num_batches,
                device=device,
                mask=mask,
            )
        else:
            raise NotImplementedError("Only approximate Fisher is implemented.")

         # 1. take every scores (to log, debug)
        all_scores = torch.cat([v.flatten() for v in scores.values()])
        # 2. Retain only the scores of parameters that remain active (i.e., where mask == 1)
        active_scores = torch.cat([
            score[mask[name] != 0].flatten()
            for name, score in scores.items()
        ])
        total_params = all_scores.numel()
        total_active_params = active_scores.numel()

        # Exponentially decrease keep_fraction for progressive pruning
        keep_fraction = (1-sparsity) ** (r / rounds)
        k = int(keep_fraction * total_params)
        print(f"Current keep fraction: {keep_fraction:.4f} | Keeping only top k: {k}")

        if strategy == "train_least_important":
            #To prevent bugs: ensure that k does not exceed the number of active parameters
            k = max(1, min(k, total_active_params))
            threshold, _ = torch.kthvalue(active_scores, k)
            print("Threshold (below which params are kept):", threshold)
            for name, score in scores.items():
                # Mask only newly selected parameters; keep previously zeroed (masked) ones unchanged
                new_mask = (score <= threshold).float()
                mask[name] = mask[name] * new_mask

        elif strategy == "train_most_important":
            # To prevent errors: k must not exceed the number of currently active parameters
            k = max(1, min(k, total_active_params))
            threshold, _ = torch.kthvalue(active_scores, k)
            print("Threshold (below which params are kept):", threshold)
            for name, score in scores.items():
                # Mask only new parameters; retain previously masked zeros
                new_mask = (score >= threshold).float()
                mask[name] = mask[name] * new_mask

        elif strategy == "random":
            for name, param in model.named_parameters():
                if param.requires_grad:
                    num_params = param.numel()
                    k = int((1 - sparsity) * num_params)
                    random_scores = torch.rand_like(param)
                    threshold, _ = torch.kthvalue(random_scores.flatten(), k)
                    mask[name] = (random_scores > threshold).float()

        elif strategy == "magnitude_most":
            for name, param in model.named_parameters():
                if param.requires_grad:
                    num_params = param.numel()
                    k = int((1 - sparsity) * num_params)
                    abs_weights = param.detach().abs()
                    threshold, _ = torch.kthvalue(abs_weights.flatten(), num_params - k + 1)
                    mask[name] = (abs_weights >= threshold).float()

        elif strategy == "magnitude_least":
            for name, param in model.named_parameters():
                if param.requires_grad:
                    num_params = param.numel()
                    k = int((1 - sparsity) * num_params)
                    abs_weights = param.detach().abs()
                    # Retain the k parameters with the smallest absolute values (lowest-magnitude weights)
                    threshold, _ = torch.kthvalue(abs_weights.flatten(), k)
                    mask[name] = (abs_weights <= threshold).float()


        else:
            raise ValueError(f"Unknown strategy: {strategy}")

        print(
            f"After round {r} mask sparsity: { _compute_sparsity(mask):.4f} "
            f"({_num_zero_params(mask)}/{_num_total_params(mask)} zeroed params)"
        )
        print()

    print("Progressive Mask Calibration completed.")
    return mask


## Sparsity

In [None]:
class SparseSGDM(SGD):
    def __init__(
        self,
        params: Iterable[torch.nn.Parameter],
        named_params: Dict[str, torch.nn.Parameter],
        lr: float,
        momentum: float = 0.0,
        weight_decay: float = 0.0,
        mask: Dict[str, torch.Tensor] = None,
    ):
        super().__init__(
            params,
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        self.mask = mask  # Dict {param_name: mask_tensor}
        self.named_params = named_params  # Dict {name: param}
        self.param_id_to_name = {id(p): n for n, p in named_params.items()}

    @torch.no_grad()
    def step(self, closure=None):
        if closure is not None:
            with torch.enable_grad():
                closure()

        for group in self.param_groups:
            for p in group["params"]:
                name = self.param_id_to_name.get(id(p))
                if p.grad is not None and self.mask is not None and name in self.mask:
                    p.grad.mul_(self.mask[name])

        return super().step(closure)


In [None]:
def sparse_fine_tune(
    model: nn.Module,
    dataloader,
    device,
    mask,
    lr=1e-3,
    epochs=1,
    momentum=0.9,
    weight_decay=5e-4,
):
    """
    Sparse fine-tuning of a model using a fixed binary mask.
    Only parameters where mask == 1 are updated (others are frozen).
    """
    model.to(device)
    # Set requires_grad according to the mask
    for name, param in model.named_parameters():
        if name in mask:
            param.requires_grad = (mask[name] == 1).any().item()
        else:
            param.requires_grad = False

    # Prepare params to optimize (only those with requires_grad)
    named_params = dict(model.named_parameters())
    params = [p for p  in named_params.values() if p.requires_grad]

    # SGD standard (no need for custom optimizer since masking is done by requires_grad)
    optimizer = SparseSGDM(
        params=params,
        named_params=named_params,
        lr=lr,
        momentum=momentum,
        weight_decay=weight_decay,
        mask=mask,
    )

    criterion = torch.nn.CrossEntropyLoss()
    model.train()
    for epoch in range(epochs):
        for inputs, targets in dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

    # Optional: reset requires_grad to True for all params if you reuse model elsewhere
    for param in model.parameters():
        param.requires_grad = True



## Adjusted Client Class

In [None]:
class Client:
    """
    Client class representing a federated learning participant.

    Each client holds a private dataset and performs local training or sparse fine-tuning
    on a copy of the global model. It supports standard local updates as well as
    gradient-masked sparse updates using a calibrated mask.
    """
    def __init__(self, client_id, dataset, device):
        self.client_id = client_id
        self.dataset = dataset
        self.device = device
        self.last_mask = None  # Store the mask if needed

    def calibrate_mask(
        self,
        model,
        sparsity_ratio=0.9,
        num_calib_rounds=5,
        batch_size=128,
        num_batches: Optional[int] = None,
        loss_fn=None,
        strategy: str = "train_least_important",
    ):
        """Calibrate a binary mask based on importance strategy (Fisher, magnitude, or random)."""
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
        if loss_fn is None:
            loss_fn = nn.CrossEntropyLoss()
        mask = calibrate_gradient_mask_progressive(
            model=model,
            dataloader=dataloader,
            device=self.device,
            sparsity=sparsity_ratio,
            rounds=num_calib_rounds,
            num_batches=num_batches,
            loss_fn=loss_fn,
            approximate_fisher=True,
            strategy=strategy,
        )
        self.last_mask = mask
        return mask

    def apply_mask_requires_grad(self, model, mask):
        """
        Sets requires_grad=True for params where mask == 1, False otherwise.
        """
        for name, param in model.named_parameters():
            if name in mask:
                param.requires_grad = (mask[name] == 1).any().item()
            else:
                param.requires_grad = False

    def sparse_fine_tune(
        self,
        model,
        mask,
        lr=1e-3,
        epochs=1,
        batch_size=128,
        momentum=0.9,
        weight_decay=5e-4,
    ):
        """
        Sparse fine-tuning: only params with requires_grad=True (i.e. mask==1) are updated.
        """
        self.apply_mask_requires_grad(model, mask)
        dataloader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)

        optimizer = torch.optim.SGD(
            filter(lambda p: p.requires_grad, model.parameters()),
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
        )
        criterion = nn.CrossEntropyLoss()
        model.train()
        for epoch in range(epochs):
            for inputs, targets in dataloader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()

        # Reset requires_grad
        for param in model.parameters():
            param.requires_grad = True

    def local_train(
        self,
        global_model,
        epochs,
        batch_size,
        lr,
        momentum,
        weight_decay,
        scheduler_fn=None,
        use_sparse=False,
        sparsity_ratio=0.9,
        num_calib_rounds=5,
        num_batches: Optional[int] = None,
        sparse_ft_epochs=1,
        strategy: str = "train_least_important",
    ):
        """
        Performs standard local training or (if use_sparse) sparse fine-tuning.
        """
        model = DinoViT_CIFAR100(num_classes=100).to(self.device)
        model.load_state_dict(global_model.state_dict())

        if use_sparse:
            mask = self.calibrate_mask(
                model,
                sparsity_ratio=sparsity_ratio,
                num_calib_rounds=num_calib_rounds,
                batch_size=batch_size,
                num_batches=num_batches,
                strategy=strategy,  # ← AJOUT
            )
            self.sparse_fine_tune(
                model,
                mask,
                lr=lr,
                epochs=sparse_ft_epochs,
                batch_size=batch_size,
                momentum=momentum,
                weight_decay=weight_decay,
            )
        else:
            # Standard local training
            loader = DataLoader(self.dataset, batch_size=batch_size, shuffle=True)
            optimizer = torch.optim.SGD(
                filter(lambda p: p.requires_grad, model.parameters()),
                lr=lr,
                momentum=momentum,
                weight_decay=weight_decay,
            )
            scheduler = scheduler_fn(optimizer) if scheduler_fn else None
            criterion = nn.CrossEntropyLoss()
            model.train()
            for epoch in range(epochs):
                for X, y in loader:
                    X, y = X.to(self.device), y.to(self.device)
                    optimizer.zero_grad()
                    loss = criterion(model(X), y)
                    loss.backward()
                    optimizer.step()
                if scheduler:
                    scheduler.step()

        return model.state_dict()


## Adjusted Federated Trainer

In [None]:
class FederatedTrainer:
    """
    Orchestrates federated learning (FedAvg) with optional sparse model editing.
    Supports both IID and non-IID client splits (weighted aggregation).
    """

    def __init__(
        self,
        clients,
        global_model,
        device,
        client_fraction,
        local_epochs,
        batch_size,
        lr,
        momentum,
        weight_decay,
        scheduler_fn=None,
        use_sparse=False,
        sparsity_ratio=0.9,
        num_calib_rounds=5,
        num_batches: Optional[int] = None,
        sparse_ft_epochs=1,
        strategy: str = "train_least_important",
    ):
        self.clients = clients
        self.global_model = global_model
        self.device = device
        self.client_fraction = client_fraction
        self.local_epochs = local_epochs
        self.batch_size = batch_size
        self.lr = lr
        self.momentum = momentum
        self.weight_decay = weight_decay
        self.scheduler_fn = scheduler_fn

        self.use_sparse = use_sparse
        self.sparsity_ratio = sparsity_ratio
        self.num_calib_rounds = num_calib_rounds
        self.num_batches = num_batches
        self.sparse_ft_epochs = sparse_ft_epochs
        self.strategy = strategy

    def aggregate_weights(self, client_states, client_sizes):
        """
        Weighted average (FedAvg) of the selected client weights.
        client_states: list of state_dicts (one per client)
        client_sizes: list of int (number of samples per client)
        """
        total = sum(client_sizes)
        avg_state = {}
        for key in client_states[0].keys():
            weighted_sum = sum(state[key].float() * size for state, size in zip(client_states, client_sizes))
            avg_state[key] = weighted_sum / total
        return avg_state

    def train_round(self):
        """
        Runs one FedAvg round with optional model editing (sparse fine-tune).
        Aggregates using sample-weighted mean (FedAvg-style).
        """
        num_clients = len(self.clients)
        m = max(int(self.client_fraction * num_clients), 1)
        selected = np.random.choice(self.clients, m, replace=False)
        client_states = []
        client_sizes = []

        for client in selected:
            client_state = client.local_train(
                global_model=self.global_model,
                epochs=self.local_epochs,
                batch_size=self.batch_size,
                lr=self.lr,
                momentum=self.momentum,
                weight_decay=self.weight_decay,
                scheduler_fn=self.scheduler_fn,
                use_sparse=self.use_sparse,
                sparsity_ratio=self.sparsity_ratio,
                num_calib_rounds=self.num_calib_rounds,
                num_batches=self.num_batches,
                sparse_ft_epochs=self.sparse_ft_epochs,
                strategy=self.strategy  # ← AJOUT
            )
            client_states.append(client_state)
            client_sizes.append(len(client.dataset))

        avg_state = self.aggregate_weights(client_states, client_sizes)
        self.global_model.load_state_dict(avg_state)

    def fit(self, n_rounds, eval_fn=None, eval_every=1):
        for rnd in range(1, n_rounds + 1):
            print(f'---- FedAvg Round {rnd} {"(SPARSE-EDITING)" if self.use_sparse else ""} ----')
            self.train_round()
            if eval_fn and (rnd % eval_every == 0 or rnd == n_rounds):
                acc, loss = eval_fn(self.global_model)
                print(f'[Round {rnd}] Eval: Acc={acc:.3f} | Loss={loss:.3f}')


## Other Utils

In [None]:
def evaluate(model, dataloader, device):
    """
    Evaluates a classification model on a given dataset.
    Returns (accuracy, average_loss).
    """
    model = model.to(device)
    model.eval()
    criterion = nn.CrossEntropyLoss()
    correct, total, total_loss = 0, 0, 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            outputs = model(X)
            loss = criterion(outputs, y)
            total_loss += loss.item() * X.size(0)
            _, preds = outputs.max(1)
            correct += (preds == y).sum().item()
            total += X.size(0)
    if total == 0:
        return 0.0, 0.0
    return correct / total, total_loss / total


## Hyperparameters Research

In [None]:
# Hyperparams
K = 100
val_split = 0.2
seed = 42
Nc = 50

# --- Dataset (no transform here)
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- Freeze the IID sharding (client split and local train/val split)
iid_split = iid_shard_train_val(full_train, K=K, val_split=val_split, seed=seed)

#--- Freeze the NON-IID sharding (client split and local train/val split)
non_iid_split = non_iid_shard_train_val(full_train, K=K, Nc=Nc, val_split=val_split, seed=seed)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparams
K = 100
C = 0.1
J = 4
n_rounds = 10
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

# Param grid
sparsity_ratios = [0.85, 0.90, 0.95]
num_calib_rounds_list = [1]
sparse_ft_epochs = 1
strategy = "magnitude_least"  # or "random", "magnitude_most", "magnitude_least"


# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Data (no split/transform here) ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)
iid_split = iid_shard_train_val(full_train, K=K, val_split=0.2, seed=42)

# --- Wrapper for transforms after split ---
from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

# --- Validation globale ---
val_indices = np.concatenate([iid_split[i]['val'] for i in range(K)])
val_set = TransformedSubset(Subset(full_train, val_indices), val_transform)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=128, shuffle=False)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_edit(model):
    return evaluate(model, val_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']}) from {path}")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def save_best_hyperparams(acc_history, config, path):
    best_acc = max(acc_history) if acc_history else 0.0
    run_data = {"best_val_acc": best_acc}
    run_data.update(config)
    with open(path, "a") as f:
        f.write(json.dumps(run_data) + "\n")

def fit_with_wandb_and_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=None, best_json_path=None, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume and checkpoint_path:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (SPARSE-EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            acc, loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Eval: Acc={acc:.3f} | Loss={loss:.3f}')
            wandb.log({"round": rnd, "val_acc": acc, "val_loss": loss})
            acc_history.append(acc)
            loss_history.append(loss)
            if checkpoint_path and (rnd % 5 == 0 or rnd == n_rounds):
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
    if best_json_path:
        save_best_hyperparams(acc_history, wandb.config, best_json_path)

# === GRIDSEARCH MODEL EDITING FL-READY ===
run_idx = 0
for sparsity_ratio in sparsity_ratios:
    for num_calib_rounds in num_calib_rounds_list:
        run_idx += 1
        print(f"\n=== MODEL EDITING RUN {run_idx}/3 ===\n")
        clients_edit = []
        for i in range(K):
            train_idxs = iid_split[i]['train']
            client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
            clients_edit.append(Client(i, client_train_dataset, device))

        global_model = DinoViT_CIFAR100(num_classes=100).to(device)

        run_name = (f"model_editing_iid_{strategy}_nrounds{n_rounds}_lr{lr}_"
                    f"sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}_ftep{sparse_ft_epochs}")
        checkpoint_path = f"{run_name}.pt"
        best_json_path = f"{run_name}.json"

        wandb.init(
            project="fl-fedavg-personnal-contribution",
            name=run_name,
            config={
                "model": "DINO ViT-S/16",
                "K": K,
                "C": C,
                "J": J,
                "n_rounds": n_rounds,
                "batch_size": batch_size,
                "lr": lr,
                "momentum": momentum,
                "weight_decay": weight_decay,
                "sharding": "iid",
                "use_sparse": True,
                "sparsity_ratio": sparsity_ratio,
                "num_calib_rounds": num_calib_rounds,
                "sparse_ft_epochs": sparse_ft_epochs,
                "strategy" : strategy
            }
        )

        trainer_edit = FederatedTrainer(
            clients=clients_edit,
            global_model=global_model,
            device=device,
            client_fraction=C,
            local_epochs=J,
            batch_size=batch_size,
            lr=lr,
            momentum=momentum,
            weight_decay=weight_decay,
            scheduler_fn=make_scheduler,
            use_sparse=True,
            sparsity_ratio=sparsity_ratio,
            num_calib_rounds=num_calib_rounds,
            sparse_ft_epochs=sparse_ft_epochs,
            strategy=strategy
        )
        trainer_edit.fit = types.MethodType(fit_with_wandb_and_logs, trainer_edit)
        trainer_edit.fit(
            n_rounds,
            eval_fn=eval_fn_edit,
            eval_every=2,
            checkpoint_path=checkpoint_path,
            best_json_path=best_json_path,
            resume=True
        )
        wandb.finish()



## Test Accuracy runs

In [None]:
# Hyperparams
K = 100
val_split = 0
seed = 42
Nc = 50

# --- Dataset (no transform here)
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- Freeze sharding IID (split clients and local train/val = 0 here)
iid_split = iid_shard_train_val(full_train, K=K, val_split=val_split, seed=seed)

# --- freeze sharding NON-IID (split clients and local train/val = 0 here)
non_iid_split = non_iid_shard_train_val(full_train, K=K, Nc=Nc, val_split=val_split, seed=seed)


In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, MODEL EDITING IID TEST ACC ======

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 20
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

strategy = "train_least_important"
# --- Model Editing HP choosen ---
sparsity_ratio = 0.85
num_calib_rounds = 3
sparse_ft_epochs = 1

CHECKPOINT_PATH = "model_editing_iid_checkpoint.pt"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Loads CIFAR-100 brut (no transform here) ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: IID sharding + local train ---
client_data = iid_split

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

# ---Instantiate clients prepared for FL (local training only) ---
clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

# --- Official test set ---
test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_test(model):
    return evaluate(model, test_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    if resume:
        start_round, acc_history, loss_history = load_checkpoint(self.global_model, checkpoint_path)
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (MODEL EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)

# --- WANDB init ---
wandb.init(
    project="fl-fedavg-personnal-contribution",
    name=f"model_editing_iid_test_acc_{strategy}_J{J}_nrounds{n_rounds}_lr{lr}_sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "iid",
        "Nc": None,
        "use_sparse": True,
        "sparsity_ratio": sparsity_ratio,
        "num_calib_rounds": num_calib_rounds,
        "sparse_ft_epochs": sparse_ft_epochs,
        "strategy" : strategy
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=True,
    sparsity_ratio=sparsity_ratio,
    num_calib_rounds=num_calib_rounds,
    sparse_ft_epochs=sparse_ft_epochs,
    strategy = strategy
)

# --- Patch et run ---
trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()


In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, MODEL EDITING IID TEST ACC ======

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 20
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

strategy = "train_most_important"
# --- Model Editing HP choosen ---
sparsity_ratio = 0.90
num_calib_rounds = 5
sparse_ft_epochs = 1


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Chargement CIFAR-100 brut (pas de transform ici) ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: IID sharding + train local ---
client_data = iid_split

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

# --- Instanciation des clients FL-ready (train local uniquement) ---
clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

# --- Test set officiel (wrap pour normalisation identique à val) ---
test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_test(model):
    return evaluate(model, test_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (MODEL EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)

# --- WANDB init ---
wandb.init(
    project="fl-fedavg-personnal-contribution",
    name=f"model_editing_iid_test_acc_{strategy}_J{J}_nrounds{n_rounds}_lr{lr}_sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "iid",
        "Nc": None,
        "use_sparse": True,
        "sparsity_ratio": sparsity_ratio,
        "num_calib_rounds": num_calib_rounds,
        "sparse_ft_epochs": sparse_ft_epochs,
        "strategy" : strategy
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=True,
    sparsity_ratio=sparsity_ratio,
    num_calib_rounds=num_calib_rounds,
    sparse_ft_epochs=sparse_ft_epochs,
    strategy = strategy
)

# --- Patch et run ---
trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()


In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, MODEL EDITING IID TEST ACC ======

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 20
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

strategy = "random"
# --- Model Editing HP choosen ---
sparsity_ratio = 0.90
num_calib_rounds = 1
sparse_ft_epochs = 1


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Loads CIFAR-100 brut (no transform here) ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: IID sharding + train local ---
client_data = iid_split

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

# ---Instantiate clients prepared for FL (local training only) ---
clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

# --- Official test set ---
test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_test(model):
    return evaluate(model, test_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (MODEL EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)

# --- WANDB init ---
wandb.init(
    project="fl-fedavg-personnal-contribution",
    name=f"model_editing_iid_test_acc_{strategy}_J{J}_nrounds{n_rounds}_lr{lr}_sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "iid",
        "Nc": None,
        "use_sparse": True,
        "sparsity_ratio": sparsity_ratio,
        "num_calib_rounds": num_calib_rounds,
        "sparse_ft_epochs": sparse_ft_epochs,
        "strategy" : strategy
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=True,
    sparsity_ratio=sparsity_ratio,
    num_calib_rounds=num_calib_rounds,
    sparse_ft_epochs=sparse_ft_epochs,
    strategy = strategy
)

# --- Patch et run ---
trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()


In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, MODEL EDITING IID TEST ACC ======

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 20
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

strategy = "magnitude_most"
# --- Model Editing HP choosen ---
sparsity_ratio = 0.85
num_calib_rounds = 1
sparse_ft_epochs = 1


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Loads CIFAR-100 brut (no transform here) ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: IID sharding + train local ---
client_data = iid_split

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

# ---Instantiate clients prepared for FL (local training only) ---
clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

# --- Official test set ---
test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_test(model):
    return evaluate(model, test_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (MODEL EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)

# --- WANDB init ---
wandb.init(
    project="fl-fedavg-personnal-contribution",
    name=f"model_editing_iid_test_acc_{strategy}_J{J}_nrounds{n_rounds}_lr{lr}_sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "iid",
        "Nc": None,
        "use_sparse": True,
        "sparsity_ratio": sparsity_ratio,
        "num_calib_rounds": num_calib_rounds,
        "sparse_ft_epochs": sparse_ft_epochs,
        "strategy" : strategy
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=True,
    sparsity_ratio=sparsity_ratio,
    num_calib_rounds=num_calib_rounds,
    sparse_ft_epochs=sparse_ft_epochs,
    strategy = strategy
)

# --- Patch et run ---
trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()


In [None]:
# ====== FL READY: DATA, CLIENTS, WANDB, MODEL EDITING IID TEST ACC ======

# --- FL params ---
K = 100
C = 0.1
J = 4
n_rounds = 20
batch_size = 128
lr = 0.001
momentum = 0.8
weight_decay = 5e-4

strategy = "magnitude_least"
# --- Model Editing HP choosen ---
sparsity_ratio = 0.90
num_calib_rounds = 1
sparse_ft_epochs = 1


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Transforms (ImageNet style for ViT/DINO) ---
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomCrop(224, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])
val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# --- Loads CIFAR-100 brut (no transform here) ---
full_train = torchvision.datasets.CIFAR100(root="./data", train=True, download=True)
test_set = torchvision.datasets.CIFAR100(root="./data", train=False, download=True)

# --- FL split: IID sharding + train local ---
client_data = iid_split

from torch.utils.data import Subset, Dataset

class TransformedSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        x = self.transform(x)
        return x, y
    def __len__(self):
        return len(self.subset)

# ---Instantiate clients prepared for FL (local training only) ---
clients = []
for i in range(K):
    train_idxs = client_data[i]['train']
    client_train_dataset = TransformedSubset(Subset(full_train, train_idxs), train_transform)
    clients.append(Client(i, client_train_dataset, device))

# --- Official test set ---
test_loader = torch.utils.data.DataLoader(
    TransformedSubset(test_set, val_transform), batch_size=128, shuffle=False
)

# --- Global model ViT-S/16 DINO CIFAR-100 ---
global_model = DinoViT_CIFAR100(num_classes=100).to(device)

def make_scheduler(optimizer):
    return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=J)

def eval_fn_test(model):
    return evaluate(model, test_loader, device)

def save_checkpoint(model, round_idx, acc_history, loss_history, path=CHECKPOINT_PATH):
    checkpoint = {
        "round": round_idx,
        "model_state": model.state_dict(),
        "acc_history": acc_history,
        "loss_history": loss_history
    }
    torch.save(checkpoint, path)

def load_checkpoint(model, path=CHECKPOINT_PATH):
    if os.path.exists(path):
        checkpoint = torch.load(path, map_location=device)
        model.load_state_dict(checkpoint["model_state"])
        print(f"Checkpoint loaded (round {checkpoint['round']})")
        return checkpoint["round"], checkpoint["acc_history"], checkpoint["loss_history"]
    return 0, [], []

def plot_history(acc_history, loss_history, eval_every):
    rounds = np.arange(0, len(acc_history))*eval_every + eval_every
    plt.figure(figsize=(8,4))
    plt.subplot(1,2,1)
    plt.plot(rounds, acc_history, label='Test Acc')
    plt.xlabel('Round')
    plt.ylabel('Accuracy')
    plt.title('Test Accuracy')
    plt.subplot(1,2,2)
    plt.plot(rounds, loss_history, label='Test Loss')
    plt.xlabel('Round')
    plt.ylabel('Loss')
    plt.title('Test Loss')
    plt.tight_layout()
    plt.show()

def fit_with_all_logs(self, n_rounds, eval_fn=None, eval_every=1, checkpoint_path=CHECKPOINT_PATH, resume=False):
    start_round, acc_history, loss_history = (0, [], [])
    for rnd in range(start_round+1, n_rounds+1):
        print(f'---- FedAvg Round {rnd} (MODEL EDITING) ----')
        self.train_round()
        if eval_fn and rnd % eval_every == 0:
            test_acc, test_loss = eval_fn(self.global_model)
            print(f'[Round {rnd}] Test Acc={test_acc:.3f} | Test Loss={test_loss:.3f}')
            wandb.log({"round": rnd, "test_acc": test_acc, "test_loss": test_loss})
            acc_history.append(test_acc)
            loss_history.append(test_loss)
            if rnd % 5 == 0 or rnd == n_rounds:
                save_checkpoint(self.global_model, rnd, acc_history, loss_history, checkpoint_path)
        if rnd % 5 == 0 or rnd == n_rounds:
            plot_history(acc_history, loss_history, eval_every)

# --- WANDB init ---
wandb.init(
    project="fl-fedavg-personnal-contribution",
    name=f"model_editing_iid_test_acc_{strategy}_J{J}_nrounds{n_rounds}_lr{lr}_sp{int(sparsity_ratio*100)}_calib{num_calib_rounds}",
    config={
        "model": "DINO ViT-S/16",
        "K": K,
        "C": C,
        "J": J,
        "n_rounds": n_rounds,
        "batch_size": batch_size,
        "lr": lr,
        "momentum": momentum,
        "weight_decay": weight_decay,
        "sharding": "iid",
        "Nc": None,
        "use_sparse": True,
        "sparsity_ratio": sparsity_ratio,
        "num_calib_rounds": num_calib_rounds,
        "sparse_ft_epochs": sparse_ft_epochs,
        "strategy" : strategy
    }
)

# --- Federated trainer FL-ready ---
trainer = FederatedTrainer(
    clients=clients,
    global_model=global_model,
    device=device,
    client_fraction=C,
    local_epochs=J,
    batch_size=batch_size,
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay,
    scheduler_fn=make_scheduler,
    use_sparse=True,
    sparsity_ratio=sparsity_ratio,
    num_calib_rounds=num_calib_rounds,
    sparse_ft_epochs=sparse_ft_epochs,
    strategy = strategy
)

# --- Patch et run ---
trainer.fit = types.MethodType(fit_with_all_logs, trainer)
trainer.fit(n_rounds, eval_fn=eval_fn_test, eval_every=2, checkpoint_path=CHECKPOINT_PATH, resume=True)
wandb.finish()
