# Chorus: Federated Training (Non-IID) — Part 2

**Run `chorus_centralized_baseline.ipynb` in parallel** — it trains the centralized gold standard.
This notebook runs the federated loop with non-IID label-skewed clients, then loads centralized
results at the end for the full comparison.

| Model | Training Data | Data Distribution | Purpose |
|-------|--------------|-------------------|----------|
| **Client-only** (×3) | ~670 each | Non-IID (label skewed) | What each org gets alone |
| **Federated** (3 rounds) | ~670 each, aggregated | Non-IID (label skewed) | Chorus + FedEx-LoRA |

| Setting | Value |
|---------|-------|
| Model | `Qwen/Qwen2.5-0.5B` (490M params, BASE) |
| Dataset | AG News — 2K train, 500 test, 4 classes |
| Non-IID | Dirichlet(α=0.3) label skew across 3 clients |
| LoRA | rank=16, alpha=32, 1 epoch |
| Aggregation | FedEx-LoRA (SVD-optimal) |

In [None]:
!pip install -q pyarrow==17.0.0
!pip install -q --upgrade typing_extensions pydantic pydantic-core
!pip install -q 'chorus[peft] @ git+https://github.com/varmabudharaju/chorus.git'
!pip install -q scikit-learn

In [None]:
import torch, gc, os, time, json, logging, random
import numpy as np
from collections import Counter

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(name)s] %(message)s', datefmt='%H:%M:%S')

# ── Data directory: auto-detect Colab vs SageMaker ───────────────
if os.path.exists("/content"):
    DATA_DIR = "/content/chorus_exp"        # Colab
else:
    DATA_DIR = os.path.expanduser("~/chorus_exp")  # SageMaker / local
os.makedirs(DATA_DIR, exist_ok=True)
print(f"Data directory: {DATA_DIR}")

# ── Hyperparameters (must match centralized notebook) ────────────────
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
DATASET_SIZE = 2000         # Training examples (from 120K available)
TEST_SIZE = 500             # Subsample test set for eval
NUM_CLIENTS = 3
NUM_ROUNDS = 3
LORA_RANK = 16
LORA_ALPHA = 32
LEARNING_RATE = 3e-4
NUM_EPOCHS = 1
BATCH_SIZE = 8
GRAD_ACCUM = 2
MAX_SEQ_LEN = 128
DIRICHLET_ALPHA = 0.3       # Lower = more skewed
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    gpu = torch.cuda.get_device_properties(0)
    print(f"GPU: {gpu.name} | VRAM: {gpu.total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU! Runtime → Change runtime type → T4 GPU")

## Step 1: Dataset — AG News

- 4 classes: World (0), Sports (1), Business (2), Sci/Tech (3)
- Subsample 6K from 120K training set (same subset as centralized notebook)
- Subsample 2K from 7.6K test split for evaluation

In [None]:
from datasets import load_dataset, Dataset

LABEL_NAMES = ["world", "sports", "business", "scitech"]
LABEL_MAP = {0: "world", 1: "sports", 2: "business", 3: "scitech"}

raw_train = load_dataset("ag_news", split="train").shuffle(seed=SEED)
raw_test = load_dataset("ag_news", split="test").shuffle(seed=SEED)

train_ds = raw_train.select(range(DATASET_SIZE))
test_ds = raw_test.select(range(TEST_SIZE))

print(f"Train: {len(train_ds)} | Test: {len(test_ds)}")
print(f"Train label distribution: {dict(sorted(Counter(train_ds['label']).items()))}")
print(f"Test label distribution:  {dict(sorted(Counter(test_ds['label']).items()))}")

## Step 2: Non-IID Split — Dirichlet(α=0.3)

Each client gets a skewed subset where some classes dominate.
Standard FL benchmark approach (McMahan et al.).

In [None]:
def dirichlet_split(dataset, num_clients, alpha, seed=42):
    """Split dataset across clients using Dirichlet distribution for label skew."""
    rng = np.random.default_rng(seed)
    labels = np.array(dataset["label"])
    num_classes = len(set(labels))
    client_indices = [[] for _ in range(num_clients)]

    for c in range(num_classes):
        class_idx = np.where(labels == c)[0]
        rng.shuffle(class_idx)
        proportions = rng.dirichlet(np.repeat(alpha, num_clients))
        counts = (proportions * len(class_idx)).astype(int)
        counts[-1] = len(class_idx) - counts[:-1].sum()
        start = 0
        for i in range(num_clients):
            client_indices[i].extend(class_idx[start:start + counts[i]].tolist())
            start += counts[i]

    for idx_list in client_indices:
        rng.shuffle(idx_list)

    return [dataset.select(indices) for indices in client_indices]


client_shards = dirichlet_split(train_ds, NUM_CLIENTS, DIRICHLET_ALPHA, seed=SEED)

def format_example(ex):
    return f"Sentence: {ex['text']}\nCategory: {LABEL_MAP[ex['label']]}"

print(f"{'Client':<10} {'Total':>6} {'World':>7} {'Sports':>7} {'Business':>9} {'SciTech':>8}")
print("-" * 50)

client_json_paths = []
for i, shard in enumerate(client_shards):
    dist = Counter(shard["label"])
    print(f"Client {i:<3} {len(shard):>6} {dist.get(0,0):>7} {dist.get(1,0):>7} {dist.get(2,0):>9} {dist.get(3,0):>8}")

    texts = [format_example(shard[j]) for j in range(len(shard))]
    cds = Dataset.from_dict({"text": texts})
    path = f"{DATA_DIR}/client_{i}.json"
    cds.to_json(path)
    client_json_paths.append(path)

## Step 3: Model + Evaluation Setup

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report

USE_BF16 = torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False
DTYPE = torch.bfloat16 if USE_BF16 else torch.float16

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

label_token_ids = {
    0: tokenizer.encode(" world", add_special_tokens=False)[0],
    1: tokenizer.encode(" sports", add_special_tokens=False)[0],
    2: tokenizer.encode(" business", add_special_tokens=False)[0],
    3: tokenizer.encode(" scitech", add_special_tokens=False)[0],
}
print(f"Label tokens: {label_token_ids}")
print(f"Model: {MODEL_NAME} | dtype: {DTYPE}")


def evaluate_accuracy(model, test_data, label):
    """Evaluate via next-token logit comparison over label tokens."""
    model.eval()
    preds, golds = [], []
    for ex in test_data:
        prompt = f"Sentence: {ex['text']}\nCategory:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(device)
        with torch.no_grad():
            logits = model(**inputs).logits[:, -1, :]
        scores = {lbl: logits[0, tid].item() for lbl, tid in label_token_ids.items()}
        preds.append(max(scores, key=scores.get))
        golds.append(ex["label"])
    acc = accuracy_score(golds, preds)
    report = classification_report(
        golds, preds, target_names=LABEL_NAMES,
        output_dict=True, zero_division=0
    )
    print(
        f"  [{label}] Accuracy: {acc:.1%}  |  F1: "
        f"world={report['world']['f1-score']:.2f} "
        f"sports={report['sports']['f1-score']:.2f} "
        f"business={report['business']['f1-score']:.2f} "
        f"scitech={report['scitech']['f1-score']:.2f}"
    )
    return acc, report

## Step 4: Start Chorus Server

In [None]:
import requests, threading, socket

def find_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]

SERVER_PORT = find_free_port()
SERVER_URL = f"http://127.0.0.1:{SERVER_PORT}"

from chorus.server.app import configure, app
import uvicorn

configure(
    model_id=MODEL_NAME,
    data_dir=f"{DATA_DIR}/chorus_data",
    strategy="fedex-lora",
    min_deltas=NUM_CLIENTS,
)

server_error = []
def run_server():
    try:
        uvicorn.run(app, host="127.0.0.1", port=SERVER_PORT, log_level="warning")
    except Exception as e:
        server_error.append(str(e))

threading.Thread(target=run_server, daemon=True).start()

for i in range(15):
    time.sleep(1)
    if server_error:
        print(f"ERROR: {server_error[0]}")
        break
    try:
        r = requests.get(f"{SERVER_URL}/health")
        if r.status_code == 200:
            print(f"Server running at {SERVER_URL} | strategy: {r.json()['strategy']}")
            break
    except:
        pass

## Step 5: Federated Training — 3 Rounds × 3 Non-IID Clients

Each round:
1. All 3 clients train LoRA on their skewed data (starting from aggregated adapter after round 1)
2. Submit deltas → server aggregates with FedEx-LoRA
3. Pull aggregated adapter → evaluate → use as starting point for next round

In [None]:
from chorus.client.trainer import LoRATrainer
from chorus.client.sdk import ChorusClient
from safetensors.torch import load_file
from peft import PeftModel, LoraConfig, get_peft_model

# PEFT adapter config for loading aggregated adapters
ADAPTER_CONFIG = {
    "peft_type": "LORA",
    "auto_mapping": None,
    "base_model_name_or_path": MODEL_NAME,
    "bias": "none",
    "fan_in_fan_out": False,
    "inference_mode": True,
    "init_lora_weights": True,
    "lora_alpha": LORA_ALPHA,
    "lora_dropout": 0.0,
    "r": LORA_RANK,
    "target_modules": ["q_proj", "v_proj"],
    "task_type": "CAUSAL_LM",
}

round_results = []
adapter_path_for_next_round = None

for rnd in range(NUM_ROUNDS):
    print(f"\n{'#'*70}")
    print(f"  ROUND {rnd + 1} / {NUM_ROUNDS}")
    print(f"{'#'*70}")

    round_client_accs = []
    round_client_reports = []

    for i in range(NUM_CLIENTS):
        t0 = time.time()
        print(f"\n  --- Client {i} (round {rnd+1}) ---")

        output_dir = f"{DATA_DIR}/adapter_r{rnd}_client_{i}"

        trainer = LoRATrainer(
            base_model=MODEL_NAME,
            dataset=client_json_paths[i],
            output_dir=output_dir,
            lora_rank=LORA_RANK,
            lora_alpha=LORA_ALPHA,
            learning_rate=LEARNING_RATE,
            num_epochs=NUM_EPOCHS,
            per_device_batch_size=BATCH_SIZE,
            gradient_accumulation_steps=GRAD_ACCUM,
            max_seq_length=MAX_SEQ_LEN,
            bf16=USE_BF16,
            fp16=not USE_BF16 and torch.cuda.is_available(),
            adapter_path=adapter_path_for_next_round,
            dataloader_pin_memory=False,
        )

        adapter_path = trainer.train()
        elapsed = time.time() - t0
        print(f"  Trained in {elapsed:.0f}s")

        # Evaluate individual client
        single_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
        single_model = PeftModel.from_pretrained(single_model, output_dir)
        acc_i, report_i = evaluate_accuracy(single_model, test_ds, f"R{rnd+1} Client {i}")
        round_client_accs.append(acc_i)
        round_client_reports.append(report_i)
        del single_model; gc.collect(); torch.cuda.empty_cache()

        # Submit delta
        client = ChorusClient(server=SERVER_URL, model_id=MODEL_NAME, client_id=f"client-{i}")
        result = client.submit_delta(adapter_path=output_dir, dataset_size=len(client_shards[i]))
        client.close()

        print(f"  Submitted: {result['deltas_received']}/{result['min_deltas']}")
        if result["aggregated"]:
            print(f"  >>> AGGREGATION TRIGGERED <<<")

    # Pull aggregated adapter
    pull_dir = f"{DATA_DIR}/adapter_federated_r{rnd}"
    client = ChorusClient(server=SERVER_URL, model_id=MODEL_NAME)
    agg_path = client.pull_latest(output_path=pull_dir, adapter_config=ADAPTER_CONFIG)
    client.close()

    # Evaluate federated model
    agg_tensors = load_file(str(agg_path))
    ranks = set(v.shape[0] for k, v in agg_tensors.items() if "lora_A" in k)
    agg_rank = max(ranks)

    fed_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
    lora_config = LoraConfig(
        r=agg_rank, lora_alpha=LORA_ALPHA,
        target_modules=["q_proj", "v_proj"],
        lora_dropout=0.0, bias="none", task_type="CAUSAL_LM",
    )
    fed_model = get_peft_model(fed_model, lora_config)

    state_dict = fed_model.state_dict()
    for k, v in agg_tensors.items():
        peft_key = "base_model.model." + k.replace(
            ".lora_A.weight", ".lora_A.default.weight"
        ).replace(
            ".lora_B.weight", ".lora_B.default.weight"
        )
        if peft_key in state_dict:
            state_dict[peft_key] = v.to(state_dict[peft_key].dtype).to(device)
    fed_model.load_state_dict(state_dict)

    fed_acc, fed_report = evaluate_accuracy(fed_model, test_ds, f"Federated R{rnd+1}")
    del fed_model; gc.collect(); torch.cuda.empty_cache()

    round_results.append({
        "round": rnd + 1,
        "client_accs": round_client_accs,
        "client_reports": round_client_reports,
        "fed_acc": fed_acc,
        "fed_report": fed_report,
    })

    adapter_path_for_next_round = pull_dir

    print(f"\n  Round {rnd+1} summary: clients={[f'{a:.1%}' for a in round_client_accs]} | federated={fed_acc:.1%}")

print(f"\n{'#'*70}")
print("ALL ROUNDS COMPLETE")
print(f"{'#'*70}")

## Step 6: Load Centralized Results

Load baseline and centralized results from the parallel notebook.

In [None]:
results_path = f"{DATA_DIR}/centralized_results.json"
try:
    with open(results_path) as f:
        central_results = json.load(f)
    baseline_acc = central_results["baseline_acc"]
    baseline_report = central_results["baseline_report"]
    centralized_acc = central_results["centralized_acc"]
    centralized_report = central_results["centralized_report"]
    print(f"Loaded centralized results from {results_path}")
    print(f"  Baseline:    {baseline_acc:.1%}")
    print(f"  Centralized: {centralized_acc:.1%}")
except FileNotFoundError:
    print(f"ERROR: {results_path} not found!")
    print("Run chorus_centralized_baseline.ipynb first (or in parallel).")
    print("Proceeding without centralized comparison...")
    baseline_acc = None
    baseline_report = None
    centralized_acc = None
    centralized_report = None

## Step 7: Results Table

In [None]:
print("\n" + "=" * 85)
print("RESULTS: AG News Classification — Centralized vs Federated (Non-IID)")
print("=" * 85)

header = f"{'Model':<40} {'Accuracy':>10} {'World F1':>9} {'Sports F1':>10} {'Biz F1':>8} {'Sci F1':>8}"
print(f"\n{header}")
print("-" * len(header))

def pr(label, acc, report):
    print(
        f"{label:<40} {acc:>9.1%} "
        f"{report['world']['f1-score']:>9.2f} "
        f"{report['sports']['f1-score']:>10.2f} "
        f"{report['business']['f1-score']:>8.2f} "
        f"{report['scitech']['f1-score']:>8.2f}"
    )

if baseline_acc is not None:
    pr("Baseline (zero-shot)", baseline_acc, baseline_report)
if centralized_acc is not None:
    pr(f"Centralized (all {DATASET_SIZE}, IID)", centralized_acc, centralized_report)
print()

for rr in round_results:
    rnd = rr["round"]
    for i in range(NUM_CLIENTS):
        pr(f"  Round {rnd} — Client {i}", rr["client_accs"][i], rr["client_reports"][i])
    pr(f"  Round {rnd} — FEDERATED", rr["fed_acc"], rr["fed_report"])
    avg_c = sum(rr["client_accs"]) / len(rr["client_accs"])
    print(f"  {'':>40} avg client: {avg_c:.1%}")
    print()

## Step 8: Progression Across Rounds

In [None]:
print("=" * 65)
print("PROGRESSION ACROSS ROUNDS")
print("=" * 65)

col_w = 12
header = f"{'':>20}" + "".join(f"{'Round '+str(rr['round']):>{col_w}}" for rr in round_results)
print(f"\n{header}")
print("-" * len(header))

print(f"{'Avg client':>20}", end="")
for rr in round_results:
    avg = sum(rr["client_accs"]) / len(rr["client_accs"])
    print(f"{avg:>{col_w}.1%}", end="")
print()

print(f"{'Federated':>20}", end="")
for rr in round_results:
    print(f"{rr['fed_acc']:>{col_w}.1%}", end="")
print()

if centralized_acc is not None:
    print(f"{'Centralized':>20}", end="")
    for _ in round_results:
        print(f"{centralized_acc:>{col_w}.1%}", end="")
    print("  ← target")

if baseline_acc is not None:
    print(f"{'Baseline':>20}", end="")
    for _ in round_results:
        print(f"{baseline_acc:>{col_w}.1%}", end="")
    print("  ← floor")

## Step 9: Verdict

In [None]:
print("\n" + "=" * 65)
print("VERDICT")
print("=" * 65)

final = round_results[-1]
final_fed = final["fed_acc"]
final_avg_client = sum(final["client_accs"]) / len(final["client_accs"])
final_best_client = max(final["client_accs"])
r1_fed = round_results[0]["fed_acc"]

if baseline_acc is not None:
    print(f"\nBaseline (zero-shot):  {baseline_acc:.1%}")
if centralized_acc is not None:
    print(f"Centralized (target):  {centralized_acc:.1%}")
print(f"Federated Round 1:     {r1_fed:.1%}")
print(f"Federated Round {NUM_ROUNDS}:     {final_fed:.1%}")
print(f"Avg client (final):    {final_avg_client:.1%}")
print(f"Best client (final):   {final_best_client:.1%}")

if centralized_acc is not None and baseline_acc is not None:
    gap_to_central = centralized_acc - baseline_acc
    fed_improvement = final_fed - baseline_acc
    if gap_to_central > 0:
        recovery = fed_improvement / gap_to_central * 100
    else:
        recovery = 0
    print(f"\nFederation recovered {recovery:.0f}% of the centralized gain over baseline")

print(f"Round 1 → Round {NUM_ROUNDS} improvement: {(final_fed - r1_fed)*100:+.1f} pp")
print(f"Federation vs avg client: {(final_fed - final_avg_client)*100:+.1f} pp")

if centralized_acc is not None and final_fed >= centralized_acc * 0.95:
    print("\n>>> FEDERATION MATCHED CENTRALIZED (within 5%). Non-IID challenge overcome! <<<")
elif final_fed > final_best_client:
    print("\n>>> FEDERATION BEAT EVERY CLIENT. Collaboration works even with skewed data. <<<")
elif final_fed > final_avg_client:
    print("\n>>> Federation beat average client. Partial success — non-IID hurts but doesn't break it. <<<")
else:
    print("\n>>> Federation underperformed. Non-IID skew was too severe. <<<")

In [None]:
print("Done! Chorus server will shut down when this session ends.")