# Chorus: Federated LoRA — IID Validation

**The core question:** 3 organizations each have private data they can't share.
Can they get a model as good as if they pooled all their data together?

| Model | Training Data | Shares data? |
|-------|--------------|--------------|
| **Baseline** (zero-shot) | None | N/A |
| **Centralized** (gold standard) | All 4500 examples | Yes (unrealistic) |
| **Individual clients** | ~1500 examples each | N/A (trains alone) |
| **Federated (Chorus)** | ~1500 each, aggregated | No (privacy preserved) |

**IID split:** Each client gets a random, balanced slice — all classes equally represented.
This is the "fair" case where federation should clearly help.

**Expected outcome if Chorus works:**
- Individual clients: decent but limited (only saw 1/3 of data)
- Federated: close to centralized (combined knowledge without sharing data)
- Centralized: best possible (saw everything)

| Setting | Value |
|---------|-------|
| Model | `Qwen/Qwen2.5-0.5B` (490M, BASE) |
| Dataset | Financial sentiment — ~4.5K train, 500 test, IID split |
| LoRA | rank=16, alpha=32, 3 epochs per round |
| Rounds | 3 |
| Aggregation | FedEx-LoRA (SVD-optimal) |

In [None]:
!pip install -q 'chorus[peft] @ git+https://github.com/varmabudharaju/chorus.git'
!pip install -q scikit-learn
!pip install -q --upgrade typing_extensions pydantic pydantic-core

In [None]:
import torch, gc, os, time, math, json, logging, random
import numpy as np
from collections import Counter

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(name)s] %(message)s', datefmt='%H:%M:%S')

# ── Data directory: auto-detect Colab vs RunPod/SageMaker ───────────────
if os.path.exists("/content"):
    DATA_DIR = "/content/chorus_iid"
else:
    DATA_DIR = os.path.expanduser("~/chorus_iid")
os.makedirs(DATA_DIR, exist_ok=True)
print(f"Data directory: {DATA_DIR}")

# ── Hyperparameters ──────────────────────────────────────────────────────
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
TEST_SIZE = 500
NUM_CLIENTS = 3
NUM_ROUNDS = 3
LORA_RANK = 16
LORA_ALPHA = 32
LEARNING_RATE = 3e-4
NUM_EPOCHS = 3
BATCH_SIZE = 8
GRAD_ACCUM = 2
MAX_SEQ_LEN = 128
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    gpu = torch.cuda.get_device_properties(0)
    print(f"GPU: {gpu.name} | VRAM: {gpu.total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU! Runtime -> Change runtime type -> T4 GPU")

## Step 1: Dataset — Split Across 3 Clients + Centralized

- 500 held out for testing (stratified)
- Remaining ~4500 split equally across 3 clients (~1500 each) — **IID** (random, balanced)
- Centralized model trains on all ~4500 (the "what if we could share" gold standard)

In [None]:
from datasets import load_dataset, Dataset, ClassLabel

full_ds = load_dataset("nickmuchi/financial-classification", split="train")
full_ds = full_ds.shuffle(seed=SEED)
print(f"Total dataset: {len(full_ds)} examples")

LABEL_MAP = {0: "negative", 1: "neutral", 2: "positive"}
print(f"Labels: { {LABEL_MAP[k]: v for k, v in sorted(Counter(full_ds['labels']).items())} }")

full_ds = full_ds.cast_column("labels", ClassLabel(names=["negative", "neutral", "positive"]))
split = full_ds.train_test_split(test_size=TEST_SIZE, seed=SEED, stratify_by_column="labels")
train_ds = split["train"]
test_ds = split["test"]
print(f"Train: {len(train_ds)} | Test: {len(test_ds)}")

# IID split — random equal partitions (each client sees all classes equally)
train_ds = train_ds.shuffle(seed=SEED)
client_shards = [train_ds.shard(num_shards=NUM_CLIENTS, index=i) for i in range(NUM_CLIENTS)]

def format_example(ex):
    return f"Sentence: {ex['text']}\nSentiment: {LABEL_MAP[ex['labels']]}"

# Save client shards
client_json_paths = []
for i, shard in enumerate(client_shards):
    texts = [format_example(shard[j]) for j in range(len(shard))]
    cds = Dataset.from_dict({"text": texts})
    path = f"{DATA_DIR}/client_{i}.json"
    cds.to_json(path)
    client_json_paths.append(path)
    dist = {LABEL_MAP[k]: v for k, v in sorted(Counter(shard['labels']).items())}
    print(f"Client {i}: {len(shard)} examples | {dist}")

# Save centralized training set
central_texts = [format_example(train_ds[j]) for j in range(len(train_ds))]
central_ds = Dataset.from_dict({"text": central_texts})
central_path = f"{DATA_DIR}/centralized_train.json"
central_ds.to_json(central_path)
print(f"\nCentralized: {len(train_ds)} examples → {central_path}")

## Step 2: Model + Evaluation Setup

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report

USE_BF16 = torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False
DTYPE = torch.bfloat16 if USE_BF16 else torch.float16

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

pos_id = tokenizer.encode(" positive", add_special_tokens=False)[0]
neg_id = tokenizer.encode(" negative", add_special_tokens=False)[0]
neu_id = tokenizer.encode(" neutral", add_special_tokens=False)[0]
label_token_ids = {0: neg_id, 1: neu_id, 2: pos_id}
print(f"Model: {MODEL_NAME} | dtype: {DTYPE}")


def evaluate_accuracy(model, test_data, label):
    model.eval()
    preds, golds = [], []
    for ex in test_data:
        prompt = f"Sentence: {ex['text']}\nSentiment:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(device)
        with torch.no_grad():
            logits = model(**inputs).logits[:, -1, :]
        scores = {lbl: logits[0, tid].item() for lbl, tid in label_token_ids.items()}
        preds.append(max(scores, key=scores.get))
        golds.append(ex["labels"])
    acc = accuracy_score(golds, preds)
    report = classification_report(golds, preds, target_names=["negative", "neutral", "positive"], output_dict=True, zero_division=0)
    print(f"  [{label}] Accuracy: {acc:.1%}  |  F1: neg={report['negative']['f1-score']:.2f} neu={report['neutral']['f1-score']:.2f} pos={report['positive']['f1-score']:.2f}")
    return acc, report

## Step 3: Baseline

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
baseline_acc, baseline_report = evaluate_accuracy(base_model, test_ds, "Baseline (zero-shot)")
del base_model; gc.collect(); torch.cuda.empty_cache()

## Step 3b: Centralized Training (Gold Standard)

Train on ALL ~4500 examples — the "what if everyone could share data" upper bound.

In [None]:
from chorus.client.trainer import LoRATrainer as CentralTrainer
from peft import PeftModel

print(f"Training centralized model on all {len(train_ds)} examples, {NUM_EPOCHS} epochs...")
t0 = time.time()

centralized_dir = f"{DATA_DIR}/adapter_centralized"
central_trainer = CentralTrainer(
    base_model=MODEL_NAME,
    dataset=central_path,
    output_dir=centralized_dir,
    lora_rank=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    learning_rate=LEARNING_RATE,
    num_epochs=NUM_EPOCHS,
    per_device_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    max_seq_length=MAX_SEQ_LEN,
    bf16=USE_BF16,
    fp16=not USE_BF16 and torch.cuda.is_available(),
    dataloader_pin_memory=False,
)
central_trainer.train()
print(f"Centralized training done in {time.time() - t0:.0f}s")

central_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
central_model = PeftModel.from_pretrained(central_model, centralized_dir)
centralized_acc, centralized_report = evaluate_accuracy(central_model, test_ds, "Centralized")
del central_model; gc.collect(); torch.cuda.empty_cache()

## Step 4: Start Chorus Server

In [None]:
import requests, threading, socket

def find_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]

SERVER_PORT = find_free_port()
SERVER_URL = f"http://127.0.0.1:{SERVER_PORT}"

from chorus.server.app import configure, app
import uvicorn

configure(model_id=MODEL_NAME, data_dir=f"{DATA_DIR}/chorus_data", strategy="fedex-lora", min_deltas=NUM_CLIENTS)

server_error = []
def run_server():
    try: uvicorn.run(app, host="127.0.0.1", port=SERVER_PORT, log_level="warning")
    except Exception as e: server_error.append(str(e))

threading.Thread(target=run_server, daemon=True).start()

for i in range(15):
    time.sleep(1)
    if server_error: print(f"ERROR: {server_error[0]}"); break
    try:
        r = requests.get(f"{SERVER_URL}/health")
        if r.status_code == 200: print(f"Server running at {SERVER_URL} | strategy: {r.json()['strategy']}"); break
    except: pass

## Step 5: Federated Training — 3 Rounds

Each round:
1. All 3 clients train LoRA on their data (starting from the aggregated adapter after round 1)
2. Submit deltas → server aggregates with FedEx-LoRA
3. Pull the aggregated adapter → use as starting point for next round

In [None]:
from chorus.client.trainer import LoRATrainer
from chorus.client.sdk import ChorusClient
from safetensors.torch import load_file
from peft import PeftModel, LoraConfig, get_peft_model

# PEFT adapter config — needed so PeftModel.from_pretrained() can load aggregated adapters
ADAPTER_CONFIG = {
    "peft_type": "LORA",
    "auto_mapping": None,
    "base_model_name_or_path": MODEL_NAME,
    "bias": "none",
    "fan_in_fan_out": False,
    "inference_mode": True,
    "init_lora_weights": True,
    "lora_alpha": LORA_ALPHA,
    "lora_dropout": 0.0,
    "r": LORA_RANK,
    "target_modules": ["q_proj", "v_proj"],
    "task_type": "CAUSAL_LM",
}

# Track results per round
round_results = []
adapter_path_for_next_round = None  # None = train from scratch

for rnd in range(NUM_ROUNDS):
    print(f"\n{'#'*70}")
    print(f"  ROUND {rnd + 1} / {NUM_ROUNDS}")
    print(f"{'#'*70}")

    round_client_accs = []
    round_client_reports = []

    for i in range(NUM_CLIENTS):
        t0 = time.time()
        print(f"\n  --- Client {i} (round {rnd+1}) ---")

        output_dir = f"{DATA_DIR}/adapter_r{rnd}_client_{i}"

        trainer = LoRATrainer(
            base_model=MODEL_NAME,
            dataset=client_json_paths[i],
            output_dir=output_dir,
            lora_rank=LORA_RANK,
            lora_alpha=LORA_ALPHA,
            learning_rate=LEARNING_RATE,
            num_epochs=NUM_EPOCHS,
            per_device_batch_size=BATCH_SIZE,
            gradient_accumulation_steps=GRAD_ACCUM,
            max_seq_length=MAX_SEQ_LEN,
            bf16=USE_BF16,
            fp16=not USE_BF16 and torch.cuda.is_available(),
            adapter_path=adapter_path_for_next_round,
            dataloader_pin_memory=False,
        )

        adapter_path = trainer.train()
        elapsed = time.time() - t0
        print(f"  Trained in {elapsed:.0f}s")

        # Evaluate
        single_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
        single_model = PeftModel.from_pretrained(single_model, output_dir)
        acc_i, report_i = evaluate_accuracy(single_model, test_ds, f"R{rnd+1} Client {i}")
        round_client_accs.append(acc_i)
        round_client_reports.append(report_i)
        del single_model; gc.collect(); torch.cuda.empty_cache()

        # Submit
        client = ChorusClient(server=SERVER_URL, model_id=MODEL_NAME, client_id=f"client-{i}")
        result = client.submit_delta(adapter_path=output_dir, dataset_size=len(client_shards[i]))
        client.close()

        print(f"  Submitted: {result['deltas_received']}/{result['min_deltas']}")
        if result['aggregated']:
            print(f"  >>> AGGREGATION TRIGGERED <<<")

    # Pull aggregated adapter
    pull_dir = f"{DATA_DIR}/adapter_federated_r{rnd}"
    client = ChorusClient(server=SERVER_URL, model_id=MODEL_NAME)
    agg_path = client.pull_latest(output_path=pull_dir, adapter_config=ADAPTER_CONFIG)
    client.close()

    # Evaluate federated model
    agg_tensors = load_file(str(agg_path))
    ranks = set(v.shape[0] for k, v in agg_tensors.items() if "lora_A" in k)
    agg_rank = max(ranks)

    fed_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
    lora_config = LoraConfig(r=agg_rank, lora_alpha=LORA_ALPHA, target_modules=["q_proj", "v_proj"], lora_dropout=0.0, bias="none", task_type="CAUSAL_LM")
    fed_model = get_peft_model(fed_model, lora_config)

    state_dict = fed_model.state_dict()
    for k, v in agg_tensors.items():
        peft_key = "base_model.model." + k.replace(".lora_A.weight", ".lora_A.default.weight").replace(".lora_B.weight", ".lora_B.default.weight")
        if peft_key in state_dict:
            state_dict[peft_key] = v.to(state_dict[peft_key].dtype).to(device)
    fed_model.load_state_dict(state_dict)

    fed_acc, fed_report = evaluate_accuracy(fed_model, test_ds, f"Federated R{rnd+1}")
    del fed_model; gc.collect(); torch.cuda.empty_cache()

    round_results.append({
        "round": rnd + 1,
        "client_accs": round_client_accs,
        "client_reports": round_client_reports,
        "fed_acc": fed_acc,
        "fed_report": fed_report,
    })

    # Use this round's aggregated adapter as starting point for next round
    adapter_path_for_next_round = pull_dir

    print(f"\n  Round {rnd+1} summary: clients={[f'{a:.1%}' for a in round_client_accs]} | federated={fed_acc:.1%}")

print(f"\n{'#'*70}")
print("ALL ROUNDS COMPLETE")
print(f"{'#'*70}")

## Step 6: Final Results

In [None]:
print("\n" + "=" * 75)
print("RESULTS: Financial Sentiment (IID) — Baseline vs Centralized vs Federated")
print("=" * 75)

print(f"\n{'Model':<45} {'Accuracy':>10} {'Neg F1':>8} {'Neu F1':>8} {'Pos F1':>8}")
print("-" * 81)

def pr(label, acc, report):
    print(f"{label:<45} {acc:>9.1%} {report['negative']['f1-score']:>8.2f} {report['neutral']['f1-score']:>8.2f} {report['positive']['f1-score']:>8.2f}")

pr("Baseline (zero-shot)", baseline_acc, baseline_report)
pr(f"Centralized (all {len(train_ds)} examples)", centralized_acc, centralized_report)
print()

for rr in round_results:
    rnd = rr['round']
    avg_c = sum(rr['client_accs']) / len(rr['client_accs'])
    for i in range(NUM_CLIENTS):
        pr(f"  Round {rnd} — Client {i}", rr['client_accs'][i], rr['client_reports'][i])
    pr(f"  Round {rnd} — FEDERATED", rr['fed_acc'], rr['fed_report'])
    print(f"  {'':>45} avg client: {avg_c:.1%}")
    print()

# Progression
print("=" * 75)
print("PROGRESSION ACROSS ROUNDS")
print("=" * 75)

col_w = 12
header = f"{'':>20}" + "".join(f"{'Round '+str(rr['round']):>{col_w}}" for rr in round_results)
print(f"\n{header}")
print("-" * len(header))

print(f"{'Avg client':>20}", end="")
for rr in round_results:
    avg = sum(rr['client_accs']) / len(rr['client_accs'])
    print(f"{avg:>{col_w}.1%}", end="")
print()

print(f"{'Federated':>20}", end="")
for rr in round_results:
    print(f"{rr['fed_acc']:>{col_w}.1%}", end="")
print()

print(f"{'Centralized':>20}", end="")
for _ in round_results:
    print(f"{centralized_acc:>{col_w}.1%}", end="")
print("  ← target")

print(f"{'Baseline':>20}", end="")
for _ in round_results:
    print(f"{baseline_acc:>{col_w}.1%}", end="")
print("  ← floor")

# Verdict
print("\n" + "=" * 75)
print("VERDICT")
print("=" * 75)

final = round_results[-1]
final_fed = final['fed_acc']
final_avg_client = sum(final['client_accs']) / len(final['client_accs'])
final_best = max(final['client_accs'])
r1_fed = round_results[0]['fed_acc']

gap_to_central = centralized_acc - baseline_acc
fed_improvement = final_fed - baseline_acc
if gap_to_central > 0:
    recovery = fed_improvement / gap_to_central * 100
else:
    recovery = 0

print(f"\nBaseline (zero-shot):  {baseline_acc:.1%}")
print(f"Centralized (target):  {centralized_acc:.1%}")
print(f"Federated Round 1:     {r1_fed:.1%}")
print(f"Federated Round {NUM_ROUNDS}:     {final_fed:.1%}")
print(f"Avg client (final):    {final_avg_client:.1%}")
print(f"Best client (final):   {final_best:.1%}")
print(f"\nFederation recovered {recovery:.0f}% of the centralized gain over baseline")
print(f"Round 1 → Round {NUM_ROUNDS} improvement: {(final_fed - r1_fed)*100:+.1f} pp")
print(f"Federation vs avg client: {(final_fed - final_avg_client)*100:+.1f} pp")

if final_fed >= centralized_acc * 0.95:
    print("\n>>> FEDERATION MATCHED CENTRALIZED (within 5%). IID case works! <<<")
elif final_fed > final_best:
    print("\n>>> FEDERATION BEAT EVERY CLIENT! Collaboration works. <<<")
elif final_fed > final_avg_client:
    print("\n>>> Federation beat average client. Partial success. <<<")
else:
    print("\n>>> Federation underperformed. <<<")

In [None]:
print("Done! Chorus server will shut down when this session ends.")