In [1]:
# import os
# import shutil

# for item in os.listdir('.'):
#     if os.path.isfile(item) or os.path.islink(item):
#         os.remove(item)
#     elif os.path.isdir(item):
#         shutil.rmtree(item)
!git clone https://github.com/sathishkumar67/PODFCSSV.git
!mv PODFCSSV/* /kaggle/working

Cloning into 'PODFCSSV'...
remote: Enumerating objects: 261, done.[K
remote: Counting objects: 100% (3/3), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 261 (delta 0), reused 2 (delta 0), pack-reused 258 (from 1)[K
Receiving objects: 100% (261/261), 4.00 MiB | 18.97 MiB/s, done.
Resolving deltas: 100% (112/112), done.


In [2]:
# =============================================================================
# IMPORT MODULES
# =============================================================================
import os
import sys
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets

# Ensure local modules are found
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

# Explicit Imports from Project
from src.client import ClientManager, FederatedClient
from src.loss import GPADLoss
from src.server import GlobalPrototypeBank, FederatedModelServer, run_server_round, GlobalModel

# Configure Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("NotebookRequest")

print("Modules imported successfully.")


Modules imported successfully.


In [3]:
# =============================================================================
# CONFIGURATION
# =============================================================================
CONFIG = {
    "num_clients": 2,          
    "num_rounds": 3,           
    "batch_size": 8,           
    "embedding_dim": 768,      # ViT Base dim
    "gpu_count": 2 if torch.cuda.device_count() >= 2 else torch.cuda.device_count(),
    
    # Loss & Proto Params
    "merge_threshold": 0.85,
    "ema_alpha": 0.1,
    "gpad_base_tau": 0.5,
    "gpad_temp_gate": 0.1,
    "k_init_prototypes": 5,
}

logger.info(f"Starting Pipeline Test with Config: {CONFIG}")


2026-02-16 11:18:56,984 [INFO] NotebookRequest: Starting Pipeline Test with Config: {'num_clients': 2, 'num_rounds': 3, 'batch_size': 8, 'embedding_dim': 768, 'gpu_count': 2, 'merge_threshold': 0.85, 'ema_alpha': 0.1, 'gpad_base_tau': 0.5, 'gpad_temp_gate': 0.1, 'k_init_prototypes': 5}


In [4]:
# =============================================================================
# DATA PREPARATION (CIFAR10)
# =============================================================================

# Define Transforms (Resize to 224 for ViT)
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load Dataset
try:
    logger.info("Loading CIFAR10...")
    full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    
    # Use a small subset for speed
    subset_size = 100
    indices = np.random.choice(len(full_dataset), subset_size, replace=False)
    subset = torch.utils.data.Subset(full_dataset, indices)
    
    # Split for 2 clients
    lengths = [subset_size // 2, subset_size - (subset_size // 2)]
    ds1, ds2 = torch.utils.data.random_split(subset, lengths)
    
    dl1 = DataLoader(ds1, batch_size=CONFIG["batch_size"], shuffle=True)
    dl2 = DataLoader(ds2, batch_size=CONFIG["batch_size"], shuffle=True)
    dataloaders = [dl1, dl2]
    
    logger.info(f"Data Ready: 2 Clients with {len(ds1)} samples each.")
    
except Exception as e:
    logger.error(f"Data loading failed: {e}")
    raise


2026-02-16 11:18:57,446 [INFO] NotebookRequest: Loading CIFAR10...
100%|██████████| 170M/170M [00:01<00:00, 104MB/s]  
2026-02-16 11:19:01,511 [INFO] NotebookRequest: Data Ready: 2 Clients with 50 samples each.


In [7]:
from __future__ import annotations
import torch
import torch.nn as nn
from typing import Optional, Tuple, Union, List, Any
from transformers import PreTrainedModel, ViTMAEForPreTraining, ViTMAEModel



class IBA_Adapter(nn.Module):
    """
    Information-Bottlenecked Adapter (IBA) module.

    This module implements a bottleneck architecture (Down-project -> Activation -> Up-project)
    inserted into frozen networks to introduce trainable parameters for efficient adaptation.
    
    Architecture:
        Input [B, L, D] -> Linear(D, d) -> Activation -> Linear(d, D) -> Dropout -> + Residual
    
    Key Design Principles:
        1. **Bottleneck**: Compresses information to force the model to learn efficient features.
        2. **Identity Initialization**: The up-projection is initialized to zero, ensuring 
        the adapter starts as an identity function (Adapter(x) = 0). This prevents 
        "semantic shock" to the pre-trained backbone at the start of training.

    Attributes:
        input_dim (int): Original hidden dimension.
        bottleneck_dim (int): Compressed dimension.
        down_project (nn.Linear): Dimensionality reduction layer.
        activation (nn.Module): Non-linear activation function.
        up_project (nn.Linear): Dimensionality restoration layer.
        dropout (nn.Dropout): Regularization layer.
    """

    def __init__(
        self, 
        input_dim: int, 
        bottleneck_dim: int = 64, 
        dropout: float = 0.0,
        activation: nn.Module = nn.GELU()
    ) -> None:
        """
        Initializes the IBA Adapter.

        Args:
            input_dim (int): The hidden dimension of the backbone model (e.g., 768 for ViT-Base).
            bottleneck_dim (int): The reduced dimension for the bottleneck. Lower values 
                compress information more (Information Bottleneck principle). Defaults to 64.
            dropout (float): Dropout probability applied after the up-projection. Defaults to 0.0.
            activation (nn.Module): Activation function to use between projections. Defaults to GELU.
        """
        super().__init__()
        self.input_dim = input_dim
        self.bottleneck_dim = bottleneck_dim
        self.activation = activation

        # Down-projection: Compress semantic information
        self.down_project = nn.Linear(input_dim, bottleneck_dim)
        
        # Up-projection: Reconstruct features for the next layer
        self.up_project = nn.Linear(bottleneck_dim, input_dim)
        self.dropout = nn.Dropout(dropout)
        
        self._init_weights()

    def _init_weights(self) -> None:
        """
        Applies specific initialization strategies to ensure stable training start.
        """
        # 1. Kaiming Normal for down_project to maintain variance through the non-linearity.
        nn.init.kaiming_normal_(self.down_project.weight, nonlinearity='linear')
        
        # 2. Zeros for up_project. This ensures the adapter output is initially 0.
        #    result = Input + 0. This preserves the pre-trained behavior exactly.
        nn.init.zeros_(self.up_project.weight)
        if self.up_project.bias is not None:
            nn.init.zeros_(self.up_project.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of the adapter.

        Args:
            x (torch.Tensor): Input tensor of shape [Batch_Size, Seq_Len, Hidden_Dim].

        Returns:
            torch.Tensor: Adapted features of the same shape as input.
        """
        residual = x
        
        # Bottleneck compression
        x = self.down_project(x)
        x = self.activation(x)
        
        # Note: Variational noise injection (e.g., for Zeus/V4 methods) 
        # would typically be applied here if probabilistic modeling is desired.
        
        # Reconstruction & Regularization
        x = self.up_project(x)
        x = self.dropout(x)
        
        # Residual connection preserves original features while adding adaptation
        return residual + x

    def __repr__(self) -> str:
        """Custom string representation for easier debugging."""
        return f"IBA_Adapter(in={self.input_dim}, btl={self.bottleneck_dim})"


class ViTBlockWithAdapter(nn.Module):
    """
    Wrapper class to inject an Adapter into a Hugging Face ViTLayer.

    It intercepts the output of the original frozen block, passes the hidden states
    through the adapter, and repackages the output to match Hugging Face's 
    return signature exactly.
    """

    def __init__(self, original_block: nn.Module, adapter: IBA_Adapter) -> None:
        """
        Args:
            original_block (nn.Module): The original, frozen Transformer block.
            adapter (IBA_Adapter): The trainable adapter instance.
        """
        super().__init__()
        self.original_block = original_block
        self.adapter = adapter

    def forward(
        self, 
        hidden_states: torch.Tensor, 
        head_mask: Optional[torch.Tensor] = None, 
        output_attentions: bool = False,
        **kwargs: Any
    ) -> Union[Tuple[torch.Tensor], Tuple[torch.Tensor, Any]]:
        """
        Forward pass matching standard Hugging Face ViTLayer signature.
        
        Args:
            hidden_states (torch.Tensor): Input tensor.
            head_mask (Optional[torch.Tensor]): Mask for attention heads.
            output_attentions (bool): Whether to return attention weights.
            **kwargs: Additional arguments required by specific HF implementations.

        Returns:
            Tuple containing the modified hidden state and optional attention weights.
        """
        # 1. Run the original frozen ViT Block
        # HF blocks typically return a tuple: (hidden_states, attention_weights (optional), ...)
        # We explicitly EXCLUDE output_attentions from the call as ViTMAE doesn't support it by default
        outputs = self.original_block(
            hidden_states, 
            head_mask=head_mask, 
            **kwargs
        )
        
        # 2. Extract Hidden States and Logic for Return Packaging
        if isinstance(outputs, tuple):
            x = outputs[0]
        elif hasattr(outputs, "hidden_states"):
            x = outputs.hidden_states
        else:
            x = outputs
        
        # 3. Apply the IBA Adapter
        x = self.adapter(x)
        
        # 4. Repackage output to maintain compatibility with HF pipeline
        if isinstance(outputs, tuple):
            # Reconstruct the tuple with the adapted hidden state
            return (x,) + outputs[1:]
        elif hasattr(outputs, "hidden_states"):
            # If it's a ModelOutput, we try to create a new one or modify in place?
            # Creating a new one is safer but requires knowing the class.
            # Mutating in place works if it's mutable.
            # A simpler hack that often works for HF is returning a tuple if it came as ModelOutput,
            # but some downstream layers check isinstance(ModelOutput).
            # However, standard ViTEncoder loop handles tuple or ModelOutput.
            # But if it wasn't a tuple originally, let's try to return what it expects.
            # Most robust: Just update the hidden_states attribute if mutable.
            try:
                outputs.hidden_states = x
                return outputs
            except:
                # If immutable, we fallback to tuple which HF usually accepts
                return (x,) 
        else:
            # It was a Tensor, return a Tensor
            return x


def inject_adapters(model: PreTrainedModel, bottleneck_dim: int = 64) -> PreTrainedModel:
    """
    Injects IBA Adapters into the Encoder of a ViTMAE (or similar) model.

    This function performs the following operations:
    1. Freezes all existing parameters in the model.
    2. Identifies the Encoder layers.
    3. Wraps each layer with `ViTBlockWithAdapter`.
    4. Unfreezes ONLY the new Adapter parameters.

    Args:
        model (PreTrainedModel): The Hugging Face ViTMAE model instance.
        bottleneck_dim (int): Dimension of the adapter bottleneck.

    Returns:
        PreTrainedModel: The modified model with adapters injected.
    
    Raises:
        AttributeError: If the model structure does not match standard ViT hierarchies.
    """
    print(f"\n{'='*60}")
    print(f"[System] Starting Adapter Injection Procedure")
    print(f"{'='*60}")

    # 1. Freeze the entire model backbone
    print("[Config] Freezing original backbone parameters...")
    for param in model.parameters():
        param.requires_grad = False
        
    # 2. Locate the Encoder
    # We verify structure to prevent runtime errors later
    if hasattr(model, "vit") and hasattr(model.vit, "encoder"):
        # Standard ViTMAE structure
        encoder = model.vit.encoder
        config = model.config
    elif hasattr(model, "encoder") and hasattr(model.encoder, "layer"):
        # Generic BERT/ViT structure fallback
        encoder = model.encoder
        config = model.config
    else:
        raise AttributeError(
            "Could not locate 'encoder.layer'. "
            "Model structure unknown (expected 'vit.encoder' or 'encoder')."
        )

    input_dim = config.hidden_size
    num_layers = len(encoder.layer)

    print(f"[Config] Model Config: Hidden Dim={input_dim}, Layers={num_layers}")
    print(f"[Config] Adapter Config: Bottleneck Dim={bottleneck_dim}")

    # 3. Iterate and Replace
    print("[Action] Injecting adapters into encoder layers...")
    
    for i, layer in enumerate(encoder.layer):
        # Instantiate the adapter
        adapter = IBA_Adapter(input_dim=input_dim, bottleneck_dim=bottleneck_dim)
        
        # CRITICAL: Ensure adapter is on the same device and dtype as the layer it wraps.
        # This handles cases where the model is already on GPU or in FP16/BF16.
        ref_param = next(layer.parameters())
        adapter.to(device=ref_param.device, dtype=ref_param.dtype)
        
        # Wrap the original layer
        wrapped_layer = ViTBlockWithAdapter(original_block=layer, adapter=adapter)
        
        # Mutate the ModuleList in-place
        encoder.layer[i] = wrapped_layer
        
        # Simple progress indicator for large models
        if (i + 1) % 4 == 0 or (i + 1) == num_layers:
            print(f"  -> Processed layer {i + 1}/{num_layers}")

    print(f"[System] Injection Complete. Decoder layers ignored (if present).")
    
    # 4. Verification of Trainable Parameters
    count_trainable_params(model)
    
    return model


def count_trainable_params(model: nn.Module) -> None:
    """
    Utility to calculate and print the count of frozen vs trainable parameters.
    
    Args:
        model (nn.Module): The model to audit.
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    frozen_params = total_params - trainable_params
    
    ratio = (trainable_params / total_params) * 100 if total_params > 0 else 0
    
    print(f"\n[Stats] Parameter Audit:")
    print(f"  - Total Parameters:     {total_params:,}")
    print(f"  - Frozen Backbone:      {frozen_params:,}")
    print(f"  - Trainable (Adapters): {trainable_params:,}")
    print(f"  - Trainable Ratio:      {ratio:.2f}%")
    print(f"{'='*60}\n")


# =============================================================================
# Main Execution Block (For Testing)
# =============================================================================
if __name__ == "__main__":
    # Simulate loading a model (mocking correct behavior if transformers is installed)
    print("[Main] Loading pre-trained ViTMAE...")
    try:
        # NOTE: Requires `pip install transformers`
        model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
        
        # Inject Adapters
        model = inject_adapters(model, bottleneck_dim=64)
        
        # Sanity Check: Forward pass
        print("[Main] Running dummy forward pass to verify graph integrity...")
        dummy_input = torch.randn(1, 3, 224, 224)
        
        # Move inputs to same device as model
        device = next(model.parameters()).device
        dummy_input = dummy_input.to(device)
        
        # Forward pass (ensure gradients flow through adapters)
        output = model(dummy_input)
        
        loss_val = output.loss.item() if hasattr(output, "loss") else "N/A"
        print(f"[Success] Forward pass complete. Loss: {loss_val}")
        
    except ImportError:
        print("[Error] 'transformers' library not found. Please install it to run this test.")
    except Exception as e:
        print(f"[Error] An error occurred during execution: {e}")

[Main] Loading pre-trained ViTMAE...

[System] Starting Adapter Injection Procedure
[Config] Freezing original backbone parameters...
[Config] Model Config: Hidden Dim=768, Layers=12
[Config] Adapter Config: Bottleneck Dim=64
[Action] Injecting adapters into encoder layers...
  -> Processed layer 4/12
  -> Processed layer 8/12
  -> Processed layer 12/12
[System] Injection Complete. Decoder layers ignored (if present).

[Stats] Parameter Audit:
  - Total Parameters:     113,097,472
  - Frozen Backbone:      111,907,840
  - Trainable (Adapters): 1,189,632
  - Trainable Ratio:      1.05%

[Main] Running dummy forward pass to verify graph integrity...
[Success] Forward pass complete. Loss: 1.000892162322998


In [8]:
# =============================================================================
# INITIALIZE COMPONENTS
# =============================================================================

# 1. Global Prototype Bank
proto_bank = GlobalPrototypeBank(
    embedding_dim=CONFIG["embedding_dim"],
    merge_threshold=CONFIG["merge_threshold"],
    ema_alpha=CONFIG["ema_alpha"],
    device="cpu" 
)

# 2. Global Model Server (Aggregation)
fed_server = FederatedModelServer()

# 3. Base Model (ViT-MAE with Adapters)
try:
    from transformers import ViTMAEForPreTraining
    
    logger.info("Initializing ViT-MAE Backbone...")
    base_model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
    base_model = inject_adapters(base_model, bottleneck_dim=64)
    logger.info("Adapters Injected Successfully.")
    
except ImportError:
    logger.warning("Transformers not found. Logic will fail unless mocked.")
    raise

# 4. Client Manager
client_manager = ClientManager(
    base_model=base_model,
    num_clients=CONFIG["num_clients"],
    gpu_count=CONFIG["gpu_count"]
)

# 5. GPAD Loss
gpad_loss = GPADLoss(
    base_tau=CONFIG["gpad_base_tau"],
    temp_gate=CONFIG["gpad_temp_gate"]
)


2026-02-16 11:20:09,628 [INFO] NotebookRequest: Initializing ViT-MAE Backbone...
2026-02-16 11:20:09,935 [INFO] NotebookRequest: Adapters Injected Successfully.
2026-02-16 11:20:09,935 [INFO] src.client: Initializing 2 clients...
2026-02-16 11:20:09,936 [INFO] src.client: Parallel Mode: Mapped 2 clients to 2 GPUs.



[System] Starting Adapter Injection Procedure
[Config] Freezing original backbone parameters...
[Config] Model Config: Hidden Dim=768, Layers=12
[Config] Adapter Config: Bottleneck Dim=64
[Action] Injecting adapters into encoder layers...
  -> Processed layer 4/12
  -> Processed layer 8/12
  -> Processed layer 12/12
[System] Injection Complete. Decoder layers ignored (if present).

[Stats] Parameter Audit:
  - Total Parameters:     113,097,472
  - Frozen Backbone:      111,907,840
  - Trainable (Adapters): 1,189,632
  - Trainable Ratio:      1.05%



2026-02-16 11:20:10,278 [INFO] src.client: Client 0 initialized on cuda:0
2026-02-16 11:20:10,552 [INFO] src.client: Client 1 initialized on cuda:1


In [9]:
# =============================================================================
# EXECUTE PIPELINE
# =============================================================================
global_protos = None
global_weights = None

for round_idx in range(1, CONFIG["num_rounds"] + 1):
    logger.info(f"\n--- Starting Round {round_idx} / {CONFIG['num_rounds']} ---")
    
    # A. Broadcast Global Weights (if exists)
    if global_weights is not None:
        logger.info("> Broadcasting Global Weights...")
        for client in client_manager.clients:
            client.model.load_state_dict(global_weights, strict=False)
            
    # B. Client Training Step
    logger.info("> Clients Training...")
    losses = client_manager.train_round(
        dataloaders,
        global_prototypes=global_protos,
        gpad_loss_fn=gpad_loss
    )
    logger.info(f"  Mean Batch Loss per Client: {losses}")
    
    # C. Extract Payloads (Protos + Weights)
    client_payloads = []
    logger.info("> Extracting Prototypes and Weights...")
    for i, client in enumerate(client_manager.clients):
        # Generate Local Prototypes (K-Means)
        local_protos = client.generate_prototypes(dataloaders[i], K_init=CONFIG["k_init_prototypes"])
        
        # Get Weights (CPU)
        weights = {k: v.cpu() for k, v in client.model.state_dict().items()}
        
        client_payloads.append({
            'client_id': f"client_{i}",
            'protos': local_protos.cpu(),
            'weights': weights
        })
        
    # D. Server Aggregation
    logger.info("> Server Aggregating...")
    server_result = run_server_round(
        proto_manager=proto_bank,
        model_server=fed_server,
        client_payloads=client_payloads
    )
    
    global_protos = server_result['global_prototypes']
    global_weights = server_result['global_weights']
    
    proto_count = global_protos.shape[0] if global_protos is not None else 0
    logger.info(f"  Round Complete. Global Prototype Bank Size: {proto_count}")

logger.info("\n*** Pipeline Execution Finished Successfully ***")


2026-02-16 11:20:18,223 [INFO] NotebookRequest: 
--- Starting Round 1 / 3 ---
2026-02-16 11:20:18,224 [INFO] NotebookRequest: > Clients Training...
2026-02-16 11:20:18,224 [INFO] src.client: Spawning 2 training threads (1 per GPU)...
2026-02-16 11:20:19,565 [INFO] src.client: Client 0 (GPU 0) finished. Loss: 0.0140
2026-02-16 11:20:19,566 [INFO] src.client: Client 1 (GPU 1) finished. Loss: 0.0136
2026-02-16 11:20:19,567 [INFO] NotebookRequest:   Mean Batch Loss per Client: [0.013959189078637533, 0.013591428553419454]
2026-02-16 11:20:19,568 [INFO] NotebookRequest: > Extracting Prototypes and Weights...
2026-02-16 11:20:21,024 [INFO] NotebookRequest: > Server Aggregating...
2026-02-16 11:20:21,405 [INFO] NotebookRequest:   Round Complete. Global Prototype Bank Size: 3
2026-02-16 11:20:21,406 [INFO] NotebookRequest: 
--- Starting Round 2 / 3 ---
2026-02-16 11:20:21,407 [INFO] NotebookRequest: > Broadcasting Global Weights...
2026-02-16 11:20:21,652 [INFO] NotebookRequest: > Clients Train