# Chorus: Federated LoRA on Medical Data

This notebook demonstrates **federated fine-tuning** of TinyLlama on medical Q&A data using the actual Chorus framework — server, clients, and all.

**Scenario:** Two hospitals each have private medical flashcard data. Neither can share their data. Instead, each trains a LoRA adapter locally and submits it to a Chorus server, which aggregates them using **FedExLoRA** (exact aggregation with heterogeneous rank support).

We compare:
1. **Base model** (no fine-tuning) — how TinyLlama answers medical questions out of the box
2. **Single hospital** (trained on only half the data)
3. **Federated model** (both hospitals' adapters merged by Chorus server)

**Runtime:** ~15 min on Colab T4 GPU (free tier)

**Architecture being tested:**
```
Hospital A (rank 8)  ──submit──►  Chorus Server  ◄──submit──  Hospital B (rank 16)
                                       │
                                  FedExLoRA aggregation
                                  (exact, heterogeneous)
                                       │
                              Aggregated adapter (rank 16)
                                       │
                                  ◄──pull──
```

## 0. Install Chorus + Dependencies

In [None]:
# Clone Chorus and install with PEFT training support
!git clone https://github.com/varmabudharaju/chorus.git /content/chorus 2>/dev/null || echo 'Already cloned'
!pip install -q -e '/content/chorus[peft]'
!pip install -q bitsandbytes  # For efficient GPU training

In [None]:
import torch
import gc
import os
import math
import time
import subprocess
import signal

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
else:
    print("WARNING: No GPU detected. Training will be very slow.")
    print("Go to Runtime → Change runtime type → T4 GPU")

## 1. Prepare Medical Datasets for Two Hospitals

We use [medalpaca/medical_meadow_medical_flashcards](https://huggingface.co/datasets/medalpaca/medical_meadow_medical_flashcards) — 33K medical Q&A pairs from clinical flashcards.

We split it in half to simulate two hospitals with **different patient populations**.

In [None]:
from datasets import load_dataset

ds = load_dataset("medalpaca/medical_meadow_medical_flashcards", split="train")
print(f"Total medical flashcards: {len(ds)}")
print(f"\nSample:")
print(f"  Instruction: {ds[0]['instruction']}")
print(f"  Question:    {ds[0]['input'][:120]}")
print(f"  Answer:      {ds[0]['output'][:120]}")

# Shuffle and split between two hospitals
ds = ds.shuffle(seed=42)
split = len(ds) // 2
hospital_a_data = ds.select(range(split))
hospital_b_data = ds.select(range(split, len(ds)))

# Save as JSON for Chorus LoRATrainer (it accepts HF dataset names or local files)
hospital_a_data.to_json("/content/hospital_a.json")
hospital_b_data.to_json("/content/hospital_b.json")

print(f"\nHospital A: {len(hospital_a_data)} examples → /content/hospital_a.json")
print(f"Hospital B: {len(hospital_b_data)} examples → /content/hospital_b.json")

# Hold out eval examples (5 from each hospital's data)
eval_examples = [
    {"q": hospital_a_data[i]["input"], "a": hospital_a_data[i]["output"]}
    for i in range(0, 50, 5)
] + [
    {"q": hospital_b_data[i]["input"], "a": hospital_b_data[i]["output"]}
    for i in range(0, 50, 5)
]
print(f"Held out {len(eval_examples)} examples for evaluation")

## 2. Evaluation Setup

We measure **medical QA perplexity** (lower = the model is more confident about correct medical answers) and compare generated answers qualitatively.

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model, TaskType
from safetensors.torch import load_file

MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

MEDICAL_QUESTIONS = [
    "What are the classic symptoms of diabetes mellitus?",
    "What is the first-line treatment for hypertension?",
    "What does an elevated troponin level indicate?",
    "What is the difference between Type 1 and Type 2 diabetes?",
    "What are the side effects of metformin?",
]


def format_prompt(question):
    return (
        "Below is an instruction that describes a task. "
        "Write a response that appropriately completes the request.\n\n"
        f"### Instruction:\nAnswer this medical question truthfully: {question}\n\n"
        "### Response:\n"
    )


def evaluate_model(model, label, eval_data=None, gen_questions=None):
    """Evaluate a model: perplexity on medical QA + qualitative generation."""
    model.eval()
    results = {"label": label}

    if eval_data:
        total_loss = 0.0
        for ex in eval_data:
            text = f"Question: {ex['q']}\nAnswer: {ex['a']}"
            inputs = tokenizer(
                text, return_tensors="pt", truncation=True, max_length=256
            ).to(device)
            with torch.no_grad():
                loss = model(**inputs, labels=inputs["input_ids"]).loss
            total_loss += loss.item()
        results["perplexity"] = math.exp(total_loss / len(eval_data))
        print(f"  [{label}] Medical QA perplexity: {results['perplexity']:.2f}")

    if gen_questions:
        results["generations"] = []
        for q in gen_questions:
            prompt = format_prompt(q)
            inputs = tokenizer(prompt, return_tensors="pt").to(device)
            with torch.no_grad():
                out = model.generate(
                    **inputs, max_new_tokens=150, do_sample=False,
                    pad_token_id=tokenizer.eos_token_id,
                )
            answer = tokenizer.decode(
                out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
            ).strip()
            results["generations"].append({"q": q, "a": answer})

    return results


def load_adapter_into_model(adapter_tensors, rank):
    """Load base model + inject aggregated adapter tensors from Chorus."""
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, torch_dtype=torch.float16
    ).to(device)
    config = LoraConfig(
        r=rank, lora_alpha=rank, target_modules=["q_proj", "v_proj"],
        lora_dropout=0.0, bias="none", task_type=TaskType.CAUSAL_LM,
    )
    model = get_peft_model(model, config)

    state_dict = model.state_dict()
    loaded = 0
    for k, v in adapter_tensors.items():
        peft_key = (
            "base_model.model." +
            k.replace(".lora_A.weight", ".lora_A.default.weight")
             .replace(".lora_B.weight", ".lora_B.default.weight")
        )
        if peft_key in state_dict:
            state_dict[peft_key] = v.to(state_dict[peft_key].dtype).to(device)
            loaded += 1
    model.load_state_dict(state_dict)
    print(f"  Loaded {loaded}/{len(adapter_tensors)} tensors into model")
    return model


print("Evaluation functions ready.")

## 3. Baseline: Evaluate Base Model (No Medical Training)

In [None]:
print("=" * 60)
print("BASELINE: TinyLlama-Chat (no medical fine-tuning)")
print("=" * 60)

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.float16
).to(device)

baseline_results = evaluate_model(
    base_model, "Base Model",
    eval_data=eval_examples,
    gen_questions=MEDICAL_QUESTIONS,
)

print("\nSample answers from BASE model:")
for g in baseline_results["generations"][:3]:
    print(f"  Q: {g['q']}")
    print(f"  A: {g['a'][:200]}")
    print()

del base_model
gc.collect()
torch.cuda.empty_cache()

## 4. Start the Chorus Server

We start the actual Chorus aggregation server in the background. It will:
- Listen for LoRA adapter submissions from clients
- Aggregate with **FedExLoRA** once 2 deltas arrive
- Support **heterogeneous ranks** (Hospital A = rank 8, Hospital B = rank 16)

In [None]:
import requests

# Start Chorus server in background
server_proc = subprocess.Popen(
    [
        "chorus", "server",
        "--model", MODEL_NAME,
        "--strategy", "fedex-lora",
        "--min-deltas", "2",
        "-v",
    ],
    stdout=subprocess.PIPE,
    stderr=subprocess.STDOUT,
    text=True,
)

# Wait for server to be ready
print("Starting Chorus server...")
for i in range(30):
    time.sleep(1)
    try:
        r = requests.get("http://localhost:8080/health")
        if r.status_code == 200:
            print(f"Chorus server is running!")
            print(f"  Health: {r.json()}")
            break
    except requests.ConnectionError:
        pass
else:
    print("ERROR: Server failed to start. Check logs:")
    print(server_proc.stdout.read())

## 5. Hospital A: Train LoRA Adapter (Rank 8)

Using Chorus's `LoRATrainer` — the same component a real client would use.

Hospital A has a smaller compute budget → **rank 8** LoRA.

In [None]:
from chorus.client.trainer import LoRATrainer

print("=" * 60)
print("HOSPITAL A: Training LoRA adapter (rank 8)")
print("=" * 60)

trainer_a = LoRATrainer(
    base_model=MODEL_NAME,
    dataset="/content/hospital_a.json",
    output_dir="/content/adapter_hospital_a",
    lora_rank=8,
    lora_alpha=8,
    learning_rate=2e-4,
    num_epochs=1,
    per_device_batch_size=8,
    gradient_accumulation_steps=2,
    max_seq_length=256,
    bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
    fp16=not (torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False),
    dataloader_pin_memory=False,
)

adapter_path_a = trainer_a.train()
print(f"\nHospital A adapter saved to: {adapter_path_a}")
print(f"Dataset size: {trainer_a.get_dataset_size()} examples")

In [None]:
# Evaluate Hospital A's adapter (single hospital, no federation)
from peft import PeftModel

print("Evaluating Hospital A's adapter (single hospital, no federation):")
model_a = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.float16
).to(device)
model_a = PeftModel.from_pretrained(model_a, "/content/adapter_hospital_a")
model_a.eval()

single_a_results = evaluate_model(
    model_a, "Hospital A only (rank 8)",
    eval_data=eval_examples,
    gen_questions=MEDICAL_QUESTIONS,
)

del model_a
gc.collect()
torch.cuda.empty_cache()

## 6. Hospital A Submits to Chorus Server

Using the actual `ChorusClient.submit_delta()` — sends the adapter over HTTP to the running Chorus server.

In [None]:
from chorus.client.sdk import ChorusClient

# Hospital A submits its adapter
client_a = ChorusClient(
    server="http://localhost:8080",
    model_id=MODEL_NAME,
    client_id="hospital-a",
)

result_a = client_a.submit_delta(
    adapter_path="/content/adapter_hospital_a",
    dataset_size=len(hospital_a_data),
)

print(f"Hospital A submitted!")
print(f"  Round: {result_a['round_id']}")
print(f"  Deltas received: {result_a['deltas_received']}/{result_a['min_deltas']}")
print(f"  Aggregated: {result_a['aggregated']}")
print(f"\n  (Waiting for Hospital B before aggregation triggers...)")

client_a.close()

## 7. Hospital B: Train LoRA Adapter (Rank 16)

Hospital B has more compute → **rank 16** LoRA.

This tests Chorus's **heterogeneous rank** support — the server must handle rank 8 + rank 16 adapters.

In [None]:
print("=" * 60)
print("HOSPITAL B: Training LoRA adapter (rank 16)")
print("=" * 60)

trainer_b = LoRATrainer(
    base_model=MODEL_NAME,
    dataset="/content/hospital_b.json",
    output_dir="/content/adapter_hospital_b",
    lora_rank=16,
    lora_alpha=16,
    learning_rate=2e-4,
    num_epochs=1,
    per_device_batch_size=8,
    gradient_accumulation_steps=2,
    max_seq_length=256,
    bf16=torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False,
    fp16=not (torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False),
    dataloader_pin_memory=False,
)

adapter_path_b = trainer_b.train()
print(f"\nHospital B adapter saved to: {adapter_path_b}")
print(f"Dataset size: {trainer_b.get_dataset_size()} examples")

In [None]:
# Evaluate Hospital B's adapter (single hospital, no federation)
print("Evaluating Hospital B's adapter (single hospital, no federation):")
model_b = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME, torch_dtype=torch.float16
).to(device)
model_b = PeftModel.from_pretrained(model_b, "/content/adapter_hospital_b")
model_b.eval()

single_b_results = evaluate_model(
    model_b, "Hospital B only (rank 16)",
    eval_data=eval_examples,
    gen_questions=MEDICAL_QUESTIONS,
)

del model_b
gc.collect()
torch.cuda.empty_cache()

## 8. Hospital B Submits → Triggers Aggregation!

With `min_deltas=2`, the server will **automatically aggregate** both adapters using FedExLoRA as soon as Hospital B submits.

In [None]:
# Hospital B submits — this triggers aggregation!
client_b = ChorusClient(
    server="http://localhost:8080",
    model_id=MODEL_NAME,
    client_id="hospital-b",
)

result_b = client_b.submit_delta(
    adapter_path="/content/adapter_hospital_b",
    dataset_size=len(hospital_b_data),
)

print(f"Hospital B submitted!")
print(f"  Round: {result_b['round_id']}")
print(f"  Deltas received: {result_b['deltas_received']}/{result_b['min_deltas']}")
print(f"  Aggregated: {result_b['aggregated']}")

if result_b['aggregated']:
    print(f"\n  FedExLoRA aggregation complete!")
    print(f"  The server merged rank-8 + rank-16 adapters using SVD-optimal aggregation.")

## 9. Pull Aggregated Adapter from Chorus Server

Now we pull the federated model — the merged adapter that contains knowledge from **both** hospitals.

In [None]:
# Pull the aggregated adapter
agg_path = client_b.pull_latest(output_path="/content/adapter_federated")
client_b.close()

print(f"Aggregated adapter pulled to: {agg_path}")
print(f"File size: {os.path.getsize(agg_path) / 1e6:.1f} MB")

# Inspect the adapter
agg_tensors = load_file(str(agg_path))
ranks = set(v.shape[0] for k, v in agg_tensors.items() if "lora_A" in k)
print(f"Tensors: {len(agg_tensors)}")
print(f"Output rank: {ranks} (max of input ranks 8 and 16)")

# Check server status
r = requests.get("http://localhost:8080/health")
print(f"\nServer health: {r.json()}")

## 10. Evaluate the Federated Model

In [None]:
print("=" * 60)
print("FEDERATED MODEL: FedExLoRA aggregation of both hospitals")
print("=" * 60)

fed_model = load_adapter_into_model(agg_tensors, rank=max(ranks))

federated_results = evaluate_model(
    fed_model, "Federated (FedExLoRA)",
    eval_data=eval_examples,
    gen_questions=MEDICAL_QUESTIONS,
)

del fed_model
gc.collect()
torch.cuda.empty_cache()

## 11. Final Comparison

The moment of truth — did federated learning actually help?

In [None]:
print("\n" + "=" * 70)
print("RESULTS: Medical QA Perplexity (lower = better)")
print("=" * 70 + "\n")

all_results = [
    baseline_results,
    single_a_results,
    single_b_results,
    federated_results,
]

# Find best perplexity for highlighting
best_ppl = min(r["perplexity"] for r in all_results)

for r in all_results:
    ppl = r["perplexity"]
    bar_len = int(ppl / max(rr["perplexity"] for rr in all_results) * 30)
    bar = "█" * bar_len
    marker = " ◄ BEST" if ppl == best_ppl else ""
    imp = ""
    if r["label"] != "Base Model":
        pct = (baseline_results["perplexity"] - ppl) / baseline_results["perplexity"] * 100
        imp = f"  ({pct:+.1f}% vs base)"
    print(f"  {r['label']:30s}  {ppl:8.2f}  {bar}{imp}{marker}")

print("\n" + "=" * 70)
print("GENERATION COMPARISON")
print("=" * 70)

for i, q in enumerate(MEDICAL_QUESTIONS):
    print(f"\n{'─' * 70}")
    print(f"Q: {q}")
    print(f"{'─' * 70}")
    for r in all_results:
        if "generations" in r:
            ans = r["generations"][i]["a"][:250]
            label = r["label"]
            print(f"\n  [{label}]:")
            print(f"  {ans}")

In [None]:
# Summary
print("\n" + "=" * 70)
print("SUMMARY")
print("=" * 70)
print(f"""
  Perplexity (lower = better):
    Base model (no training):       {baseline_results['perplexity']:.2f}
    Hospital A only (rank 8):       {single_a_results['perplexity']:.2f}
    Hospital B only (rank 16):      {single_b_results['perplexity']:.2f}
    Federated — FedExLoRA (A + B):  {federated_results['perplexity']:.2f}

  What happened:
    1. Two hospitals trained LoRA adapters on their PRIVATE medical data
    2. They submitted only the small adapters (~4-8 MB) to the Chorus server
    3. The server aggregated them using FedExLoRA (exact, heterogeneous ranks)
    4. The federated model learned from BOTH hospitals' data
    5. Neither hospital ever shared a single patient record

  Technical details:
    - Hospital A: rank 8, {len(hospital_a_data)} examples
    - Hospital B: rank 16, {len(hospital_b_data)} examples
    - Aggregated adapter: rank {max(ranks)}, {len(agg_tensors)} tensors
    - Aggregation: FedExLoRA (SVD-optimal, mathematically exact)
    - Server: Chorus v0.1.0, strategy=fedex-lora
""")

In [None]:
# Cleanup: stop the Chorus server
server_proc.terminate()
server_proc.wait(timeout=5)
print("Chorus server stopped.")