In [None]:
#!/usr/bin/env python3
"""
few_shot_cls_head_per_query.py

Few-shot classification with frozen BERT encoder + randomly initialized classification head per query.

Pipeline:
1. Randomly choose a row from the CSV -> this is the query we want to classify.
2. Randomly select n other distinct rows as "shots" (support set).
3. Randomly initialize the classification head weights.
4. Train the classification head on the n shots (with frozen BERT).
5. Infer on the target query using the temporarily learned head.
6. Repeat for num_queries times, save predictions to CSV and print accuracy.
"""

import os
import random
import argparse
import pandas as pd
import torch
from torch import nn, optim
from transformers import AutoTokenizer, AutoModel
import numpy as np
from tqdm import trange

# -------------------------
# Argument parsing
# -------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Few-shot classification with reinitialized head per query")
    p.add_argument("--csv", type=str, default="mt_bench_training.csv", help="Input CSV file path")
    p.add_argument("--output_csv", type=str, default="bert_few_shot_cls_predictions.csv", help="Output CSV file path")
    p.add_argument("--model_name", type=str, default="bert-base-uncased", help="Pretrained BERT encoder")
    p.add_argument("--shots", type=int, default=1, help="Number of support examples per query")
    p.add_argument("--num_queries", type=int, default=100, help="Number of queries to evaluate")
    p.add_argument("--epochs", type=int, default=10, help="Training epochs for head on support set")
    p.add_argument("--lr", type=float, default=1e-3, help="Learning rate for head training")
    p.add_argument("--max_length", type=int, default=256, help="Max token length")
    p.add_argument("--seed", type=int, default=42)
    return p.parse_args()

# -------------------------
# Utilities
# -------------------------
def build_text_from_row(row):
    """Format row into natural language text."""
    turn = int(row["turn"])
    q1 = str(row.get("turn_1_query", "")).strip()
    if turn == 1:
        return f"Turn 1 query: {q1}"
    elif turn == 2:
        ans = str(row.get("turn_1_answer", "")).strip()
        q2 = str(row.get("turn_2_query", "")).strip()
        return f"Turn 1 query: {q1}, Turn 1 answer: {ans}, Turn 2 query: {q2}"
    else:
        return f"Turn 1 query: {q1}"

class ClassificationHead(nn.Module):
    """Simple linear classification head on top of CLS embedding."""
    def __init__(self, hidden_size, num_labels):
        super().__init__()
        self.linear = nn.Linear(hidden_size, num_labels)

    def forward(self, x):
        return self.linear(x)

def encode_texts(texts, tokenizer, model, max_length, device):
    """Encode texts with frozen BERT, return CLS embeddings."""
    enc = tokenizer(
        texts,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=max_length,
    ).to(device)
    with torch.no_grad():
        outputs = model(**enc)
        cls_embeds = outputs.last_hidden_state[:, 0, :]  # CLS token
    return cls_embeds

# -------------------------
# Main
# -------------------------
def main():
    args = parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    device = "cuda" if torch.cuda.is_available() else "cpu"

    # Load dataset
    if not os.path.exists(args.csv):
        raise FileNotFoundError(f"CSV not found: {args.csv}")
    df = pd.read_csv(args.csv)
    unique_labels = sorted(df["winner"].unique().tolist())
    label2id = {label: i for i, label in enumerate(unique_labels)}
    id2label = {i: label for i, label in enumerate(unique_labels)}
    print(f"Unique labels: {unique_labels}")

    # Load frozen BERT encoder
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    bert = AutoModel.from_pretrained(args.model_name).to(device)
    bert.eval()
    for param in bert.parameters():
        param.requires_grad = False

    results = []
    correct = 0

    for i in trange(args.num_queries, desc="Queries"):
        # 1. Pick target row
        target_idx = random.randint(0, len(df) - 1)
        target_row = df.iloc[target_idx]

        # 2. Select support set
        support_indices = random.sample(
            [j for j in range(len(df)) if j != target_idx],
            k=min(args.shots, len(df)-1)
        )
        support_rows = [df.iloc[j] for j in support_indices]

        # 3. Prepare support texts and labels
        support_texts = [build_text_from_row(r) for r in support_rows]
        support_labels = torch.tensor([label2id[r["winner"]] for r in support_rows], dtype=torch.long).to(device)

        # 4. Encode support set
        support_embeds = encode_texts(support_texts, tokenizer, bert, args.max_length, device)

        # 5. Randomly initialize classification head
        head = ClassificationHead(bert.config.hidden_size, len(unique_labels)).to(device)

        # 6. Train classification head on support set
        optimizer = optim.AdamW(head.parameters(), lr=args.lr)
        loss_fn = nn.CrossEntropyLoss()
        head.train()
        for epoch in range(args.epochs):
            optimizer.zero_grad()
            logits = head(support_embeds)
            loss = loss_fn(logits, support_labels)
            loss.backward()
            optimizer.step()

        # 7. Encode target query
        target_text = build_text_from_row(target_row)
        target_embed = encode_texts([target_text], tokenizer, bert, args.max_length, device)

        # 8. Infer with trained head
        head.eval()
        with torch.no_grad():
            logits = head(target_embed)
            probs = torch.softmax(logits, dim=-1).squeeze().cpu().numpy()
            pred_id = int(np.argmax(probs))
            pred_label = id2label[pred_id]

        # 9. Update accuracy counter
        if pred_label == target_row["winner"]:
            correct += 1

        # 10. Save result
        results.append({
            "query_id": target_row["question_id"],
            "turn": target_row["turn"],
            "true_winner": target_row["winner"],
            "predicted_winner": pred_label,
            "support_set": [r["question_id"] for r in support_rows],
            "probs": {id2label[j]: float(probs[j]) for j in range(len(probs))}
        })

        print(f"[{i+1}/{args.num_queries}] True: {target_row['winner']} | Pred: {pred_label}")

    # Compute accuracy
    accuracy = correct / args.num_queries
    print(f"\nFinal Accuracy over {args.num_queries} queries: {accuracy:.4f}")

    # Save predictions
    out_df = pd.DataFrame(results)
    out_df.to_csv(args.output_csv, index=False)
    print(f"Saved predictions to {args.output_csv}")

if __name__ == "__main__":
    main()