In [29]:
# =============================================================================
# 0. ENVIRONMENT DIAGNOSTICS & SETUP
# =============================================================================
import os
import sys

# Ensure the current directory is in SYS.PATH
cwd = os.getcwd()
if cwd not in sys.path:
    sys.path.append(cwd)

print(f"Current Working Directory: {cwd}")

# Verify 'src' folder exists
if os.path.isdir('src'):
    print("'src' directory found.")
    print(f"Contents: {os.listdir('src')}")
else:
    print("ERROR: 'src' directory NOT found! Make sure you run this notebook from the project root.")

# Autoreload (Useful if editing files)
%load_ext autoreload
%autoreload 2

print("Environment setup complete.")


Current Working Directory: /kaggle/working
'src' directory found.
Contents: ['__init__.py', 'mae_with_adapter.py', 'server.py', 'loss.py', 'client.py']


ModuleNotFoundError: No module named 'imp'

In [30]:
# =============================================================================
# 1. IMPORTS (Explicit & Verbose)
# =============================================================================
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
import numpy as np
import logging

# Local Imports
try:
    from src.client import ClientManager, FederatedClient
    print("SUCCESS: Imported src.client")
    
    from src.loss import GPADLoss
    print("SUCCESS: Imported src.loss")
    
    from src.server import GlobalPrototypeBank, FederatedModelServer, run_server_round, GlobalModel
    print("SUCCESS: Imported src.server (GlobalPrototypeBank found)")
    
except ImportError as e:
    print(f"CRITICAL IMPORT ERROR: {e}")
    print("Please check that 'src' is a valid package (contains __init__.py) and in the path.")
    # Attempt to debug src package location
    try:
        import src
        print(f"src package location: {src.__file__}")
    except:
        pass


SUCCESS: Imported src.client
SUCCESS: Imported src.loss
CRITICAL IMPORT ERROR: cannot import name 'GlobalPrototypeBank' from 'src.server' (/kaggle/working/src/server.py)
Please check that 'src' is a valid package (contains __init__.py) and in the path.
src package location: /kaggle/working/src/__init__.py


In [32]:
# =============================================================================
# 2. CONFIGURATION & LOGGING
# =============================================================================
# Configure Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("NotebookTest")

CONFIG = {
    "num_clients": 2,          
    "num_rounds": 3,           
    "batch_size": 4,           
    "embedding_dim": 768,      # ViT Base dim
    "gpu_count": 2 if torch.cuda.device_count() >= 2 else torch.cuda.device_count(),
    
    # Loss & Proto Params
    "merge_threshold": 0.85,
    "ema_alpha": 0.1,
    "gpad_base_tau": 0.5,
    "gpad_temp_gate": 0.1,
    "k_init_prototypes": 5,
}

logger.info(f"Running Test with Config: {CONFIG}")


2026-02-16 10:09:25,040 [INFO] NotebookTest: Running Test with Config: {'num_clients': 2, 'num_rounds': 3, 'batch_size': 4, 'embedding_dim': 768, 'gpu_count': 2, 'merge_threshold': 0.85, 'ema_alpha': 0.1, 'gpad_base_tau': 0.5, 'gpad_temp_gate': 0.1, 'k_init_prototypes': 5}


In [33]:
# =============================================================================
# 3. DATA PREPARATION (Using Torchvision / Mock)
# =============================================================================
class FakeImageDataset(Dataset):
    def __init__(self, size=100, dim=224):
        self.data = torch.randn(size, 3, dim, dim)
        self.targets = torch.randint(0, 10, (size,))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# Try real data, fallback to fake
try:
    logger.info("Attempting to load CIFAR10 for testing...")
    transform = transforms.Compose([
        transforms.Resize((224, 224)), # ViT expects 224
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    # subset
    # Use download=True if strict internet access. 
    full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    # Take just 40 samples per client for speed
    indices = np.random.choice(len(full_dataset), 40, replace=False)
    subset = torch.utils.data.Subset(full_dataset, indices)
    logger.info("CIFAR10 loaded successfully.")
    
except Exception as e:
    logger.warning(f"Could not load CIFAR10 ({e}). Using Fake Data.")
    subset = FakeImageDataset(size=40)

# Split data for 2 clients
cutoff = len(subset) // 2
ds1, ds2 = torch.utils.data.random_split(subset, [cutoff, len(subset)-cutoff])

dl1 = DataLoader(ds1, batch_size=CONFIG["batch_size"], shuffle=True)
dl2 = DataLoader(ds2, batch_size=CONFIG["batch_size"], shuffle=True)
dataloaders = [dl1, dl2]

logger.info(f"Data Prepared: 2 Clients with {len(ds1)} samples each.")


2026-02-16 10:09:28,427 [INFO] NotebookTest: Attempting to load CIFAR10 for testing...
2026-02-16 10:09:29,378 [INFO] NotebookTest: CIFAR10 loaded successfully.
2026-02-16 10:09:29,381 [INFO] NotebookTest: Data Prepared: 2 Clients with 20 samples each.


In [34]:
# =============================================================================
# 4. COMPONENT INITIALIZATION
# =============================================================================

# A. Global Prototype Bank
# Ensure GlobalPrototypeBank is defined in current scope (from imports cell)
if 'GlobalPrototypeBank' not in locals():
    raise NameError("GlobalPrototypeBank not found! Did the Imports cell run successfully?")

proto_bank = GlobalPrototypeBank(
    embedding_dim=CONFIG["embedding_dim"],
    merge_threshold=CONFIG["merge_threshold"],
    ema_alpha=CONFIG["ema_alpha"],
    device="cpu" 
)

# B. Server Aggregator
fed_server = FederatedModelServer()

# C. Components - Model (Using real ViT if possible, handled by library import check in code)
try:
    from transformers import ViTMAEForPreTraining
    from src.mae_with_adapter import inject_adapters
    
    logger.info("Loading Base ViT-MAE Model...")
    base_model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
    base_model = inject_adapters(base_model, bottleneck_dim=64)
    logger.info("Base Model Ready with Adapters.")
    
except ImportError:
    logger.error("Transformers library missing! Can't load ViT-MAE. Using Mock fallback.")
    class MockViTMAE(nn.Module):
        def __init__(self, dim=768):
            super().__init__()
            self.encoder = nn.Linear(dim, dim) 
            self.config = type('Config', (), {'hidden_size': dim})()
        def forward(self, x, output_hidden_states=False, **kwargs):
            B = x.size(0)
            feat = torch.randn(B, 768).to(x.device)
            class Output: pass
            out = Output()
            out.loss = feat.abs().mean()
            out.hidden_states = [feat.unsqueeze(1)]
            return out
    base_model = MockViTMAE()

# D. Client Manager
client_manager = ClientManager(
    base_model=base_model,
    num_clients=CONFIG["num_clients"],
    gpu_count=CONFIG["gpu_count"]
)

# E. Loss Function
gpad_loss = GPADLoss(
    base_tau=CONFIG["gpad_base_tau"],
    temp_gate=CONFIG["gpad_temp_gate"]
)


NameError: GlobalPrototypeBank not found! Did the Imports cell run successfully?

In [None]:
# =============================================================================
# 5. PIPELINE EXECUTION
# =============================================================================
global_protos = None
global_weights = None

for round_idx in range(1, CONFIG["num_rounds"] + 1):
    logger.info(f"\n--- Starting Round {round_idx} ---")
    
    # 1. Update Clients with Global State (if available)
    if global_weights is not None:
        logger.info("Broadcasting Global Weights to Clients...")
        for client in client_manager.clients:
            client.model.load_state_dict(global_weights, strict=False)
            
    # 2. Train
    logger.info("Clients Training...")
    losses = client_manager.train_round(
        dataloaders,
        global_prototypes=global_protos,
        gpad_loss_fn=gpad_loss
    )
    logger.info(f"Round Losses: {losses}")
    
    # 3. Extract Prototypes & Weights
    client_payloads = []
    logger.info("Extracting Client Payloads...")
    for i, client in enumerate(client_manager.clients):
        local_protos = client.generate_prototypes(dataloaders[i], K_init=CONFIG["k_init_prototypes"])
        weights = {k: v.cpu() for k, v in client.model.state_dict().items()}
        
        client_payloads.append({
            'client_id': f"client_{i}",
            'protos': local_protos.cpu(),
            'weights': weights
        })
        
    # 4. Server Aggregation
    logger.info("Server Aggregating...")
    server_result = run_server_round(
        proto_manager=proto_bank,
        model_server=fed_server,
        client_payloads=client_payloads
    )
    
    global_protos = server_result['global_prototypes']
    global_weights = server_result['global_weights']
    
    count = global_protos.shape[0] if global_protos is not None else 0
    logger.info(f"Round {round_idx} Complete. Global Prototypes: {count}")

logger.info("\nPipeline Test Finished Successfully!")
