In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
from torchvision.models import vit_b_16, ViT_B_16_Weights
import numpy as np
import logging
from typing import List, Dict, Any
import copy

# --- Fix Imports ---
import sys
import os
sys.path.append(os.getcwd())  # Ensure local modules are found

from src.server import GlobalPrototypeBank, FederatedModelServer, run_server_round, GlobalModel
from src.client import ClientManager, FederatedClient
from src.loss import GPADLoss

# Configure Logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s")
logger = logging.getLogger("NotebookTest")

# =============================================================================
# 1. SETUP & CONFIGURATION
# =============================================================================
CONFIG = {
    "num_clients": 2,          # As requested
    "num_rounds": 3,           # Just a few rounds to test pipeline
    "batch_size": 4,           # Small batch for quick test (or mock)
    "embedding_dim": 768,      # ViT Base dim
    "gpu_count": 2 if torch.cuda.device_count() >= 2 else torch.cuda.device_count(),
    
    # Loss & Proto Params
    "merge_threshold": 0.85,
    "ema_alpha": 0.1,
    "gpad_base_tau": 0.5,
    "gpad_temp_gate": 0.1,
    "k_init_prototypes": 5,
}

logger.info(f"Running Test with Config: {CONFIG}")

# =============================================================================
# 2. DATA PREPARATION (Using Torchvision / Mock)
# =============================================================================
# We'll use a tiny subset of CIFAR10 to ensure it runs quickly even if downloading is slow
# Or fake data if download fails to be robust.

class FakeImageDataset(Dataset):
    def __init__(self, size=100, dim=224):
        self.data = torch.randn(size, 3, dim, dim)
        self.targets = torch.randint(0, 10, (size,))
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.targets[idx]

# Try real data, fallback to fake
try:
    logger.info("Attempting to load CIFAR10 for testing...")
    transform = transforms.Compose([
        transforms.Resize((224, 224)), # ViT expects 224
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    # subset
    # Use download=True if strict internet access. 
    full_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
    # Take just 40 samples per client for speed
    indices = np.random.choice(len(full_dataset), 40, replace=False)
    subset = torch.utils.data.Subset(full_dataset, indices)
    logger.info("CIFAR10 loaded successfully.")
    
except Exception as e:
    logger.warning(f"Could not load CIFAR10 ({e}). Using Fake Data.")
    subset = FakeImageDataset(size=40)

# Split data for 2 clients
cutoff = len(subset) // 2
ds1, ds2 = torch.utils.data.random_split(subset, [cutoff, len(subset)-cutoff])

dl1 = DataLoader(ds1, batch_size=CONFIG["batch_size"], shuffle=True)
dl2 = DataLoader(ds2, batch_size=CONFIG["batch_size"], shuffle=True)
dataloaders = [dl1, dl2]

logger.info(f"Data Prepared: 2 Clients with {len(ds1)} samples each.")

# =============================================================================
# 3. INITIALIZE COMPONENTS
# =============================================================================

# A. Global Prototype Bank
proto_bank = GlobalPrototypeBank(
    embedding_dim=CONFIG["embedding_dim"],
    merge_threshold=CONFIG["merge_threshold"],
    ema_alpha=CONFIG["ema_alpha"],
    device="cpu" 
)

# B. Server Aggregator
fed_server = FederatedModelServer()

# C. Components - Model (Using real ViT if possible, handled by library import check in code)
# Since we are in notebook with GPU, let's try strict real model
try:
    from transformers import ViTMAEForPreTraining
    # We need to make sure we inject adapters. 
    # The `src.mae_with_adapter` handles this logic.
    from src.mae_with_adapter import inject_adapters
    
    logger.info("Loading Base ViT-MAE Model...")
    base_model = ViTMAEForPreTraining.from_pretrained("facebook/vit-mae-base")
    base_model = inject_adapters(base_model, bottleneck_dim=64)
    logger.info("Base Model Ready with Adapters.")
    
except ImportError:
    logger.error("Transformers library missing! Can't load ViT-MAE. Using Mock fallback.")
    # Mock Model definition similar to main.py
    class MockViTMAE(nn.Module):
        def __init__(self, dim=768):
            super().__init__()
            self.encoder = nn.Linear(dim, dim) # Dummy
            self.config = type('Config', (), {'hidden_size': dim})()
        def forward(self, x, output_hidden_states=False, **kwargs):
            # x: (B, 3, 224, 224) -> Flatten to (B, D) effectively
            B = x.size(0)
            feat = torch.randn(B, 768).to(x.device)
            class Output: pass
            out = Output()
            out.loss = feat.abs().mean()
            out.hidden_states = [feat.unsqueeze(1)]
            return out
    base_model = MockViTMAE()

# D. Client Manager
client_manager = ClientManager(
    base_model=base_model,
    num_clients=CONFIG["num_clients"],
    gpu_count=CONFIG["gpu_count"]
)

# E. Loss Function
gpad_loss = GPADLoss(
    base_tau=CONFIG["gpad_base_tau"],
    temp_gate=CONFIG["gpad_temp_gate"]
)

# =============================================================================
# 4. RUN PIPELINE (3 Rounds)
# =============================================================================
global_protos = None
global_weights = None

for round_idx in range(1, CONFIG["num_rounds"] + 1):
    logger.info(f"\n--- Starting Round {round_idx} ---")
    
    # 1. Update Clients with Global State (if available)
    if global_weights is not None:
        logger.info("Broadcasting Global Weights to Clients...")
        # In simulation: Manually load into each client
        for client in client_manager.clients:
            client.model.load_state_dict(global_weights, strict=False)
            
    # 2. Train
    logger.info("Clients Training...")
    losses = client_manager.train_round(
        dataloaders,
        global_prototypes=global_protos,
        gpad_loss_fn=gpad_loss
    )
    logger.info(f"Round Losses: {losses}")
    
    # 3. Extract Prototypes & Weights
    client_payloads = []
    logger.info("Extracting Client Payloads...")
    for i, client in enumerate(client_manager.clients):
        # Protos
        local_protos = client.generate_prototypes(dataloaders[i], K_init=CONFIG["k_init_prototypes"])
        
        # Weights (CPU for aggregation)
        weights = {k: v.cpu() for k, v in client.model.state_dict().items()}
        
        client_payloads.append({
            'client_id': f"client_{i}",
            'protos': local_protos.cpu(),
            'weights': weights
        })
        
    # 4. Server Aggregation
    logger.info("Server Aggregating...")
    server_result = run_server_round(
        proto_manager=proto_bank,
        model_server=fed_server,
        client_payloads=client_payloads
    )
    
    global_protos = server_result['global_prototypes']
    global_weights = server_result['global_weights']
    
    if global_protos is not None:
        logger.info(f"Round {round_idx} Complete. Global Prototypes: {global_protos.shape[0]}")
    else:
        logger.info(f"Round {round_idx} Complete. Global Prototypes: 0")

logger.info("\nPipeline Test Finished Successfully!")
