In [None]:
# Install Deps
# !pip install -r requirements.txt

from data_manager import DataManager
from core import FederatedClient, Server, AlignmentManager
import torch

# 1. Setup Data
dm = DataManager()
datasets = dm.setup_mini_drake()

# 2. Setup Server
server = Server(num_clients=10)

# 3. Setup Clients (Heterogeneous)
# Clients 0-4: Small Model
# Clients 5-9: Large Model
clients = []
for i in range(10):
    model_type = "small" if i < 5 else "large"
    client = FederatedClient(client_id=i, model_type=model_type, dataset=datasets[i])
    clients.append(client)
    server.register_client(client)

# 4. Alignment Phase (One-time)
aligner = AlignmentManager()
aligner.align_models()

# 5. Federated Training Loop
NUM_ROUNDS = 3
history = []

print(f"\n--- Starting Simulation for {NUM_ROUNDS} Rounds ---")
for r in range(NUM_ROUNDS):
    print(f"\n>>> Round {r+1} <<<")
    
    updates = []
    
    # Train each client sequentially (to save RAM)
    for client in clients:
        # Pass global params
        update = client.train(server.global_P, server.global_Q)
        updates.append(update)
        
    # Server Aggregation
    server.aggregate(updates)
    
    # (Optional) Evaluate performance here
    # acc = evaluate_on_unseen()
    # history.append(acc)

print("\nSimulation Complete.")