In [1]:
# import os
# import shutil

# for item in os.listdir('.'):
#     if os.path.isfile(item) or os.path.islink(item):
#         os.remove(item)
#     elif os.path.isdir(item):
#         shutil.rmtree(item)
!git clone https://github.com/sathishkumar67/PODFCSSV.git
!mv PODFCSSV/* /kaggle/working

Cloning into 'PODFCSSV'...
remote: Enumerating objects: 261, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 261 (delta 0), reused 2 (delta 0), pack-reused 258 (from 1)[K
Receiving objects: 100% (261/261), 4.00 MiB | 18.97 MiB/s, done.
Resolving deltas: 100% (112/112), done.


In [11]:
from __future__ import annotations
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union, List, Any
from transformers import PreTrainedModel, ViTMAEForPreTraining

class IBA_Adapter(nn.Module):
    """
    Information-Bottlenecked Adapter (IBA) Module for Efficient parameter-efficient Fine-Tuning.

    Overview
    --------
    The IBA Adapter is a lightweight neural network module designed to be inserted 
    into pre-trained frozen backbones (like ViT or BERT). It introduces a small 
    number of trainable parameters to adapt the model to new tasks (or domains) 
    without retraining the entire massive network.

    Architectural Design
    --------------------
    The adapter follows a "Bottleneck" structure to minimize parameter count while 
    maximizing adaptation capability. Ideally, it compresses high-dimensional 
    semantic features into a compact representation and then reconstructs them.

    Structure:
        Input (D) -> Down-Projection (d) -> Non-Linearity -> Up-Projection (D) -> Dropout -> + Residual

    Key Design Principles
    ---------------------
    1.  **Information Bottleneck**: By projecting high-dimensional features (D) 
        down to a smaller dimension (d), the model is forced to learn only the 
        most salient features relevant to the specific task, ignoring noise.
    
    2.  **Identity Initialization**: A critical stability feature for Federated Learning.
        -   The Up-Projection layer is initialized with **zeros**.
        -   This ensures that at initialization (step 0), the adapter output is exactly 0.
        -   Result: `Layer(x) + Adapter(x) = Layer(x) + 0 = Layer(x)`.
        -   This prevents "catastrophic forgetting" or "semantic shock" where random 
            initialization would distort the carefully learned features of the 
            pre-trained backbone.

    Attributes
    ----------
    input_dim : int
        The dimensionality of the input features (Hidden Size of the backbone).
    bottleneck_dim : int
        The dimensionality of the compressed bottleneck space.
    down_project : nn.Linear
        Linear layer reducing dimension from `input_dim` to `bottleneck_dim`.
    activation : nn.Module
        Non-linear activation function (e.g., GELU, ReLU) to enable learning complex patterns.
    up_project : nn.Linear
        Linear layer restoring dimension from `bottleneck_dim` back to `input_dim`.
    dropout : nn.Dropout
        Dropout layer for regularization during training.
    """

    def __init__(
        self, 
        input_dim: int, 
        bottleneck_dim: int = 64, 
        dropout: float = 0.0,
        activation: nn.Module = nn.GELU()
    ) -> None:
        """
        Initializes the IBA Adapter with the specified configuration.

        Parameters
        ----------
        input_dim : int
            The hidden size of the pre-trained model (e.g., 768 for ViT-Base).
        bottleneck_dim : int, optional
            The size of the bottleneck. Smaller values result in fewer parameters 
            but may limit capacity. Defaults to 64.
        dropout : float, optional
            Dropout probability applied to the output of the adapter. Defaults to 0.0.
        activation : nn.Module, optional
            The activation function to use within the bottleneck. Defaults to nn.GELU().
        """
        super().__init__()
        self.input_dim = input_dim
        self.bottleneck_dim = bottleneck_dim
        self.activation = activation

        # 1. Down-Projection Layer
        # Compresses the input semantic vector into the bottleneck space (D -> d).
        self.down_project = nn.Linear(input_dim, bottleneck_dim, bias=True)
        
        # 2. Up-Projection Layer
        # Reconstructs the semantic vector from the bottleneck space (d -> D).
        self.up_project = nn.Linear(bottleneck_dim, input_dim, bias=True)
        
        # 3. Regularization
        self.dropout = nn.Dropout(dropout)
        
        # 4. Weight Initialization
        # Apply strict initialization rules to ensure stable convergence.
        self._init_weights()

    def _init_weights(self) -> None:
        """
        Applies robust initialization strategies for the adapter layers.

        Initialization Strategy
        -----------------------
        1.  **Down-Projection**: 
            -   **Weights**: Kaiming Normal (He Initialization) with 'relu' nonlinearity mode. 
                This maintains the variance of activations through the layer, preventing 
                vanishing/exploding gradients in the bottleneck.
            -   **Bias**: Initialized to Zero.

        2.  **Up-Projection**:
            -   **Weights & Bias**: Zero Initialization. 
            -   **Reasoning**: This ensures the adapter contributes nothing (0) at the 
                very start of training. The model initially behaves exactly like the 
                original frozen backbone, allowing the adapter to gradually learn 
                modifications rather than starting with random noise.
        """
        # A. Down-Projection Initialization
        # We use 'mode=fan_out' and 'nonlinearity=relu' as a robust default for linear layers followed by activations.
        nn.init.kaiming_normal_(self.down_project.weight, mode='fan_out', nonlinearity='relu')
        if self.down_project.bias is not None:
            nn.init.zeros_(self.down_project.bias)
        
        # B. Up-Projection Initialization (Identity Init)
        # This is the most critical step for stability in fine-tuning.
        nn.init.zeros_(self.up_project.weight)
        if self.up_project.bias is not None:
            nn.init.zeros_(self.up_project.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Executes the forward pass of the adapter module.

        The flow is:
        Input -> [Down Project] -> [Activation] -> [Up Project] -> [Dropout] -> + Input (Residual)

        Parameters
        ----------
        x : torch.Tensor
            The input hidden states from the transformer block.
            Shape: [Batch_Size, Sequence_Length, Hidden_Dimension]

        Returns
        -------
        torch.Tensor
            The adapted hidden states, with the exact same shape as the input.
        """
        # Save the original input for the residual connection
        residual = x
        
        # 1. Compression: Project down to bottleneck dimension
        x = self.down_project(x)
        
        # 2. Non-Linearity: Apply activation function
        x = self.activation(x)
        
        # 3. Reconstruction: Project back up to original dimension
        x = self.up_project(x)
        
        # 4. Regularization: Apply dropout
        x = self.dropout(x)
        
        # 5. Residual Connection: Add the learned delta to the original features
        return residual + x

    def __repr__(self) -> str:
        """
        Returns a string representation of the module for debugging purposes.
        """
        return f"IBA_Adapter(in_features={self.input_dim}, bottleneck={self.bottleneck_dim})"


class ViTBlockWithAdapter(nn.Module):
    """
    Wrapper Module to Inject an Adapter into a Frozen Transformer Block.

    Purpose
    -------
    This class wraps an existing (frozen) `ViTLayer` or `BertLayer` from the 
    Hugging Face library. It intercepts the forward pass, allows the original 
    block to process the input, and then applies the `IBA_Adapter` to the output 
    hidden states.

    It ensures compatibility with Hugging Face's complex return types 
    (tuples vs ModelOutput objects) so that the rest of the model pipeline 
    remains unaware of the modification.
    """

    def __init__(self, original_block: nn.Module, adapter: IBA_Adapter) -> None:
        """
        Wraps a transformer block with an adapter.

        Parameters
        ----------
        original_block : nn.Module
            The original, frozen Transformer block (e.g., `ViTLayer`).
        adapter : IBA_Adapter
            The trainable adapter instance to be applied after the block.
        """
        super().__init__()
        self.original_block = original_block
        self.adapter = adapter

    def forward(
        self, 
        hidden_states: torch.Tensor,
        *args,
        **kwargs
    ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Any]]:
        """
        Forward pass that mimics the signature of a standard Hugging Face ViTLayer.
        
        Note: arguments like `head_mask` or `output_attentions` are implicitly handled 
        or omitted based on the specific requirements of the backbone (e.g., ViTMAE 
        does not support `head_mask`).

        Parameters
        ----------
        hidden_states : torch.Tensor
            Input tensor of shape [Batch, SeqLen, Dim].
        args : tuple
            Variable positional arguments required by the pipeline.
        kwargs : dict
            Variable keyword arguments required by the pipeline.

        Returns
        -------
        Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Any]]
            The output tuple expected by the transformer model, containing the 
            adapted hidden states and potentially attention weights.
        """
        # 1. Execute the Original Frozen Block
        # We explicitly ignored *args and **kwargs (like head_mask) on purpose 
        # because ViTMAE layers typically reject them.
        outputs = self.original_block(hidden_states)
        
        # 2. Extract the Hidden States
        # Hugging Face models can return:
        # - A tuple: (hidden_states, attention_weights, ...)
        # - A ModelOutput object (like BaseModelOutput)
        # - A raw Tensor
        if isinstance(outputs, tuple):
            x = outputs[0]
        elif hasattr(outputs, "hidden_states"):
            x = outputs.hidden_states
        else:
            x = outputs
        
        # 3. Apply the Adapter
        # The adapter modifies the features in-place (conceptually) via residual connection.
        x = self.adapter(x)
        
        # 4. Repackage Result
        # We must return exactly what the parent model expects to avoid breaking the pipeline.
        if isinstance(outputs, tuple):
            # Reconstruct the tuple: (new_hidden_states, *rest_of_tuple)
            return (x,) + outputs[1:]
        elif hasattr(outputs, "hidden_states"):
            # If it's a ModelOutput object, we try to update it.
            # Some objects are immutable or downstream layers check strict types.
            try:
                outputs.hidden_states = x
                return outputs
            except:
                # Fallback: Return a tuple, which HF pipelines usually accept as a valid alternative.
                return (x,) 
        else:
            # If input was just a Tensor, return the new Tensor.
            return x


def inject_adapters(model: PreTrainedModel, bottleneck_dim: int = 64) -> PreTrainedModel:
    """
    Core Utility: Injects IBA Adapters into the Encoder of a Pre-trained Model.

    This function performs the precise surgery needed to convert a standard 
    pre-trained model (like ViTMAE) into an adapter-tuned model.

    Procedure
    ---------
    1.  **Freeze Backbone**: Sets `requires_grad=False` for ALL original parameters.
    2.  **Locate Encoder**: Identifies the list of transformer layers (`encoder.layer`).
    3.  **Inject Adapters**:
        -   Iterates through each layer.
        -   Creates a new `IBA_Adapter` matching the layer's dimensions.
        -   Wraps the original layer in `ViTBlockWithAdapter`.
        -   Replaces the layer in the model's module list.
    4.  **Activate Adapters**: Ensures only the new adapter parameters are trainable.

    Parameters
    ----------
    model : PreTrainedModel
        The Hugging Face model instance (e.g., `ViTMAEForPreTraining`).
    bottleneck_dim : int, optional
        The dimension of the adapter bottleneck. Defaults to 64.

    Returns
    -------
    PreTrainedModel
        The modified model instance with adapters injected and backbone frozen.

    Raises
    ------
    AttributeError
        If the model structure is not recognized (i.e., cannot find the encoder layers).
    """
    print(f"\n{'='*60}")
    print(f"[System] Starting Adapter Injection Procedure")
    print(f"{'='*60}")

    # 1. Freeze the entire model backbone
    # This ensures we don't destroy the pre-trained knowledge during fine-tuning.
    print("[Config] Freezing original backbone parameters...")
    for param in model.parameters():
        param.requires_grad = False
        
    # 2. Locate the Encoder Module
    # We inspect the model structure to find where the Transformer layers live.
    if hasattr(model, "vit") and hasattr(model.vit, "encoder"):
        # Standard ViTMAE structure (Hugging Face)
        encoder = model.vit.encoder
        config = model.config
    elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
        # Generic BERT/ViT structure fallback
        encoder = model.encoder
        config = model.config
    else:
        raise AttributeError(
            "Could not locate 'encoder.layer'. "
            "Model structure unknown (expected 'vit.encoder' or 'encoder')."
        )

    input_dim = config.hidden_size
    num_layers = len(encoder.layer)

    print(f"[Config] Model Config: Hidden Dim={input_dim}, Layers={num_layers}")
    print(f"[Config] Adapter Config: Bottleneck Dim={bottleneck_dim}")

    # 3. Iterate and Replace
    print("[Action] Injecting adapters into encoder layers...")
    
    for i, layer in enumerate(encoder.layer):
        # Create the adapter instance
        adapter = IBA_Adapter(input_dim=input_dim, bottleneck_dim=bottleneck_dim)
        
        # CRITICAL: Move adapter to the correct device/dtype.
        # This handles cases where the model is already on GPU or in FP16/BF16.
        # We take the first parameter of the layer as a reference.
        ref_param = next(layer.parameters())
        adapter.to(device=ref_param.device, dtype=ref_param.dtype)
        
        # Wrap the original layer with our adapter-enabled wrapper
        wrapped_layer = ViTBlockWithAdapter(original_block=layer, adapter=adapter)
        
        # Perform the replacement in the ModuleList
        encoder.layer[i] = wrapped_layer
        
        # Progress logging
        if (i + 1) % 4 == 0 or (i + 1) == num_layers:
            print(f"  -> Processed layer {i + 1}/{num_layers}")

    print(f"[System] Injection Complete. Decoder layers ignored (if present).")
    
    # 4. Verification
    # Print a summary of trainable vs frozen parameters to confirm success.
    count_trainable_params(model)
    
    return model


def count_trainable_params(model: nn.Module) -> None:
    """
    Audit Utility: Prints the distribution of Frozen vs Trainable parameters.
    
    Useful for verifying that:
    1.  The backbone is indeed frozen (0 gradients).
    2.  The adapters are trainable (requires_grad=True).
    
    Parameters
    ----------
    model : nn.Module
        The model to inspect.
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    
    ratio = (trainable_params / total_params) * 100 if total_params > 0 else 0
    
    print(f"\n[Stats] Parameter Audit:")
    print(f"  - Total Parameters:     {total_params:,}")
    print(f"  - Frozen Backbone:      {frozen_params:,}")
    print(f"  - Trainable (Adapters): {trainable_params:,}")
    print(f"  - Trainable Ratio:      {ratio:.2f}%")
    print(f"{'='*60}\n")


# =============================================================================
# Main Execution Block (Integration Test)
# =============================================================================
if __name__ == "__main__":
    """
    Test Script to verify the Adapter Injection pipeline.
    
    Steps:
    1.  Load a real ViTMAE model from Hugging Face.
    2.  Inject Adapters.
    3.  Run a dummy forward pass to check for shape mismatches or runtime errors.
    """
    print("[Main] Loading pre-trained ViTMAE...")
    try:
        # NOTE: Requires `pip install transformers`
        model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
        
        # Inject Adapters
        model = inject_adapters(model, bottleneck_dim=64)
        
        # Sanity Check: Forward pass
        print("[Main] Running dummy forward pass to verify graph integrity...")
        dummy_input = torch.randn(1, 3, 224, 224)
        
        # Move inputs to same device as model
        device = next(model.parameters()).device
        dummy_input = dummy_input.to(device)
        
        # Forward pass (ensure gradients flow through adapters)
        output = model(dummy_input)
        
        loss_val = output.loss.item() if hasattr(output, "loss") else "N/A"
        print(f"[Success] Standard Forward pass complete. Loss: {loss_val}")

    except Exception as e:
        print(f"[Error] An error occurred during execution: {e}")

[Main] Loading pre-trained ViTMAE...

[System] Starting Adapter Injection Procedure
[Config] Freezing original backbone parameters...
[Config] Model Config: Hidden Dim=768, Layers=12
[Config] Adapter Config: Bottleneck Dim=64
[Action] Injecting adapters into encoder layers...
  -> Processed layer 4/12
  -> Processed layer 8/12
  -> Processed layer 12/12
[System] Injection Complete. Decoder layers ignored (if present).

[Stats] Parameter Audit:
  - Total Parameters:     113,097,472
  - Frozen Backbone:      111,907,840
  - Trainable (Adapters): 1,189,632
  - Trainable Ratio:      1.05%

[Main] Running dummy forward pass to verify graph integrity...
[Success] Standard Forward pass complete. Loss: 0.998232364654541


In [15]:
model.vit(dummy_input).last_hidden_state.mean(dim=1).shape

torch.Size([1, 768])

In [None]:
model.vit(dummy_input).last_hidden_state.shape

tensor([[[-0.1229,  0.0684,  0.1409,  ..., -0.1243, -0.1670, -0.0609],
         [-0.3416, -0.0233,  0.0788,  ..., -0.0027, -0.3302, -0.1642],
         [-0.1514, -0.0708,  0.2829,  ..., -0.1047, -0.2944, -0.5037],
         ...,
         [-0.4514, -0.0598,  0.3576,  ..., -0.0868, -0.2820, -0.7138],
         [ 0.5742,  0.0094, -0.0291,  ...,  0.0154, -0.3651,  0.0233],
         [ 0.4702, -0.0291,  0.1678,  ..., -0.0907, -0.2683, -0.6697]]],
       grad_fn=<NativeLayerNormBackward0>)

In [16]:
model.vit.encoder

ViTMAEEncoder(
  (layer): ModuleList(
    (0-11): 12 x ViTBlockWithAdapter(
      (original_block): ViTMAELayer(
        (attention): ViTMAEAttention(
          (attention): ViTMAESelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTMAESelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTMAEIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTMAEOutput(
          (dense): Linear(in_features=3072, out_features=768, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
        (layernorm_before): LayerNorm((768,

In [None]:

from __future__ import annotations
import copy
import logging
from typing import List, Dict, Any, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger(__name__)


class FederatedClient:
    """
    Simulates a Local Client in the Federated Network.

    Each client possesses a private local dataset (or a slice of the global data) 
    and trains a copy of the global model. 

    Key Responsibilities:
    1.  **Phased Training**:
        -   **Round 1**: Learns purely via Masked Autoencoding (MAE), establishing 
            a strong self-supervised baseline.
        -   **Round > 1**: Learns via MAE + GPAD. It uses global prototypes received 
            from the server to regularize its feature space, preventing it from 
            drifting too far from the global consensus (Continual Learning).

    2.  **Prototype Generation**:
        -   After training, it extracts latent features from its data.
        -   Runs an internal K-Means clustering algorithm to identify 'Local Prototypes'
            representing the distinct visual concepts found in its private data.
        -   Sends these prototypes (vectors) to the server, preserving data privacy 
            (raw images are never shared).
    """
    
    def __init__(
        self, 
        client_id: int,
        model: nn.Module,
        device: torch.device,
        dtype: torch.dtype,
        optimizer_cls: type = optim.AdamW,
        optimizer_kwargs: Dict[str, Any] = None,
        local_update_threshold: float = 0.7,
        local_ema_alpha: float = 0.1,
    ) -> None:
        """
        Initialize the Client.

        Args:
            client_id (int): Unique identifier.
            model (nn.Module): The base model architecture (ViT-MAE). 
                            Ideally, this is a deep copy of the global model.
            device (torch.device): The hardware device (CPU/GPU) this client runs on.
            optimizer_cls (type): The class of optimizer to use (default: AdamW).
            optimizer_kwargs (Dict): Configuration for the local optimizer (lr, weight_decay).
        """
        self.client_id = client_id
        self.device = device
        self.local_update_threshold = local_update_threshold
        self.local_ema_alpha = local_ema_alpha
        self.dtype = dtype
        
        # Local Prototypes State (Initialized as empty, populated/updated during lifecycle)
        self.local_prototypes: Optional[torch.Tensor] = None
        
        # 1. Independent Model Copy
        # We deepcopy the base model so that this client's training 
        # doesn't affect the base model or other clients.
        self.model = copy.deepcopy(model).to(self.device)
        
        # 2. Independent Optimizer
        opt_kwargs = optimizer_kwargs or {"lr": 1e-3}
        self.optimizer = optimizer_cls(self.model.parameters(), **opt_kwargs)
        
        logger.info(f"Client {self.client_id} initialized on {self.device}")

    def train_epoch(
        self, 
        dataloader: DataLoader, 
        global_prototypes: torch.Tensor = None,
        gpad_loss_fn: nn.Module = None
    ) -> float:
        """
        Executes one epoch of local training.

        The loss function changes based on the availability of global prototypes:
        -   **Initialization Phase (No Prototypes)**: Loss = L_mae
        -   **Continual Phase (Has Prototypes)**: Loss = L_mae + L_gpad

        Args:
            dataloader (DataLoader): Local data stream.
            global_prototypes (Tensor, optional): Global concepts from Server.
            gpad_loss_fn (nn.Module, optional): The distillation loss module.

        Returns:
            float: Average loss across the epoch.
        """
        self.model.train()
        total_loss = 0.0
        num_batches = 0
        
        for batch in dataloader:
            inputs = batch.to(self.dtype).to(self.device)

            # Forward Pass
            outputs = self.model(inputs)

            # 1. Base MAE Loss
            mae_loss = getattr(outputs, "loss", None)
            if mae_loss is None:
                # Fallback if model doesn't compute loss internally (unlikely for ViTMAEForPreTraining)
                mae_loss = torch.tensor(0.0, dtype=self.dtype, device=self.device, requires_grad=True)

            final_loss = mae_loss

            # --- Feature Extraction (Shared for GPAD and Local Proto Update) ---
            embeddings = self.model.vit(dummy_input).last_hidden_state.mean(dim=1)
        
            # 2. GPAD Loss (if applicable)
            if global_prototypes is not None and gpad_loss_fn is not None and embeddings is not None:
                # Compute GPAD
                # Ensure global prototypes are on same device
                protos_device = global_prototypes.to(self.device)
                gpad = gpad_loss_fn(embeddings, protos_device)
                
                final_loss = final_loss + gpad

            # 3. Online Local Prototype Update (Separate Logic)
            # "Check sim with rest of the local prototypes -> Find Best -> EMA if > Threshold"
            if self.local_prototypes is not None and embeddings is not None:
                if self.local_prototypes.device != self.device:
                    self.local_prototypes = self.local_prototypes.to(self.device)


                with torch.inference_mode():
                    # Normalize for Cosine Similarity
                    z_norm = F.normalize(embeddings, p=2, dim=1)
                    p_norm = F.normalize(self.local_prototypes, p=2, dim=1)
                    
                    # Compute Similarity Matrix: (B, K_local)
                    sims = torch.mm(z_norm, p_norm.t())
                    
                    # Find Best Matching Prototype per sample
                    max_sim, best_idx = sims.max(dim=1)
                    
                    # Mask: Who passes the fixed threshold?
                    mask = max_sim > self.local_update_threshold
                    
                    # Update Loop for matching samples
                    indices = torch.where(mask)[0]
                    if len(indices) > 0:
                        for idx in indices:
                            sample_emb = z_norm[idx]
                            proto_idx = best_idx[idx]
                            
                            # EMA Update: Old = (1-a)Old + a*New
                            old_proto = self.local_prototypes[proto_idx]
                            updated_proto = (1 - self.local_ema_alpha) * old_proto + self.local_ema_alpha * sample_emb
                            # In-place update
                            self.local_prototypes[proto_idx] = updated_proto

            # Backward Pass
            self.optimizer.zero_grad()
            final_loss.backward()
            self.optimizer.step()

            total_loss += final_loss.item()
            num_batches += 1
            
        avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
        return avg_loss


    @torch.no_grad()
    def generate_prototypes(self, dataloader: DataLoader, K_init: int = 10) -> torch.Tensor:
        """
        Generates 'Local Prototypes' by clustering the client's feature space.

        Process:
        1.  Inference: Run the local model on all local data to extract embeddings.
        2.  Clustering: Perform K-Means on these embeddings to find K centroids.
        3.  These centroids become the 'Local Prototypes' sent to the server.

        Args:
            dataloader (DataLoader): Local data.
            K_init (int): Number of prototypes to generate.

        Returns:
            torch.Tensor: Local prototypes [K, Dim].
        """
        self.model.eval()
        all_features = []
        
        # 1. Feature Extraction (Forward Pass)
        for batch in dataloader:
            inputs = batch.to(self.dtype).to(self.device)
            
            # Extract features from the model
            # Assuming model returns an object with 'hidden_states' or similar, 
            # or for ViTMAE, we might need to tap into the encoder output.
            # For simplicity, let's assume the model returns a direct embedding or 'last_hidden_state'
            # If standard ViTMAE, outputs.last_hidden_state is (B, L, D). We usually pool it (e.g. CLS or mean).
            # Let's assume Mean Pooling for prototype generation if sequence provided.
            
            with torch.inference_mode():
                features = self.model.vit(inputs).last_hidden_state.mean(dim=1)
                
            all_features.append(features)

        # Concatenate all features: (N_samples, D)
        embeddings = torch.cat(all_features, dim=0)
        
        # 2. Normalization
        embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
        
        # 3. K-Means Clustering (Simple PyTorch Implementation)
        centroids = self._kmeans(embeddings, K=K_init)
        
        # Save for next round's online updates
        self.local_prototypes = centroids.detach().clone()
        
        return centroids

    def _kmeans(self, X: torch.Tensor, K: int, max_iters: int = 100) -> torch.Tensor:
        """
        Simple K-Means implementation in PyTorch.
        """
        N, D = X.shape
        
        # Initialize centroids randomly from data
        indices = torch.randperm(N)[:K]
        centroids = X[indices].clone()
        
        for _ in range(max_iters):
            # Compute distances: ||x - c||^2
            # (Using cosine distance since inputs are normalized)
            # dist = 1 - cos_sim
            
            # Normalize centroids to keep consistent with embedding space (unit sphere)
            centroids = torch.nn.functional.normalize(centroids, p=2, dim=1)
            
            # Similarity matrix: (N, K)
            sims = torch.mm(X, centroids.t())
            # Distance is monotonic with 1-sim, so maximizing sim is minimizing dist
            
            # Assign clusters
            _, labels = sims.max(dim=1)
            
            # Update centroids
            new_centroids = torch.zeros_like(centroids)
            for k in range(K):
                cluster_mask = (labels == k)
                if cluster_mask.sum() > 0:
                    new_centroids[k] = X[cluster_mask].mean(dim=0)
                else:
                    # Re-initialize empty cluster
                    new_idx = torch.randint(0, N, (1,)).item()
                    new_centroids[k] = X[new_idx]
            
            # Check convergence
            center_shift = torch.norm(new_centroids - centroids)
            centroids = new_centroids
            if center_shift < 1e-4:
                break
                
        return torch.nn.functional.normalize(centroids, p=2, dim=1)


class ClientManager:
    """
    Simulates the Orchestration of Multiple Clients.
    
    In a real FL system, this would be distributed across devices. Here, it manages
    a list of `FederatedClient` objects and orchestrates their training, effectively
    simulating the "edge" layer.

    Execution Modes:
    ----------------
    1.  **Parallel (GPU)**: If GPUs are available (`gpu_count > 0`), it enforces
        a strict 1:1 mapping (Client i -> GPU i) and runs training in parallel threads.
    
    2.  **Sequential (CPU)**: If no GPUs are available, it runs clients one after 
        another to avoid the overhead of threading on a single CPU resource.
    """
    
    def __init__(
        self, 
        base_model: nn.Module, 
        num_clients: int, 
        gpu_count: int = 0
    ) -> None:
        """
        Initializes the Client Manager and spawns the clients.

        Args:
            base_model: The initial global model template.
            num_clients: Total number of clients to simulate.
            gpu_count: Number of available GPUs.
        """
        self.clients: List[FederatedClient] = []
        self.num_clients = num_clients
        self.gpu_count = gpu_count
        
        self._initialize_clients(base_model)

    def _initialize_clients(self, base_model: nn.Module) -> None:
        """Internal helper to spawn clients on appropriate devices."""
        logger.info(f"Initializing {self.num_clients} clients...")
        
        # Enforce 1:1 Mapping rule if GPUs are available
        if self.gpu_count > 0:
            if self.num_clients != self.gpu_count:
                raise ValueError(
                    f"Strict 1:1 Client-GPU mapping required. "
                    f"Requested {self.num_clients} clients but found {self.gpu_count} GPUs."
                )
            logger.info(f"Parallel Mode: Mapped {self.num_clients} clients to {self.gpu_count} GPUs.")
        else:
            logger.info(f"Sequential Mode: Running {self.num_clients} clients on CPU.")

        for i in range(self.num_clients):
            # Determine Device
            if self.gpu_count > 0:
                # 1:1 Mapping: Client i -> GPU i
                device = torch.device(f"cuda:{i}")
            else:
                device = torch.device("cpu")
            
            # Create Client
            # Note: The optimizer config can be parameterized later
            client = FederatedClient(
                client_id=i,
                model=base_model,
                device=device,
                optimizer_kwargs={"lr": 1e-4, "weight_decay": 0.05}
            )
            self.clients.append(client)

    def train_round(
        self, 
        dataloaders: List[DataLoader],
        global_prototypes: torch.Tensor = None,
        gpad_loss_fn: nn.Module = None
    ) -> List[float]:
        """
        Triggers one round of local training for ALL clients.
        
        Dispatch Logic:
        -   **GPU Available**: Uses `ThreadPoolExecutor` to launch `N` concurrent threads.
            Since PyTorch releases the GIL for CUDA operations, this achieves true parallelism.
        -   **CPU Only**: Iterates sequentially. Python threads + CPU compute usually 
            degrades performance due to GIL contention, so sequential is faster here.
        
        Args:
            dataloaders: List of DataLoaders (must match num_clients).
            global_prototypes: The current global prototype bank (for GPAD).
            gpad_loss_fn: The loss function instance.

        Returns:
            List[float]: The average training loss for each client.
        """
        if len(dataloaders) != self.num_clients:
            raise ValueError(
                f"Dataloader count ({len(dataloaders)}) does not match "
                f"client count ({self.num_clients})"
            )

        round_losses = [0.0] * self.num_clients
        
        if self.gpu_count > 0:
            # Parallel Execution for GPUs
            from concurrent.futures import ThreadPoolExecutor
            
            with ThreadPoolExecutor(max_workers=self.num_clients) as executor:
                logger.info(f"Spawning {self.num_clients} training threads (1 per GPU)...")
                futures = {}
                for i, client in enumerate(self.clients):
                    futures[executor.submit(
                        client.train_epoch, 
                        dataloaders[i], 
                        global_prototypes=global_prototypes, 
                        gpad_loss_fn=gpad_loss_fn
                    )] = i
                
                for future in futures:
                    client_idx = futures[future]
                    try:
                        loss = future.result()
                        round_losses[client_idx] = loss
                        logger.info(f"Client {client_idx} (GPU {client_idx}) finished. Loss: {loss:.4f}")
                    except Exception as e:
                        logger.error(f"Client {client_idx} failed: {e}")
                        round_losses[client_idx] = float('nan')
        else:
            # Sequential Execution for CPU
            logger.info(f"Running sequential training on CPU for {self.num_clients} clients...")
            for i, client in enumerate(self.clients):
                try:
                    loss = client.train_epoch(
                        dataloaders[i],
                        global_prototypes=global_prototypes,
                        gpad_loss_fn=gpad_loss_fn
                    )
                    round_losses[i] = loss
                    logger.info(f"Client {i} (CPU) finished. Loss: {loss:.4f}")
                except Exception as e:
                    logger.error(f"Client {i} failed: {e}")
                    round_losses[i] = float('nan')
            
        return round_losses