In [None]:
#MT-Bench Response Generation using the FT-DPO model

import os, json, random
import pandas as pd
from tqdm.auto import tqdm
import torch
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel


# Config
BASE_MODEL       = "mistralai/Mistral-7B-Instruct-v0.2"
DPO_ADAPTER_PATH = "./dpo_arena55k_0830_dulcet_glade_12"  
OUT_DIR          = "mtbench_runs"
MAX_NEW_TOKENS   = 1000
TEMPERATURE      = 0.2
TOP_P            = 0.9
SEED             = 25

DEVICE   = "cuda" if torch.cuda.is_available() else "cpu"
USE_BF16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()

random.seed(SEED)
torch.manual_seed(SEED)

# Load MT-Bench prompts
mtb = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
print(mtb[0])

# Load models & tokenizer
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
tok.padding_side = "left"


def load_base():
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        torch_dtype=torch.bfloat16 if USE_BF16 else torch.float16,
        device_map="balanced",
    )
    model.eval()
    model.config.use_cache = True
    model.config.pad_token_id = tok.pad_token_id
    return model


def load_dpo():
    base = load_base()
    dpo = PeftModel.from_pretrained(base, DPO_ADAPTER_PATH, device_map="auto")
    dpo.eval()
    dpo.config.pad_token_id = tok.pad_token_id
    return dpo


dpo_model  = load_dpo()

{'category': 'writing', 'prompt': ['Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences and must-see attractions.', 'Rewrite your previous response. Start every sentence with the letter A.'], 'reference': [], 'prompt_id': 44067482}


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
# Generation helper

def chat_gen(model, turns):
    """
    turns: list of user turns (1–2 strings)
    returns: list of assistant replies
    """
    replies = []
    messages = [{"role": "user", "content": turns[0]}]

    def run_once(msgs):
        # 1) Build a single string with the chat template
        prompt_text = tok.apply_chat_template(
            msgs,
            add_generation_prompt=True,
            tokenize=False,            
        )
        # 2) Tokenize to get attention_mask
        batch = tok(
            prompt_text,
            return_tensors="pt",
            padding=False,              
            add_special_tokens=False,
        ).to(DEVICE)
    
        with torch.no_grad():
            out = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],  
                max_new_tokens=MAX_NEW_TOKENS,
                temperature=TEMPERATURE,
                top_p=TOP_P,
                do_sample=bool(TEMPERATURE > 0),
                pad_token_id=tok.pad_token_id,
                eos_token_id=tok.eos_token_id,
            )
    
        # decode only the generated tail
        gen = tok.decode(out[0][batch["input_ids"].shape[-1]:], skip_special_tokens=True)
        return gen.strip()

    # turn 1
    r1 = run_once(messages)
    replies.append(r1)


    # turn 2 if present
    if len(turns) > 1 and turns[1]:
        messages.append({"role": "assistant", "content": r1})
        messages.append({"role": "user", "content": turns[1]})
        r2 = run_once(messages)
        replies.append(r2)

    return replies


def run_model_on_mtbench(model, tag):
    rows = []
    for i, ex in enumerate(tqdm(mtb, desc=f"Generating: {tag}")):
        turns = [t for t in ex["prompt"] if isinstance(t, str) and t.strip()]
        try:
            replies = chat_gen(model, turns)
        except Exception as e:
            replies = [f"[GENERATION ERROR] {e}"]
        rows.append({
            "index": i,
            "category": ex.get("category"),
            "turns": turns,
            "replies": replies,
            "model_id": tag,
        })
    return pd.DataFrame(rows)


# Generate & Save
df_dpo  = run_model_on_mtbench(dpo_model,  "dpo")

os.makedirs(OUT_DIR, exist_ok=True)
df_dpo.to_json(f"{OUT_DIR}/dpo_generations.json",  orient="records", lines=True, force_ascii=False)

print("Saved generations to", OUT_DIR)