In [None]:
#!/usr/bin/env python3
"""
few_shot_bert_icL.py

Use BERT in few-shot in-context learning mode (n-shots) to classify the best model.

Pipeline:
1. Randomly choose a row from the CSV -> this is the "query" we want to predict.
2. Randomly select n other distinct rows as few-shot demonstrations.
3. Build a prompt of the form:

   Query: {query_1}, winner model: {winner_1}
   ...
   Query: {query_n}, winner model: {winner_n}
   Turn 1 query: {target_query}, winner model: [MASK]

4. Pass this prompt into a BERT MLM head, predict the masked token.
5. Repeat for 100 queries, save results to CSV.
"""

import os
import random
import argparse
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM
import numpy as np

# -------------------------
# Argument parsing
# -------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Few-shot ICL with BERT for model classification")
    p.add_argument("--csv", type=str, default="mt_bench_training.csv", help="Input CSV file path")
    p.add_argument("--output_csv", type=str, default="bert_few_shot_predictions.csv", help="Output CSV file path")
    p.add_argument("--model_name", type=str, default="bert-base-uncased", help="Pretrained BERT model")
    p.add_argument("--shots", type=int, default=1, help="Number of shots (examples) for few-shot ICL")
    p.add_argument("--num_queries", type=int, default=100, help="Number of queries to evaluate")
    p.add_argument("--seed", type=int, default=42)
    return p.parse_args()

# -------------------------
# Utilities
# -------------------------
def build_text_from_row(row):
    """Format a row into natural language text based on turn."""
    turn = int(row["turn"])
    q1 = str(row.get("turn_1_query", "")).strip()
    if turn == 1:
        return f"Turn 1 query: {q1}"
    elif turn == 2:
        ans = str(row.get("turn_1_answer", "")).strip()
        q2 = str(row.get("turn_2_query", "")).strip()
        return f"Turn 1 query: {q1}, Turn 1 answer: {ans}, Turn 2 query: {q2}"
    else:
        return f"Turn 1 query: {q1}"

def build_prompt(few_shots, target_row):
    """Build the full few-shot prompt ending with 'winner model: [MASK]'."""
    parts = []
    for row in few_shots:
        query_text = build_text_from_row(row)
        parts.append(f"{query_text}, winner model: {row['winner']}")
    target_query_text = build_text_from_row(target_row)
    parts.append(f"{target_query_text}, winner model: [MASK]")
    return "\n".join(parts)

def predict_masked(prompt, tokenizer, model, candidate_labels):
    """Predict which candidate label fills [MASK] best."""
    # Tokenize
    inputs = tokenizer(prompt, return_tensors="pt")
    mask_token_index = torch.where(inputs["input_ids"] == tokenizer.mask_token_id)[1]

    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    mask_token_logits = logits[0, mask_token_index, :].squeeze(0)

    # Score each candidate by summing token logit probs
    candidate_scores = {}
    for label in candidate_labels:
        label_ids = tokenizer(label, add_special_tokens=False)["input_ids"]
        if len(label_ids) == 1:
            score = mask_token_logits[label_ids[0]].item()
        else:
            # Approximate: average scores for multi-token labels
            score = np.mean([mask_token_logits[i].item() for i in label_ids])
        candidate_scores[label] = score

    # Pick the highest-scoring candidate
    pred = max(candidate_scores, key=candidate_scores.get)
    return pred, candidate_scores

# -------------------------
# Main
# -------------------------
def main():
    args = parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Load dataset
    if not os.path.exists(args.csv):
        raise FileNotFoundError(f"CSV not found: {args.csv}")
    df = pd.read_csv(args.csv)
    unique_labels = sorted(df["winner"].unique().tolist())
    print(f"Unique labels: {unique_labels}")

    # Load BERT MLM model
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)
    model = AutoModelForMaskedLM.from_pretrained(args.model_name)
    model.eval()

    results = []

    for i in range(args.num_queries):
        # 1. Pick a target row
        target_idx = random.randint(0, len(df) - 1)
        target_row = df.iloc[target_idx]

        # 2. Select few-shot examples (ensure distinct from target)
        few_shot_indices = random.sample(
            [j for j in range(len(df)) if j != target_idx],
            k=min(args.shots, len(df)-1)
        )
        few_shots = [df.iloc[j] for j in few_shot_indices]

        # 3. Build prompt
        prompt = build_prompt(few_shots, target_row)

        # 4. Predict [MASK]
        pred, scores = predict_masked(prompt, tokenizer, model, unique_labels)

        # Record result
        results.append({
            "query_id": target_row["question_id"],
            "turn": target_row["turn"],
            "true_winner": target_row["winner"],
            "predicted_winner": pred,
            "prompt": prompt,
            "scores": scores
        })

        print(f"[{i+1}/{args.num_queries}] True: {target_row['winner']} | Pred: {pred}")

    # Save results to CSV
    out_df = pd.DataFrame(results)
    out_df.to_csv(args.output_csv, index=False)
    print(f"Saved predictions to {args.output_csv}")

if __name__ == "__main__":
    main()