# Chorus: Federated LoRA — End-to-End Validation

**Goal:** Prove that federating LoRA adapters via Chorus produces a model better than any single client.

| Setting | Value |
|---------|-------|
| Model | `facebook/opt-125m` (125M params, BASE — not instruction-tuned) |
| Dataset | `financial_phrasebank` — 4.8K financial sentiment sentences |
| Task | 3-class sentiment: positive / neutral / negative |
| Clients | 3 (non-IID split — each sees different label distributions) |
| Rounds | 1 |
| LoRA rank | 8 |
| Aggregation | FedEx-LoRA (SVD-optimal) |
| Metric | **Accuracy** (not just loss) |
| Runtime | **~15 min total** on T4 |

## Step 0: Install

In [None]:
!pip install -q 'chorus[peft] @ git+https://github.com/varmabudharaju/chorus.git'
!pip install -q scikit-learn

In [None]:
import torch, gc, os, time, math, json, logging
import numpy as np
from collections import Counter

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(name)s] %(message)s', datefmt='%H:%M:%S')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    gpu = torch.cuda.get_device_properties(0)
    print(f"GPU: {gpu.name} | VRAM: {gpu.total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU! Go to Runtime -> Change runtime type -> T4 GPU")

In [None]:
from datasets import load_dataset, Dataset

# Load financial sentiment dataset (parquet format — no custom scripts)
ds = load_dataset("nickmuchi/financial-classification", split="train")
print(f"Total examples: {len(ds)}")

# Label mapping: 0=negative, 1=neutral, 2=positive
LABEL_MAP = {0: "negative", 1: "neutral", 2: "positive"}
label_counts = Counter(ds["labels"])
print(f"Label distribution: { {LABEL_MAP[k]: v for k, v in sorted(label_counts.items())} }")
print(f"Sample: '{ds[0]['text'][:80]}...' -> {LABEL_MAP[ds[0]['labels']]}")

# Shuffle with seed
ds = ds.shuffle(seed=42)

# Hold out 20% for testing (stratified)
split = ds.train_test_split(test_size=0.2, seed=42, stratify_by_column="labels")
train_ds = split["train"]
test_ds = split["test"]
print(f"\nTrain: {len(train_ds)} | Test: {len(test_ds)}")

# Non-IID split: group by label, distribute unevenly
by_label = {0: [], 1: [], 2: []}  # neg, neu, pos
for i, ex in enumerate(train_ds):
    by_label[ex["labels"]].append(i)

# Client 0: mostly neutral (70%) + some positive (30%)
# Client 1: mostly positive (60%) + some negative (40%)
# Client 2: mostly negative (60%) + some neutral (40%)
import random
random.seed(42)

neg_idx = by_label[0].copy(); random.shuffle(neg_idx)
neu_idx = by_label[1].copy(); random.shuffle(neu_idx)
pos_idx = by_label[2].copy(); random.shuffle(pos_idx)

n_neg, n_neu, n_pos = len(neg_idx), len(neu_idx), len(pos_idx)

client_indices = [
    neu_idx[:int(n_neu*0.7)] + pos_idx[:int(n_pos*0.3)],
    pos_idx[int(n_pos*0.3):int(n_pos*0.9)] + neg_idx[:int(n_neg*0.4)],
    neg_idx[int(n_neg*0.4):] + neu_idx[int(n_neu*0.7):],
]

# Format as text for causal LM training
def format_example(ex):
    return f"Sentence: {ex['text']}\nSentiment: {LABEL_MAP[ex['labels']]}"

client_datasets = []
for i, indices in enumerate(client_indices):
    random.shuffle(indices)
    subset = train_ds.select(indices)
    texts = [format_example(subset[j]) for j in range(len(subset))]
    cds = Dataset.from_dict({"text": texts})
    path = f"/content/client_{i}.json"
    cds.to_json(path)

    labels_in_client = Counter(subset["labels"])
    dist = {LABEL_MAP[k]: v for k, v in sorted(labels_in_client.items())}
    print(f"Client {i}: {len(cds)} examples | {dist}")
    client_datasets.append(cds)

print(f"\nTest set: {len(test_ds)} examples (shared across all evaluations)")

## Step 2: Evaluation Setup — Accuracy on Sentiment Classification

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report

MODEL_NAME = "facebook/opt-125m"
DTYPE = torch.float16

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Token IDs for our 3 labels
pos_id = tokenizer.encode(" positive", add_special_tokens=False)[0]
neg_id = tokenizer.encode(" negative", add_special_tokens=False)[0]
neu_id = tokenizer.encode(" neutral", add_special_tokens=False)[0]
label_token_ids = {0: neg_id, 1: neu_id, 2: pos_id}
print(f"Label tokens: positive={pos_id}, neutral={neu_id}, negative={neg_id}")


def evaluate_accuracy(model, test_data, label):
    """Measure 3-class sentiment accuracy. Returns (accuracy, per_class_report)."""
    model.eval()
    preds, golds = [], []

    for ex in test_data:
        prompt = f"Sentence: {ex['sentence']}\nSentiment:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128).to(device)

        with torch.no_grad():
            logits = model(**inputs).logits[:, -1, :]  # logits for next token

        # Pick highest prob among our 3 label tokens
        scores = {lbl: logits[0, tid].item() for lbl, tid in label_token_ids.items()}
        pred = max(scores, key=scores.get)
        preds.append(pred)
        golds.append(ex["label"])

    acc = accuracy_score(golds, preds)
    report = classification_report(
        golds, preds, target_names=["negative", "neutral", "positive"],
        output_dict=True, zero_division=0
    )
    print(f"  [{label}] Accuracy: {acc:.1%}")
    print(f"    Per-class F1: neg={report['negative']['f1-score']:.2f}  neu={report['neutral']['f1-score']:.2f}  pos={report['positive']['f1-score']:.2f}")
    return acc, report


print("Evaluation functions ready.")

In [ ]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report

MODEL_NAME = "facebook/opt-125m"
DTYPE = torch.float16

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

# Token IDs for our 3 labels
pos_id = tokenizer.encode(" positive", add_special_tokens=False)[0]
neg_id = tokenizer.encode(" negative", add_special_tokens=False)[0]
neu_id = tokenizer.encode(" neutral", add_special_tokens=False)[0]
label_token_ids = {0: neg_id, 1: neu_id, 2: pos_id}
print(f"Label tokens: positive={pos_id}, neutral={neu_id}, negative={neg_id}")


def evaluate_accuracy(model, test_data, label):
    """Measure 3-class sentiment accuracy. Returns (accuracy, per_class_report)."""
    model.eval()
    preds, golds = [], []

    for ex in test_data:
        prompt = f"Sentence: {ex['text']}\nSentiment:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=128).to(device)

        with torch.no_grad():
            logits = model(**inputs).logits[:, -1, :]  # logits for next token

        # Pick highest prob among our 3 label tokens
        scores = {lbl: logits[0, tid].item() for lbl, tid in label_token_ids.items()}
        pred = max(scores, key=scores.get)
        preds.append(pred)
        golds.append(ex["labels"])

    acc = accuracy_score(golds, preds)
    report = classification_report(
        golds, preds, target_names=["negative", "neutral", "positive"],
        output_dict=True, zero_division=0
    )
    print(f"  [{label}] Accuracy: {acc:.1%}")
    print(f"    Per-class F1: neg={report['negative']['f1-score']:.2f}  neu={report['neutral']['f1-score']:.2f}  pos={report['positive']['f1-score']:.2f}")
    return acc, report


print("Evaluation functions ready.")

In [None]:
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)

baseline_acc, baseline_report = evaluate_accuracy(base_model, test_ds, "Baseline (zero-shot)")

del base_model; gc.collect(); torch.cuda.empty_cache()
print(f"\nBaseline: {baseline_acc:.1%} (expected ~33% for random 3-class)")

## Step 4: Start the Chorus Server

In [None]:
import requests, threading, socket

def find_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]

SERVER_PORT = find_free_port()
SERVER_URL = f"http://127.0.0.1:{SERVER_PORT}"

from chorus.server.app import configure, app
import uvicorn

configure(
    model_id=MODEL_NAME,
    data_dir="/content/chorus_data",
    strategy="fedex-lora",
    min_deltas=3,
)

server_error = []
def run_server():
    try:
        uvicorn.run(app, host="127.0.0.1", port=SERVER_PORT, log_level="warning")
    except Exception as e:
        server_error.append(str(e))

threading.Thread(target=run_server, daemon=True).start()

print(f"Starting Chorus server on port {SERVER_PORT}...")
for i in range(15):
    time.sleep(1)
    if server_error:
        print(f"ERROR: {server_error[0]}")
        break
    try:
        r = requests.get(f"{SERVER_URL}/health")
        if r.status_code == 200:
            print(f"Server running at {SERVER_URL}")
            print(f"  Strategy: {r.json()['strategy']}")
            break
    except requests.ConnectionError:
        pass
else:
    if not server_error:
        print("ERROR: Server failed to start")

## Step 5: Train 3 Clients → Submit → Aggregate

Each client trains LoRA on its non-IID partition (~1,300 examples, ~2-4 min each).

In [None]:
from chorus.client.trainer import LoRATrainer
from chorus.client.sdk import ChorusClient
from peft import PeftModel

NUM_CLIENTS = 3
client_accs = []
client_reports = []

for i in range(NUM_CLIENTS):
    t0 = time.time()
    print(f"\n{'='*60}")
    print(f"CLIENT {i}: {len(client_datasets[i])} examples")
    print(f"{'='*60}")

    output_dir = f"/content/adapter_client_{i}"

    trainer = LoRATrainer(
        base_model=MODEL_NAME,
        dataset=f"/content/client_{i}.json",
        output_dir=output_dir,
        lora_rank=8,
        lora_alpha=16,
        learning_rate=3e-4,
        num_epochs=3,
        per_device_batch_size=8,
        gradient_accumulation_steps=2,
        max_seq_length=128,
        bf16=False,
        fp16=True,
        dataloader_pin_memory=False,
    )

    adapter_path = trainer.train()
    elapsed = time.time() - t0
    print(f"  Trained in {elapsed:.0f}s")

    # Evaluate this client alone
    single_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
    single_model = PeftModel.from_pretrained(single_model, output_dir)
    acc_i, report_i = evaluate_accuracy(single_model, test_ds, f"Client {i}")
    client_accs.append(acc_i)
    client_reports.append(report_i)
    del single_model; gc.collect(); torch.cuda.empty_cache()

    # Submit to Chorus
    client = ChorusClient(server=SERVER_URL, model_id=MODEL_NAME, client_id=f"client-{i}")
    result = client.submit_delta(adapter_path=output_dir, dataset_size=len(client_datasets[i]))
    client.close()

    print(f"  Submitted: {result['deltas_received']}/{result['min_deltas']}")
    if result['aggregated']:
        print(f"  >>> FedEx-LoRA AGGREGATION TRIGGERED! <<<")

print(f"\nAll clients done.")

## Step 6: Pull Federated Adapter → Evaluate

In [None]:
from safetensors.torch import load_file
from peft import LoraConfig, get_peft_model

# Pull aggregated adapter
client = ChorusClient(server=SERVER_URL, model_id=MODEL_NAME)
agg_path = client.pull_latest(output_path="/content/adapter_federated")
client.close()

agg_tensors = load_file(str(agg_path))
ranks = set(v.shape[0] for k, v in agg_tensors.items() if "lora_A" in k)
agg_rank = max(ranks)
print(f"Aggregated adapter: {len(agg_tensors)} tensors, rank {agg_rank}")
print(f"File size: {os.path.getsize(agg_path) / 1e6:.2f} MB")

# Load into model
fed_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
lora_config = LoraConfig(
    r=agg_rank, lora_alpha=agg_rank,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.0, bias="none", task_type="CAUSAL_LM",
)
fed_model = get_peft_model(fed_model, lora_config)

state_dict = fed_model.state_dict()
loaded = 0
for k, v in agg_tensors.items():
    peft_key = (
        "base_model.model." +
        k.replace(".lora_A.weight", ".lora_A.default.weight")
         .replace(".lora_B.weight", ".lora_B.default.weight")
    )
    if peft_key in state_dict:
        state_dict[peft_key] = v.to(state_dict[peft_key].dtype).to(device)
        loaded += 1
fed_model.load_state_dict(state_dict)
print(f"Loaded {loaded}/{len(agg_tensors)} tensors")

# Evaluate
fed_acc, fed_report = evaluate_accuracy(fed_model, test_ds, "Federated (FedEx-LoRA)")
del fed_model; gc.collect(); torch.cuda.empty_cache()

## Step 7: Results

In [None]:
print("\n" + "=" * 70)
print("RESULTS: Financial Sentiment Classification")
print("=" * 70)

# Accuracy table
print(f"\n{'Model':<40} {'Accuracy':>10} {'Neg F1':>8} {'Neu F1':>8} {'Pos F1':>8}")
print("-" * 76)

def print_row(label, acc, report):
    neg_f1 = report['negative']['f1-score']
    neu_f1 = report['neutral']['f1-score']
    pos_f1 = report['positive']['f1-score']
    print(f"{label:<40} {acc:>9.1%} {neg_f1:>8.2f} {neu_f1:>8.2f} {pos_f1:>8.2f}")

print_row("Baseline (zero-shot, no training)", baseline_acc, baseline_report)
for i in range(NUM_CLIENTS):
    print_row(f"Client {i} alone", client_accs[i], client_reports[i])
print("-" * 76)
print_row("Federated (FedEx-LoRA, 3 clients)", fed_acc, fed_report)

# Summary
avg_client_acc = sum(client_accs) / len(client_accs)
best_client_acc = max(client_accs)
best_client_idx = client_accs.index(best_client_acc)

print(f"\n" + "-" * 76)
print(f"\nBaseline accuracy:       {baseline_acc:.1%}")
print(f"Best single client:      {best_client_acc:.1%} (Client {best_client_idx})")
print(f"Average client:          {avg_client_acc:.1%}")
print(f"Federated:               {fed_acc:.1%}")
print(f"\nFederation vs baseline:     +{(fed_acc - baseline_acc)*100:.1f} percentage points")
print(f"Federation vs best client:  {(fed_acc - best_client_acc)*100:+.1f} percentage points")
print(f"Federation vs avg client:   {(fed_acc - avg_client_acc)*100:+.1f} percentage points")

if fed_acc > best_client_acc:
    print(f"\n>>> FEDERATION BEAT EVERY INDIVIDUAL CLIENT! Chorus works. <<<")
elif fed_acc > avg_client_acc:
    print(f"\n>>> Federation beat the average client but not the best. Partial success. <<<")
else:
    print(f"\n>>> Federation did not beat avg client. Needs more rounds or tuning. <<<")

## Step 8: Visual — Per-Class Accuracy Comparison

In [None]:
# Show where federation helps: per-class F1 comparison
classes = ["negative", "neutral", "positive"]

print("\nPer-Class F1 Scores (where federation shines):")
print(f"\n{'':>25} {'Negative':>10} {'Neutral':>10} {'Positive':>10}")
print("-" * 58)
print(f"{'Baseline':<25} {baseline_report['negative']['f1-score']:>10.2f} {baseline_report['neutral']['f1-score']:>10.2f} {baseline_report['positive']['f1-score']:>10.2f}")
for i in range(NUM_CLIENTS):
    r = client_reports[i]
    print(f"{'Client '+str(i):<25} {r['negative']['f1-score']:>10.2f} {r['neutral']['f1-score']:>10.2f} {r['positive']['f1-score']:>10.2f}")
print("-" * 58)
print(f"{'Federated':<25} {fed_report['negative']['f1-score']:>10.2f} {fed_report['neutral']['f1-score']:>10.2f} {fed_report['positive']['f1-score']:>10.2f}")

print("\nKey insight: Each client is weak on classes they saw less of.")
print("The federated model should be strong across ALL classes.")

## Cleanup

In [None]:
print("Done! All results above.")
print(f"\nTotal test examples: {len(test_ds)}")
print(f"Chorus server will shut down when this session ends.")