# Chorus: End-to-End Federated LoRA Test

**What this tests:** 3 clients each train a LoRA adapter on their private slice of data, submit to a Chorus server, which aggregates them using FedEx-LoRA. We verify the federated model outperforms any single client.

| Setting | Value |
|---------|-------|
| Model | `Qwen/Qwen2.5-0.5B-Instruct` (490M params) |
| Dataset | `databricks/databricks-dolly-15k` (5K per client) |
| LoRA rank | 16 |
| Clients | 3 (sequential training, shared GPU) |
| Aggregation | FedEx-LoRA (SVD-optimal) |
| Runtime | ~30-45 min on T4 |

## Step 0: Install Chorus

In [None]:
!pip install -q 'chorus[peft] @ git+https://github.com/varmabudharaju/chorus.git'
!pip install -q bitsandbytes

In [None]:
import torch, gc, os, time, math, subprocess, logging

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(name)s] %(message)s', datefmt='%H:%M:%S')

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    gpu_props = torch.cuda.get_device_properties(0)
    print(f"GPU: {gpu_props.name}")
    print(f"VRAM: {gpu_props.total_memory / 1e9:.1f} GB")
    print(f"bf16 supported: {torch.cuda.is_bf16_supported()}")
else:
    print("WARNING: No GPU! Go to Runtime -> Change runtime type -> T4 GPU")

## Step 1: Prepare Dataset — Split Dolly-15K Across 3 Clients

In [None]:
from datasets import load_dataset

ds = load_dataset("databricks/databricks-dolly-15k", split="train")
print(f"Total examples: {len(ds)}")
print(f"Categories: {sorted(ds.unique('category'))}")
print(f"\nSample:")
print(f"  instruction: {ds[0]['instruction'][:100]}")
print(f"  context:     {ds[0]['context'][:100]}")
print(f"  response:    {ds[0]['response'][:100]}")

# Shuffle and split into 3 client partitions
ds = ds.shuffle(seed=42)
client_datasets = [ds.shard(num_shards=3, index=i) for i in range(3)]

# Save each client's data as JSON
for i, cds in enumerate(client_datasets):
    path = f"/content/client_{i}.json"
    cds.to_json(path)
    print(f"Client {i}: {len(cds)} examples -> {path}")

# Hold out 20 examples for evaluation
eval_ds = ds.select(range(20))
print(f"\nHeld out {len(eval_ds)} examples for evaluation")

## Step 2: Baseline — Measure Loss Before Any Training

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
USE_BF16 = torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False
DTYPE = torch.bfloat16 if USE_BF16 else torch.float16

print(f"Loading {MODEL_NAME} (dtype={DTYPE})...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


def measure_loss(model, examples, label):
    """Measure average cross-entropy loss on held-out examples."""
    model.eval()
    total_loss = 0.0
    for ex in examples:
        text = f"{ex['instruction']}\n{ex.get('context', '')}\n{ex['response']}"
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to(device)
        with torch.no_grad():
            loss = model(**inputs, labels=inputs["input_ids"]).loss
        total_loss += loss.item()
    avg_loss = total_loss / len(examples)
    ppl = math.exp(avg_loss)
    print(f"  [{label}] Loss: {avg_loss:.4f} | Perplexity: {ppl:.2f}")
    return avg_loss, ppl


# Baseline
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
baseline_loss, baseline_ppl = measure_loss(base_model, eval_ds, "Baseline (no training)")
del base_model; gc.collect(); torch.cuda.empty_cache()
print("Baseline measured. Model unloaded.")

## Step 3: Start the Chorus Server

In [None]:
import requests
import threading
import socket

# Find a free port automatically
def find_free_port():
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("127.0.0.1", 0))
        return s.getsockname()[1]

SERVER_PORT = find_free_port()
SERVER_URL = f"http://127.0.0.1:{SERVER_PORT}"
print(f"Using port: {SERVER_PORT}")

# Start Chorus server in a background thread
from chorus.server.app import configure, app
import uvicorn

configure(
    model_id=MODEL_NAME,
    data_dir="/content/chorus_data",
    strategy="fedex-lora",
    min_deltas=3,
)

server_error = []

def run_server():
    try:
        uvicorn.run(app, host="127.0.0.1", port=SERVER_PORT, log_level="warning")
    except Exception as e:
        server_error.append(str(e))

server_thread = threading.Thread(target=run_server, daemon=True)
server_thread.start()

print("Starting Chorus server...")
for i in range(15):
    time.sleep(1)
    if server_error:
        print(f"ERROR: Server crashed: {server_error[0]}")
        break
    try:
        r = requests.get(f"{SERVER_URL}/health")
        if r.status_code == 200:
            health = r.json()
            print(f"Server is running at {SERVER_URL}")
            print(f"  Model:    {health['model_id']}")
            print(f"  Strategy: {health['strategy']}")
            break
    except requests.ConnectionError:
        print(f"  Waiting... ({i+1}s)")
else:
    if not server_error:
        print("ERROR: Server failed to start after 15 seconds")

## Step 4: Train 3 Clients and Submit Deltas

Each client:
1. Loads Qwen2.5-0.5B
2. Trains LoRA (rank 16) on its ~5K examples
3. Submits the delta to the Chorus server
4. Unloads the model to free VRAM for the next client

The 3rd submission triggers FedEx-LoRA aggregation automatically.

In [None]:
from chorus.client.trainer import LoRATrainer
from chorus.client.sdk import ChorusClient

# SERVER_URL is set in Step 3 (dynamic port)
NUM_CLIENTS = 3
client_losses = []

for i in range(NUM_CLIENTS):
    print(f"\n{'='*60}")
    print(f"CLIENT {i}: Training on {len(client_datasets[i])} examples")
    print(f"{'='*60}")

    output_dir = f"/content/adapter_client_{i}"

    # Train
    trainer = LoRATrainer(
        base_model=MODEL_NAME,
        dataset=f"/content/client_{i}.json",
        output_dir=output_dir,
        lora_rank=16,
        lora_alpha=32,
        learning_rate=2e-4,
        num_epochs=1,
        per_device_batch_size=4,
        gradient_accumulation_steps=4,
        max_seq_length=512,
        bf16=USE_BF16,
        fp16=not USE_BF16 and torch.cuda.is_available(),
        dataloader_pin_memory=False,
    )

    adapter_path = trainer.train()
    print(f"  Adapter saved to: {adapter_path}")

    # Measure this single client's loss
    from peft import PeftModel
    single_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
    single_model = PeftModel.from_pretrained(single_model, output_dir)
    loss_i, ppl_i = measure_loss(single_model, eval_ds, f"Client {i} alone")
    client_losses.append((loss_i, ppl_i))
    del single_model; gc.collect(); torch.cuda.empty_cache()

    # Submit to Chorus server
    client = ChorusClient(
        server=SERVER_URL,
        model_id=MODEL_NAME,
        client_id=f"client-{i}",
    )
    result = client.submit_delta(
        adapter_path=output_dir,
        dataset_size=len(client_datasets[i]),
    )
    client.close()

    print(f"  Submitted! Deltas: {result['deltas_received']}/{result['min_deltas']}")
    if result['aggregated']:
        print(f"  >>> FedEx-LoRA AGGREGATION TRIGGERED! <<<")

print(f"\nAll {NUM_CLIENTS} clients done.")

## Step 5: Pull Aggregated Adapter and Evaluate

In [None]:
from safetensors.torch import load_file
from peft import LoraConfig, get_peft_model

# Pull the aggregated adapter from the server
client = ChorusClient(server=SERVER_URL, model_id=MODEL_NAME)
agg_path = client.pull_latest(output_path="/content/adapter_federated")
client.close()

print(f"Aggregated adapter: {agg_path}")
print(f"File size: {os.path.getsize(agg_path) / 1e6:.2f} MB")

# Inspect tensors
agg_tensors = load_file(str(agg_path))
print(f"Tensors: {len(agg_tensors)}")
for k, v in sorted(agg_tensors.items())[:4]:
    print(f"  {k}: {v.shape}")
print(f"  ...")

In [None]:
# Load base model + inject aggregated adapter
fed_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)

# Apply a LoRA config matching the aggregated adapter's rank
ranks = set(v.shape[0] for k, v in agg_tensors.items() if "lora_A" in k)
agg_rank = max(ranks)
print(f"Aggregated adapter rank: {agg_rank}")

lora_config = LoraConfig(
    r=agg_rank, lora_alpha=agg_rank,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.0, bias="none", task_type="CAUSAL_LM",
)
fed_model = get_peft_model(fed_model, lora_config)

# Load aggregated weights into the PEFT model
state_dict = fed_model.state_dict()
loaded = 0
for k, v in agg_tensors.items():
    peft_key = (
        "base_model.model." +
        k.replace(".lora_A.weight", ".lora_A.default.weight")
         .replace(".lora_B.weight", ".lora_B.default.weight")
    )
    if peft_key in state_dict:
        state_dict[peft_key] = v.to(state_dict[peft_key].dtype).to(device)
        loaded += 1
fed_model.load_state_dict(state_dict)
print(f"Loaded {loaded}/{len(agg_tensors)} tensors into model")

# Evaluate
fed_loss, fed_ppl = measure_loss(fed_model, eval_ds, "Federated (FedEx-LoRA)")
del fed_model; gc.collect(); torch.cuda.empty_cache()

## Step 6: Results

In [None]:
print("\n" + "=" * 65)
print("RESULTS")
print("=" * 65)
print(f"\n{'Model':<35} {'Loss':>8} {'Perplexity':>12}")
print("-" * 58)

print(f"{'Baseline (no training)':<35} {baseline_loss:>8.4f} {baseline_ppl:>12.2f}")
for i, (l, p) in enumerate(client_losses):
    print(f"{'Client ' + str(i) + ' alone':<35} {l:>8.4f} {p:>12.2f}")
print(f"{'Federated (FedEx-LoRA, 3 clients)':<35} {fed_loss:>8.4f} {fed_ppl:>12.2f}")

print("\n" + "-" * 58)
avg_client_loss = sum(l for l, _ in client_losses) / len(client_losses)
print(f"\nBaseline -> Federated:  {(baseline_loss - fed_loss) / baseline_loss * 100:+.1f}% loss reduction")
print(f"Avg client -> Federated: {(avg_client_loss - fed_loss) / avg_client_loss * 100:+.1f}% loss reduction")

if fed_loss < avg_client_loss:
    print("\n>>> Federation improved over individual clients! Chorus works. <<<")
else:
    print("\n>>> Federation did not beat avg client. May need more rounds or tuning. <<<")

## Step 7: Test Generation Quality

In [None]:
# Quick generation comparison: base vs federated
test_prompts = [
    "Explain the difference between machine learning and deep learning.",
    "What are the main causes of climate change?",
    "Summarize the plot of Romeo and Juliet in 3 sentences.",
]


def generate(model, prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        out = model.generate(
            **inputs, max_new_tokens=150, do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
        )
    return tokenizer.decode(out[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()


# Load both models for comparison
print("Loading base model...")
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)

print("Loading federated model...")
fed_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
fed_model = get_peft_model(fed_model, lora_config)
state_dict = fed_model.state_dict()
for k, v in agg_tensors.items():
    peft_key = (
        "base_model.model." +
        k.replace(".lora_A.weight", ".lora_A.default.weight")
         .replace(".lora_B.weight", ".lora_B.default.weight")
    )
    if peft_key in state_dict:
        state_dict[peft_key] = v.to(state_dict[peft_key].dtype).to(device)
fed_model.load_state_dict(state_dict)

print("\n" + "=" * 65)
for prompt in test_prompts:
    print(f"\nQ: {prompt}")
    print(f"\n  [Base]:      {generate(base_model, prompt)[:300]}")
    print(f"\n  [Federated]: {generate(fed_model, prompt)[:300]}")
    print("-" * 65)

del base_model, fed_model; gc.collect(); torch.cuda.empty_cache()

## Cleanup

In [None]:
# Server runs as a daemon thread — it stops when the notebook stops.
print("Done! The Chorus server will shut down when this Colab session ends.")
print("All results are above.")