# Chorus: Centralized Baseline + Zero-Shot

**Run this notebook in parallel with `chorus_centralized_vs_federated.ipynb`.**

This notebook trains the centralized (gold standard) model on ALL 6K training examples
and evaluates both the zero-shot baseline and centralized model. Results are saved to
`/content/centralized_results.json` for the federated notebook to load.

| Setting | Value |
|---------|-------|
| Model | `Qwen/Qwen2.5-0.5B` |
| Dataset | AG News — 6K train, 2K test |
| LoRA | rank=16, alpha=32, 2 epochs |

In [None]:
!pip install -q 'chorus[peft] @ git+https://github.com/varmabudharaju/chorus.git'
!pip install -q scikit-learn

In [None]:
import torch, gc, os, time, json, logging, random
import numpy as np
from collections import Counter

logging.basicConfig(level=logging.INFO, format='%(asctime)s [%(name)s] %(message)s', datefmt='%H:%M:%S')

# ── Must match federated notebook exactly ─────────────────────────
MODEL_NAME = "Qwen/Qwen2.5-0.5B"
DATASET_SIZE = 6000
TEST_SIZE = 2000
LORA_RANK = 16
LORA_ALPHA = 32
LEARNING_RATE = 3e-4
NUM_EPOCHS = 2
BATCH_SIZE = 8
GRAD_ACCUM = 2
MAX_SEQ_LEN = 128
SEED = 42

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
if device == "cuda":
    gpu = torch.cuda.get_device_properties(0)
    print(f"GPU: {gpu.name} | VRAM: {gpu.total_memory / 1e9:.1f} GB")
else:
    print("WARNING: No GPU!")

## Step 1: Dataset

In [None]:
from datasets import load_dataset, Dataset

LABEL_NAMES = ["world", "sports", "business", "scitech"]
LABEL_MAP = {0: "world", 1: "sports", 2: "business", 3: "scitech"}

raw_train = load_dataset("ag_news", split="train").shuffle(seed=SEED)
raw_test = load_dataset("ag_news", split="test").shuffle(seed=SEED)

train_ds = raw_train.select(range(DATASET_SIZE))
test_ds = raw_test.select(range(TEST_SIZE))

print(f"Train: {len(train_ds)} | Test: {len(test_ds)}")
print(f"Train labels: {dict(sorted(Counter(train_ds['label']).items()))}")
print(f"Test labels:  {dict(sorted(Counter(test_ds['label']).items()))}")

def format_example(ex):
    return f"Sentence: {ex['text']}\nCategory: {LABEL_MAP[ex['label']]}"

central_texts = [format_example(train_ds[j]) for j in range(len(train_ds))]
central_ds = Dataset.from_dict({"text": central_texts})
central_path = "/content/centralized_train.json"
central_ds.to_json(central_path)
print(f"Saved {len(train_ds)} training examples to {central_path}")

## Step 2: Model + Evaluation Setup

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from sklearn.metrics import accuracy_score, classification_report

USE_BF16 = torch.cuda.is_bf16_supported() if torch.cuda.is_available() else False
DTYPE = torch.bfloat16 if USE_BF16 else torch.float16

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

label_token_ids = {
    0: tokenizer.encode(" world", add_special_tokens=False)[0],
    1: tokenizer.encode(" sports", add_special_tokens=False)[0],
    2: tokenizer.encode(" business", add_special_tokens=False)[0],
    3: tokenizer.encode(" scitech", add_special_tokens=False)[0],
}
print(f"Label tokens: {label_token_ids}")


def evaluate_accuracy(model, test_data, label):
    model.eval()
    preds, golds = [], []
    for ex in test_data:
        prompt = f"Sentence: {ex['text']}\nCategory:"
        inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=MAX_SEQ_LEN).to(device)
        with torch.no_grad():
            logits = model(**inputs).logits[:, -1, :]
        scores = {lbl: logits[0, tid].item() for lbl, tid in label_token_ids.items()}
        preds.append(max(scores, key=scores.get))
        golds.append(ex["label"])
    acc = accuracy_score(golds, preds)
    report = classification_report(
        golds, preds, target_names=LABEL_NAMES,
        output_dict=True, zero_division=0
    )
    print(
        f"  [{label}] Accuracy: {acc:.1%}  |  F1: "
        f"world={report['world']['f1-score']:.2f} "
        f"sports={report['sports']['f1-score']:.2f} "
        f"business={report['business']['f1-score']:.2f} "
        f"scitech={report['scitech']['f1-score']:.2f}"
    )
    return acc, report

## Step 3: Baseline (Zero-Shot)

In [None]:
base_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
baseline_acc, baseline_report = evaluate_accuracy(base_model, test_ds, "Baseline (zero-shot)")
del base_model; gc.collect(); torch.cuda.empty_cache()

## Step 4: Centralized Training (Gold Standard)

In [None]:
from chorus.client.trainer import LoRATrainer
from peft import PeftModel

print(f"Training centralized model on all {DATASET_SIZE} examples, {NUM_EPOCHS} epochs...")
t0 = time.time()

centralized_dir = "/content/adapter_centralized"
trainer = LoRATrainer(
    base_model=MODEL_NAME,
    dataset=central_path,
    output_dir=centralized_dir,
    lora_rank=LORA_RANK,
    lora_alpha=LORA_ALPHA,
    learning_rate=LEARNING_RATE,
    num_epochs=NUM_EPOCHS,
    per_device_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRAD_ACCUM,
    max_seq_length=MAX_SEQ_LEN,
    bf16=USE_BF16,
    fp16=not USE_BF16 and torch.cuda.is_available(),
    dataloader_pin_memory=False,
)
trainer.train()
print(f"Centralized training done in {time.time() - t0:.0f}s")

central_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=DTYPE).to(device)
central_model = PeftModel.from_pretrained(central_model, centralized_dir)
centralized_acc, centralized_report = evaluate_accuracy(central_model, test_ds, "Centralized")
del central_model; gc.collect(); torch.cuda.empty_cache()

## Step 5: Save Results

In [None]:
results = {
    "baseline_acc": baseline_acc,
    "baseline_report": baseline_report,
    "centralized_acc": centralized_acc,
    "centralized_report": centralized_report,
    "dataset_size": DATASET_SIZE,
    "test_size": TEST_SIZE,
    "num_epochs": NUM_EPOCHS,
}

output_path = "/content/centralized_results.json"
with open(output_path, "w") as f:
    json.dump(results, f, indent=2)

print(f"Results saved to {output_path}")
print(f"\nBaseline:    {baseline_acc:.1%}")
print(f"Centralized: {centralized_acc:.1%}")
print(f"\nOpen the federated notebook to see the full comparison.")