In [None]:
#!/usr/bin/env python3
"""
train_cls_head_only.py

Train only the classification head on top of a frozen BERT encoder
using CLS token pooling, then evaluate on a test set.

CSV schema: question_id, turn, turn_1_query, turn_1_answer, turn_2_query, winner
"""

import os
import random
import argparse
import pandas as pd
import torch
from torch import nn
import numpy as np
from datasets import Dataset
from sklearn.model_selection import train_test_split
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    DataCollatorWithPadding,
)

# -------------------------
# Argument parsing
# -------------------------
def parse_args():
    p = argparse.ArgumentParser(description="Train only classification head on frozen BERT")
    p.add_argument("--csv", type=str, default="mt_bench_training.csv", help="Input CSV file path")
    p.add_argument("--output_dir", type=str, default="./bert_cls_head", help="Where to save trained model")
    p.add_argument("--model_name", type=str, default="bert-base-uncased", help="Pretrained BERT model")
    p.add_argument("--batch_size", type=int, default=8)
    p.add_argument("--epochs", type=int, default=3)
    p.add_argument("--lr", type=float, default=5e-5)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--max_length", type=int, default=256)
    return p.parse_args()

# -------------------------
# Utilities
# -------------------------
def build_text_from_row(row):
    """Format row into natural language text."""
    turn = int(row["turn"])
    q1 = str(row.get("turn_1_query", "")).strip()
    if turn == 1:
        return f"Turn 1 query: {q1}"
    elif turn == 2:
        ans = str(row.get("turn_1_answer", "")).strip()
        q2 = str(row.get("turn_2_query", "")).strip()
        return f"Turn 1 query: {q1}, Turn 1 answer: {ans}, Turn 2 query: {q2}"
    else:
        return f"Turn 1 query: {q1}"

def preprocess_dataset(df, tokenizer, label2id, args):
    """Convert DataFrame into HF Dataset with tokenized inputs and labels."""
    texts = [build_text_from_row(r) for _, r in df.iterrows()]
    labels = [label2id[w] for w in df["winner"]]
    encodings = tokenizer(
        texts,
        truncation=True,
        padding=False,
        max_length=args.max_length,
    )
    dataset = Dataset.from_dict({**encodings, "labels": labels})
    return dataset

# -------------------------
# Main
# -------------------------
def main():
    args = parse_args()
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # Load dataset
    if not os.path.exists(args.csv):
        raise FileNotFoundError(f"CSV not found: {args.csv}")
    df = pd.read_csv(args.csv)
    unique_labels = sorted(df["winner"].unique().tolist())
    print(f"Unique labels: {unique_labels}")

    label2id = {label: i for i, label in enumerate(unique_labels)}
    id2label = {i: label for i, label in enumerate(unique_labels)}

    # Split dataset
    train_df, test_df = train_test_split(df, test_size=0.1, random_state=args.seed, stratify=df["winner"])
    train_df, val_df = train_test_split(train_df, test_size=0.1, random_state=args.seed, stratify=train_df["winner"])

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name)

    # Tokenize datasets
    train_ds = preprocess_dataset(train_df, tokenizer, label2id, args)
    val_ds = preprocess_dataset(val_df, tokenizer, label2id, args)
    test_ds = preprocess_dataset(test_df, tokenizer, label2id, args)

    # Load model
    model = AutoModelForSequenceClassification.from_pretrained(
        args.model_name,
        num_labels=len(unique_labels),
        id2label=id2label,
        label2id=label2id,
    )

    # Freeze all BERT parameters
    for param in model.base_model.parameters():
        param.requires_grad = False

    # Ensure classification head is trainable
    for param in model.classifier.parameters():
        param.requires_grad = True

    print("Trainable parameters:")
    for name, param in model.named_parameters():
        if param.requires_grad:
            print(name)

    # Data collator
    data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=args.output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch",
        per_device_train_batch_size=args.batch_size,
        per_device_eval_batch_size=args.batch_size,
        num_train_epochs=args.epochs,
        learning_rate=args.lr,
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
        logging_dir=os.path.join(args.output_dir, "logs"),
        seed=args.seed,
    )

    # Metric
    def compute_metrics(eval_preds):
        logits, labels = eval_preds
        preds = np.argmax(logits, axis=-1)
        acc = (preds == labels).mean()
        return {"accuracy": acc}

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )

    # Train
    print("--- Training classification head ---")
    trainer.train()

    # Evaluate on test set
    print("--- Evaluating on test set ---")
    results = trainer.evaluate(test_ds)
    print(results)

    # Save model
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(f"Model saved to {args.output_dir}")

if __name__ == "__main__":
    main()