# Phase 3 — Reranker Training: LoRA + Knowledge Distillation

This notebook trains a two-stage reranker pipeline:
1. **Teacher**: CodeBERT fine-tuned with LoRA on assert-review pairs
2. **Student**: CodeT5-small distilled from teacher logits (2× smaller, 5× faster)

Output checkpoint saved to `ml/models/reranker/`.

## 1. Setup

In [None]:
!pip install -q transformers peft datasets accelerate torch scikit-learn

In [None]:
import json
import os
import sys
from pathlib import Path

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset

# Resolve project root
PROJECT_ROOT = Path("../..").resolve()
sys.path.insert(0, str(PROJECT_ROOT))

DATA_PATH = PROJECT_ROOT / "ml" / "data" / "reranker_pairs.jsonl"
CHECKPOINT_DIR = PROJECT_ROOT / "ml" / "models" / "reranker"
CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {DEVICE}")
print(f"Data path: {DATA_PATH}")

## 2. Data — Load JSONL and Build RerankerDataset

Expected JSONL format (one record per line):
```json
{"text": "<file>src/auth/jwt.py ...diff...", "label": 1}
```
`label=1` means the diff is important for review; `label=0` means it can be deprioritized.

In [None]:
def load_jsonl(path: Path) -> list[dict]:
    records = []
    with open(path) as f:
        for line in f:
            line = line.strip()
            if line:
                records.append(json.loads(line))
    return records


class RerankerDataset(Dataset):
    """Tokenised dataset for binary importance classification."""

    def __init__(self, records: list[dict], tokenizer, max_length: int = 512):
        self.records = records
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.records)

    def __getitem__(self, idx: int) -> dict:
        rec = self.records[idx]
        enc = self.tokenizer(
            rec["text"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "labels": torch.tensor(rec["label"], dtype=torch.float),
        }


# Generate synthetic data if real data is missing (for notebook demo)
if not DATA_PATH.exists():
    print(f"WARNING: {DATA_PATH} not found. Generating synthetic demo data.")
    DATA_PATH.parent.mkdir(parents=True, exist_ok=True)
    samples = [
        {"text": "<file>src/auth/jwt.py\n+def validate_token(token): ...", "label": 1},
        {"text": "<file>README.md\n+## Updated documentation", "label": 0},
        {"text": "<file>src/core/db.py\n+def execute_query(sql, params): ...", "label": 1},
        {"text": "<file>tests/test_utils.py\n+def test_helper(): ...", "label": 0},
        {"text": "<file>src/crypto/hash.py\n+def sha256(data): ...", "label": 1},
        {"text": "<file>docs/changelog.md\n+v1.2.0 release notes", "label": 0},
    ] * 20  # repeat for a small training set
    with open(DATA_PATH, "w") as f:
        for s in samples:
            f.write(json.dumps(s) + "\n")

records = load_jsonl(DATA_PATH)
print(f"Loaded {len(records)} records")
print(f"Label distribution: {sum(r['label'] for r in records)} positive / {len(records)} total")

In [None]:
from sklearn.model_selection import train_test_split

train_records, val_records = train_test_split(records, test_size=0.15, random_state=42)
print(f"Train: {len(train_records)}  Val: {len(val_records)}")

## 3. Teacher Training — CodeBERT + LoRA

We apply LoRA adapters only to the query/value projection matrices, keeping the base CodeBERT weights frozen. This reduces trainable parameters by ~98% while retaining most accuracy.

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from peft import LoraConfig, get_peft_model, TaskType

TEACHER_BASE = "microsoft/codebert-base"

teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_BASE)
teacher_base = AutoModelForSequenceClassification.from_pretrained(
    TEACHER_BASE, num_labels=1
)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    target_modules=["query", "value"],
    bias="none",
)

teacher_model = get_peft_model(teacher_base, lora_config)
teacher_model.print_trainable_parameters()

In [None]:
TEACHER_EPOCHS = 3
TEACHER_LR = 2e-4
BATCH_SIZE = 8

train_ds = RerankerDataset(train_records, teacher_tokenizer)
val_ds = RerankerDataset(val_records, teacher_tokenizer)
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)

teacher_model.to(DEVICE)
optimizer = torch.optim.AdamW(teacher_model.parameters(), lr=TEACHER_LR)
loss_fn = nn.BCEWithLogitsLoss()

for epoch in range(TEACHER_EPOCHS):
    teacher_model.train()
    total_loss = 0.0
    for batch in train_loader:
        optimizer.zero_grad()
        out = teacher_model(
            input_ids=batch["input_ids"].to(DEVICE),
            attention_mask=batch["attention_mask"].to(DEVICE),
        )
        logits = out.logits.squeeze(-1)
        loss = loss_fn(logits, batch["labels"].to(DEVICE))
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    # Validation
    teacher_model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            out = teacher_model(
                input_ids=batch["input_ids"].to(DEVICE),
                attention_mask=batch["attention_mask"].to(DEVICE),
            )
            logits = out.logits.squeeze(-1)
            val_loss += loss_fn(logits, batch["labels"].to(DEVICE)).item()

    print(
        f"Epoch {epoch+1}/{TEACHER_EPOCHS} | "
        f"train_loss={total_loss/len(train_loader):.4f} | "
        f"val_loss={val_loss/len(val_loader):.4f}"
    )

print("Teacher training complete.")

## 4. Distillation — CodeT5-small Student

We use **soft-label distillation**: the student minimises a weighted sum of
- BCE loss against hard labels (ground truth)
- KL-divergence loss against teacher soft probabilities

Temperature `T=4` softens the teacher distribution so the student learns richer signal.

In [None]:
STUDENT_BASE = "Salesforce/codet5-small"

student_tokenizer = AutoTokenizer.from_pretrained(STUDENT_BASE)
student_model = AutoModelForSequenceClassification.from_pretrained(
    STUDENT_BASE, num_labels=1
)
student_model.to(DEVICE)

print(f"Student params: {sum(p.numel() for p in student_model.parameters()):,}")

In [None]:
class DistillDataset(Dataset):
    """Pairs (student_encoding, teacher_soft_label, hard_label)."""

    def __init__(self, records, student_tok, teacher_model, teacher_tok, device, max_length=512):
        self.records = records
        self.student_tok = student_tok
        self.max_length = max_length
        # Pre-compute teacher soft labels
        self.soft_labels = self._compute_soft_labels(teacher_model, teacher_tok, device)

    def _compute_soft_labels(self, model, tokenizer, device):
        model.eval()
        soft = []
        texts = [r["text"] for r in self.records]
        bs = 16
        with torch.no_grad():
            for i in range(0, len(texts), bs):
                batch_texts = texts[i : i + bs]
                enc = tokenizer(
                    batch_texts,
                    padding=True,
                    truncation=True,
                    max_length=self.max_length,
                    return_tensors="pt",
                ).to(device)
                out = model(**enc)
                probs = torch.sigmoid(out.logits.squeeze(-1)).cpu().tolist()
                if isinstance(probs, float):
                    probs = [probs]
                soft.extend(probs)
        return soft

    def __len__(self):
        return len(self.records)

    def __getitem__(self, idx):
        rec = self.records[idx]
        enc = self.student_tok(
            rec["text"],
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt",
        )
        return {
            "input_ids": enc["input_ids"].squeeze(0),
            "attention_mask": enc["attention_mask"].squeeze(0),
            "hard_label": torch.tensor(rec["label"], dtype=torch.float),
            "soft_label": torch.tensor(self.soft_labels[idx], dtype=torch.float),
        }


print("Building distillation dataset (computing teacher soft labels) ...")
distill_ds = DistillDataset(
    train_records, student_tokenizer, teacher_model, teacher_tokenizer, DEVICE
)
distill_loader = DataLoader(distill_ds, batch_size=BATCH_SIZE, shuffle=True)
print(f"Distillation dataset ready: {len(distill_ds)} samples")

In [None]:
DISTILL_EPOCHS = 4
DISTILL_LR = 3e-4
ALPHA = 0.5   # weight for hard-label BCE loss
TEMPERATURE = 4.0

distill_optimizer = torch.optim.AdamW(student_model.parameters(), lr=DISTILL_LR)
bce = nn.BCEWithLogitsLoss()

for epoch in range(DISTILL_EPOCHS):
    student_model.train()
    total_loss = 0.0
    for batch in distill_loader:
        distill_optimizer.zero_grad()

        out = student_model(
            input_ids=batch["input_ids"].to(DEVICE),
            attention_mask=batch["attention_mask"].to(DEVICE),
        )
        student_logits = out.logits.squeeze(-1)

        hard_labels = batch["hard_label"].to(DEVICE)
        soft_labels = batch["soft_label"].to(DEVICE)

        # Hard-label loss
        hard_loss = bce(student_logits, hard_labels)

        # Soft-label distillation loss (MSE between sigmoid probabilities)
        student_probs = torch.sigmoid(student_logits / TEMPERATURE)
        teacher_probs = soft_labels  # already sigmoid-ed
        soft_loss = nn.functional.mse_loss(student_probs, teacher_probs)

        loss = ALPHA * hard_loss + (1 - ALPHA) * soft_loss
        loss.backward()
        distill_optimizer.step()
        total_loss += loss.item()

    print(f"Distill Epoch {epoch+1}/{DISTILL_EPOCHS} | loss={total_loss/len(distill_loader):.4f}")

print("Distillation complete.")

In [None]:
# Save student checkpoint (used by ml/models/reranker.py at inference time)
student_model.save_pretrained(str(CHECKPOINT_DIR))
student_tokenizer.save_pretrained(str(CHECKPOINT_DIR))
print(f"Student checkpoint saved to {CHECKPOINT_DIR}")

## 5. Benchmark — Teacher vs Student

Evaluate both models on the held-out validation set using:
- **AUC-ROC** — ranking quality
- **Accuracy** at threshold 0.5
- **Latency** (ms/sample)

In [None]:
import time
from sklearn.metrics import roc_auc_score, accuracy_score


def evaluate(model, tokenizer, records, device, label="model", batch_size=16):
    model.eval()
    all_probs, all_labels = [], []
    texts = [r["text"] for r in records]
    labels = [r["label"] for r in records]

    t0 = time.perf_counter()
    with torch.no_grad():
        for i in range(0, len(texts), batch_size):
            batch_texts = texts[i : i + batch_size]
            enc = tokenizer(
                batch_texts,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors="pt",
            ).to(device)
            out = model(**enc)
            probs = torch.sigmoid(out.logits.squeeze(-1)).cpu().tolist()
            if isinstance(probs, float):
                probs = [probs]
            all_probs.extend(probs)
    elapsed = time.perf_counter() - t0

    auc = roc_auc_score(labels, all_probs)
    preds = [1 if p >= 0.5 else 0 for p in all_probs]
    acc = accuracy_score(labels, preds)
    ms_per_sample = (elapsed / len(texts)) * 1000

    print(f"[{label}] AUC={auc:.4f} | Acc={acc:.4f} | Latency={ms_per_sample:.2f} ms/sample")
    return {"auc": auc, "accuracy": acc, "ms_per_sample": ms_per_sample}


print("=== Validation Benchmark ===")
teacher_metrics = evaluate(teacher_model, teacher_tokenizer, val_records, DEVICE, label="Teacher (CodeBERT+LoRA)")
student_metrics = evaluate(student_model, student_tokenizer, val_records, DEVICE, label="Student (CodeT5-small)")

In [None]:
speedup = teacher_metrics["ms_per_sample"] / max(student_metrics["ms_per_sample"], 1e-9)
auc_gap = teacher_metrics["auc"] - student_metrics["auc"]

print(f"\nSpeedup: {speedup:.1f}x faster")
print(f"AUC gap (teacher - student): {auc_gap:.4f}")

if speedup >= 2.0 and auc_gap <= 0.05:
    print("PASS: Student meets latency and quality targets.")
else:
    print("WARNING: Targets not met — consider more distillation epochs or temperature tuning.")