In [None]:
# =============================================================================
# IMPORT MODULES
# =============================================================================
import os
import sys
import logging
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets

# Ensure local modules are found
if os.getcwd() not in sys.path:
    sys.path.append(os.getcwd())

# Explicit Imports from Project
from src.client import ClientManager, FederatedClient
from src.loss import GPADLoss
from src.server import GlobalPrototypeBank, FederatedModelServer, run_server_round, GlobalModel

# Configure Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("NotebookRequest")

print("Modules imported successfully.")


In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================
CONFIG = {
    "num_clients": 2,          
    "num_rounds": 3,           
    "batch_size": 8,           
    "embedding_dim": 768,      # ViT Base dim
    "gpu_count": 2 if torch.cuda.device_count() >= 2 else torch.cuda.device_count(),
    
    # Loss & Proto Params
    "merge_threshold": 0.85,
    "ema_alpha": 0.1,
    "gpad_base_tau": 0.5,
    "gpad_temp_gate": 0.1,
    "k_init_prototypes": 5,
}

logger.info(f"Starting Pipeline Test with Config: {CONFIG}")


In [None]:
# =============================================================================
# DATA PREPARATION (CIFAR10)
# =============================================================================

# Define Transforms (Resize to 224 for ViT)
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load Dataset
try:
    logger.info("Loading CIFAR10...")
    full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    
    # Use a small subset for speed
    subset_size = 100
    indices = np.random.choice(len(full_dataset), subset_size, replace=False)
    subset = torch.utils.data.Subset(full_dataset, indices)
    
    # Split for 2 clients
    lengths = [subset_size // 2, subset_size - (subset_size // 2)]
    ds1, ds2 = torch.utils.data.random_split(subset, lengths)
    
    dl1 = DataLoader(ds1, batch_size=CONFIG["batch_size"], shuffle=True)
    dl2 = DataLoader(ds2, batch_size=CONFIG["batch_size"], shuffle=True)
    dataloaders = [dl1, dl2]
    
    logger.info(f"Data Ready: 2 Clients with {len(ds1)} samples each.")
    
except Exception as e:
    logger.error(f"Data loading failed: {e}")
    raise


In [None]:
# =============================================================================
# INITIALIZE COMPONENTS
# =============================================================================

# 1. Global Prototype Bank
proto_bank = GlobalPrototypeBank(
    embedding_dim=CONFIG["embedding_dim"],
    merge_threshold=CONFIG["merge_threshold"],
    ema_alpha=CONFIG["ema_alpha"],
    device="cpu" 
)

# 2. Global Model Server (Aggregation)
fed_server = FederatedModelServer()

# 3. Base Model (ViT-MAE with Adapters)
try:
    from transformers import ViTMAEForPreTraining
    from src.mae_with_adapter import inject_adapters
    
    logger.info("Initializing ViT-MAE Backbone...")
    base_model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
    base_model = inject_adapters(base_model, bottleneck_dim=64)
    logger.info("Adapters Injected Successfully.")
    
except ImportError:
    logger.warning("Transformers not found. Logic will fail unless mocked.")
    raise

# 4. Client Manager
client_manager = ClientManager(
    base_model=base_model,
    num_clients=CONFIG["num_clients"],
    gpu_count=CONFIG["gpu_count"]
)

# 5. GPAD Loss
gpad_loss = GPADLoss(
    base_tau=CONFIG["gpad_base_tau"],
    temp_gate=CONFIG["gpad_temp_gate"]
)


In [None]:
# =============================================================================
# EXECUTE PIPELINE
# =============================================================================
global_protos = None
global_weights = None

for round_idx in range(1, CONFIG["num_rounds"] + 1):
    logger.info(f"\n--- Starting Round {round_idx} / {CONFIG['num_rounds']} ---")
    
    # A. Broadcast Global Weights (if exists)
    if global_weights is not None:
        logger.info("> Broadcasting Global Weights...")
        for client in client_manager.clients:
            client.model.load_state_dict(global_weights, strict=False)
            
    # B. Client Training Step
    logger.info("> Clients Training...")
    losses = client_manager.train_round(
        dataloaders,
        global_prototypes=global_protos,
        gpad_loss_fn=gpad_loss
    )
    logger.info(f"  Mean Batch Loss per Client: {losses}")
    
    # C. Extract Payloads (Protos + Weights)
    client_payloads = []
    logger.info("> Extracting Prototypes and Weights...")
    for i, client in enumerate(client_manager.clients):
        # Generate Local Prototypes (K-Means)
        local_protos = client.generate_prototypes(dataloaders[i], K_init=CONFIG["k_init_prototypes"])
        
        # Get Weights (CPU)
        weights = {k: v.cpu() for k, v in client.model.state_dict().items()}
        
        client_payloads.append({
            'client_id': f"client_{i}",
            'protos': local_protos.cpu(),
            'weights': weights
        })
        
    # D. Server Aggregation
    logger.info("> Server Aggregating...")
    server_result = run_server_round(
        proto_manager=proto_bank,
        model_server=fed_server,
        client_payloads=client_payloads
    )
    
    global_protos = server_result['global_prototypes']
    global_weights = server_result['global_weights']
    
    proto_count = global_protos.shape[0] if global_protos is not None else 0
    logger.info(f"  Round Complete. Global Prototype Bank Size: {proto_count}")

logger.info("\n*** Pipeline Execution Finished Successfully ***")
