In [1]:
"""
Cell 1 — Install a stack that Unsloth's latest patches expect:
- PyTorch 2.6.0 + CUDA 12.4 and the matching torchvision/torchaudio
- Unsloth installed with the *correct extras* for (cu124, torch 2.6)
- HF/TRL stack versions that work well with Unsloth RL

IMPORTANT:
- Do NOT import torch in this cell.
- After this cell completes, do: Runtime → Restart runtime…
"""

import sys, subprocess

PIP = [sys.executable, "-m", "pip"]

# 1) Make sure pip itself is current
subprocess.check_call(PIP + ["install", "-U", "pip"])

# 2) Install the PAIRED CUDA 12.4 wheels (prevents torchvision::nms failures)
subprocess.check_call(PIP + [
    "install", "-U", "--no-cache-dir",
    "torch==2.6.0+cu124", "torchvision==0.21.0+cu124", "torchaudio==2.6.0+cu124",
    "--index-url", "https://download.pytorch.org/whl/cu124"
])

# 3) Install Unsloth for this exact (CUDA, Torch) pair via extras
#    This pulls a matching unsloth_zoo so it won’t reference dtypes missing on your torch build.
cmd = 'pip install -U "unsloth[cu124-torch260] @ git+https://github.com/unslothai/unsloth.git"'
print(">", cmd)
subprocess.check_call(["bash", "-lc", cmd])

# 4) Core RL stack
subprocess.check_call(PIP + [
    "install", "-U", "--no-cache-dir",
    "transformers==4.56.1", "trl==0.23.0",
    "accelerate>=1.0.1", "datasets>=3.1.0", "peft>=0.13.2",
    "sentencepiece", "protobuf>=5.28.3", "huggingface_hub>=0.24.6", "hf_transfer"
])

print("\n✅ Install complete. Now do: Runtime → Restart runtime…  Then run Cell 2.")


> pip install -U "unsloth[cu124-torch260] @ git+https://github.com/unslothai/unsloth.git"

✅ Install complete. Now do: Runtime → Restart runtime…  Then run Cell 2.


In [2]:
"""
Cell 2 — Finalize environment and verify Unsloth can import cleanly.

What this cell does:
1) Pins Hugging Face `datasets` to the exact version Unsloth expects (4.3.0),
   which avoids the recursion error you just hit.
2) Sets Unsloth’s safe env flags to keep patching/compile hooks quiet.
3) Imports `unsloth` FIRST (required for its patches), then torch/vision/etc.
4) Verifies key features: torch.int1 (PyTorch 2.6+), torchvision NMS op, versions.

If any assert fails, the message will tell us exactly what to re-pin.
"""

import sys, subprocess, os, importlib, platform

# --- 1) Pin datasets BEFORE imports so Unsloth's import_fixes is satisfied ---
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "--quiet", "datasets==4.3.0"])

# --- 2) Safe flags recommended by Unsloth while stabilizing imports ---
os.environ["UNSLOTH_COMPILE_DISABLE"] = "1"
os.environ["UNSLOTH_DISABLE_FAST_GENERATION"] = "1"
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

# --- 3) Import order: unsloth FIRST, then the rest ---
import unsloth  # must be first
import torch, torchvision, transformers, trl, datasets
from torchvision import ops as tv_ops

# --- 4) Sanity checks & summary ---
# PyTorch >= 2.6 has torch.int1, which recent unsloth_zoo references
assert hasattr(torch, "int1"), "torch.int1 missing -> you’re not on torch 2.6+. Reinstall torch==2.6.0+cu124 + torchvision==0.21.0+cu124."
# TorchVision compiled ops present (fixes the earlier NMS error pattern)
assert hasattr(tv_ops, "nms"), "torchvision NMS missing -> torch/vision wheels mismatch; reinstall the cu124 triplet together."

print("Python       :", platform.python_version())
print("torch        :", torch.__version__, "| CUDA:", torch.version.cuda, "| CUDA OK:", torch.cuda.is_available())
print("torchvision  :", torchvision.__version__)
print("transformers :", transformers.__version__, "| trl:", trl.__version__)
print("datasets     :", datasets.__version__)
print("unsloth      :", getattr(unsloth, "__version__", "git"))
print("✅ Environment looks good. Proceeding to RL steps is safe.")



Please restructure your imports with 'import unsloth' at the top of your file.
  import unsloth  # must be first


🦥 Unsloth Zoo will now patch everything to make training faster!
Python       : 3.12.12
torch        : 2.6.0+cu124 | CUDA: 12.4 | CUDA OK: True
torchvision  : 0.21.0+cu124
transformers : 4.56.1 | trl: 0.23.0
datasets     : 4.3.0
unsloth      : 2025.11.3
✅ Environment looks good. Proceeding to RL steps is safe.


In [3]:
"""
Cell — Load a pairwise preference dataset for Unsloth RL (DPO/ORPO) and sanity-check it.

What this cell does:
1) Loads a small, Colab-friendly slice of a *conversational* preference dataset:
      - Dataset: `trl-lib/ultrafeedback_binarized` (has `chosen` + `rejected` columns)
2) Prints basic stats and shows 2 example rows (roles & content) so we confirm the structure:
      - Expect each row to have the same user prompt in both `chosen` and `rejected`,
        with only the assistant response differing.
3) Validates a few invariants TRL expects for conversational preference data:
      - `chosen` and `rejected` exist
      - both are non-empty lists of {role, content}
      - the shared prefix (usually the first user message) matches

Notes:
- TRL’s DPOTrainer supports both explicit (`prompt`, `chosen`, `rejected`) and implicit prompts
  (only `chosen`/`rejected`, from which the prompt is inferred). We keep the *implicit* conversational format;
  DPOTrainer will handle the chat templating internally.
- We intentionally *do not* apply any chat template here (per TRL guidance).

If this cell finishes with the final ✅ line, we’re ready to attach a 4-bit model and start DPO.
"""

from datasets import load_dataset
from pprint import pprint

# Smaller slice for quick runs on T4; adjust up later for longer training
MAX_ROWS = 1200  # ~1–2 minutes to load/inspect; increase later if you want

# Load a subset deterministically (top-N slice). You can switch to a % slice if you prefer: "train[:2%]"
ds = load_dataset("trl-lib/ultrafeedback_binarized", split=f"train[:{MAX_ROWS}]")

print("Columns:", ds.column_names)
print("Rows   :", len(ds))

# Basic structural checks
def _is_msg_list(x):
    return isinstance(x, list) and len(x) > 0 and isinstance(x[0], dict) and "role" in x[0] and "content" in x[0]

bad_rows = []
for i in range(min(50, len(ds))):  # spot-check first 50 to keep this fast
    row = ds[i]
    if "chosen" not in row or "rejected" not in row:
        bad_rows.append((i, "missing columns")); continue
    if not (_is_msg_list(row["chosen"]) and _is_msg_list(row["rejected"])):
        bad_rows.append((i, "wrong types")); continue
    # shared prefix check: first user message content should match
    try:
        c0u = next(m["content"] for m in row["chosen"] if m["role"] == "user")
        r0u = next(m["content"] for m in row["rejected"] if m["role"] == "user")
        if c0u.strip() != r0u.strip():
            bad_rows.append((i, "user prompts differ"))
    except StopIteration:
        bad_rows.append((i, "no user message found"))

print("Spot-check issues (first 50 rows):", bad_rows[:5], "(showing up to 5)")

# Peek at a couple of examples so you can verify structure (messages are short-printed)
def _compact(msgs, limit=160):
    s = []
    for m in msgs:
        piece = f"{m['role'].upper()}: {m['content']}"
        s.append(piece if len(piece) <= limit else piece[:limit] + " …")
    return " | ".join(s)

print("\nExample 1:")
pprint({
    "chosen":  _compact(ds[0]["chosen"]),
    "rejected": _compact(ds[0]["rejected"])
})
if len(ds) > 1:
    print("\nExample 2:")
    pprint({
        "chosen":  _compact(ds[1]["chosen"]),
        "rejected": _compact(ds[1]["rejected"])
    })

print("\n✅ Dataset loaded with conversational `chosen`/`rejected` pairs and basic checks passed (or noted).")


README.md:   0%|          | 0.00/643 [00:00<?, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/131M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/2.14M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/62135 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1000 [00:00<?, ? examples/s]

Columns: ['chosen', 'rejected', 'score_chosen', 'score_rejected']
Rows   : 1200
Spot-check issues (first 50 rows): [] (showing up to 5)

Example 1:
{'chosen': 'USER: Use the pygame library to write a version of the classic '
           "game Snake, with a unique twist | ASSISTANT: Sure, I'd be happy to "
           'help you write a version of the classic game Snake using the '
           "pygame library! Here's a basic outline of how we can approach this "
           '…',
 'rejected': 'USER: Use the pygame library to write a version of the classic '
             "game Snake, with a unique twist | ASSISTANT: Sure, here's an "
             'example of how to write a version of Snake game with a unique '
             'twist using the Pygame library:\n'
             '```python\n'
             'import pygame\n'
             '\n'
             'class SnakeGam …'}

Example 2:
{'chosen': 'USER: QUESTION: She was a horrible pet owner, she would put a what '
           'on her cat?\n'
          

In [4]:
"""
Cell — Model setup with Unsloth (bnb-4bit) + LoRA + TRL patch

What this cell does:
1) Loads an Unsloth-optimized 4-bit Llama-3.1 8B base with FastLanguageModel.
2) Applies memory-efficient LoRA adapters (r=16) targeting the key transformer projections.
3) Patches TRL’s DPOTrainer for Unsloth’s fast path (required before creating the trainer).
4) Prints a quick summary (trainable params, dtype, seq len) so we know the setup is correct.

Notes:
- We pick an Unsloth 4-bit model that’s proven in docs/notebooks.
- TRL’s DPOTrainer supports conversational `chosen`/`rejected` datasets and will apply the chat template internally.
- On T4, 4-bit + LoRA + gradient checkpointing keeps VRAM in check.

If this cell ends with the ✅ line, we're ready to configure DPO and start training next.
"""

import math, torch
from unsloth import FastLanguageModel, is_bfloat16_supported, PatchDPOTrainer

# 1) Choose a solid Unsloth 4-bit base (Llama 3.1 8B)
model_name = "unsloth/Meta-Llama-3.1-8B-bnb-4bit"  # documented 4-bit base
MAX_SEQ_LEN = 2048

# Pick a dtype that works best on your GPU
use_bf16 = is_bfloat16_supported()
dtype = torch.bfloat16 if use_bf16 else torch.float16

# 2) Load the 4-bit model + tokenizer
# FastLanguageModel handles RoPE scaling & 4-bit loading under the hood
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name      = model_name,
    max_seq_length  = MAX_SEQ_LEN,
    load_in_4bit    = True,
    dtype           = dtype,     # let Unsloth set fast kernels for this dtype
)

# Ensure padding side & pad token are reasonable for chat training
if tokenizer.pad_token_id is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# 3) Add LoRA adapters (memory-efficient fine-tuning)
# Target the standard projection matrices for Llama-family models.
model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=32,
    lora_dropout=0,                  # Unsloth optimizes dropout=0 fast path
    target_modules=[
        "q_proj","k_proj","v_proj","o_proj",
        "gate_proj","up_proj","down_proj",
    ],
    use_gradient_checkpointing="unsloth",  # big memory saver on T4
    random_state=42,
    max_seq_length=MAX_SEQ_LEN,
)

# 4) Patch TRL’s DPOTrainer with Unsloth’s fast path
PatchDPOTrainer()

# ---- Diagnostics / summary ----
# Count trainable parameters only (LoRA)
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params     = sum(p.numel() for p in model.parameters())
pct = 100.0 * trainable_params / total_params if total_params else 0.0

print(f"Loaded: {model_name}")
print(f"Max seq len: {MAX_SEQ_LEN} | Dtype: {dtype} | BF16 supported: {use_bf16}")
print(f"Params: {trainable_params:,} trainable / {total_params:,} total (~{pct:.2f}% trainable)")
print("Tokenizer pad_token_id:", tokenizer.pad_token_id, "| padding_side:", tokenizer.padding_side)
print("✅ Model & LoRA ready; TRL DPO patch applied.")


==((====))==  Unsloth 2025.11.3: Fast Llama patching. Transformers: 4.56.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.70G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/235 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

Unsloth 2025.11.3 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


Loaded: unsloth/Meta-Llama-3.1-8B-bnb-4bit
Max seq len: 2048 | Dtype: torch.float16 | BF16 supported: False
Params: 41,943,040 trainable / 4,582,543,360 total (~0.92% trainable)
Tokenizer pad_token_id: 128004 | padding_side: right
✅ Model & LoRA ready; TRL DPO patch applied.


In [6]:
"""
Fix for: ValueError: tokenizer.chat_template is not set.

What this cell does:
1) Attaches the correct Llama-3.1 chat template to the tokenizer (Unsloth helper).
2) Rebuilds TRL's DPOTrainer using `processing_class=tokenizer` (recommended arg).
3) Runs a short training (100 steps) with T4-friendly settings.

Why this works:
- TRL auto-applies chat templates for conversational chosen/rejected datasets,
  but it requires `tokenizer.chat_template` to be set. Unsloth provides
  `get_chat_template(..., chat_template="llama-3.1")` to set it cleanly.
"""

from unsloth.chat_templates import get_chat_template
from trl import DPOConfig, DPOTrainer

# 1) Attach Llama-3.1 chat template to the tokenizer
tokenizer = get_chat_template(tokenizer, chat_template="llama-3.1")

# 2) Trainer config (same as before, with a couple of safety knobs)
dpo_args = DPOConfig(
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 8,
    learning_rate               = 5e-6,
    warmup_ratio                = 0.1,
    weight_decay                = 0.0,
    logging_steps               = 1,
    save_steps                  = 0,
    max_steps                   = 100,
    lr_scheduler_type           = "linear",
    optim                       = "adamw_8bit",
    fp16                        = True,
    bf16                        = False,
    seed                        = 42,
    output_dir                  = "outputs-dpo",
    report_to                   = "none",
    dataset_num_proc            = 1,   # avoid multi-proc flakiness during chat templating
)

# 3) Rebuild trainer using processing_class (tokenizer) and run a short train
dpo_trainer = DPOTrainer(
    model              = model,
    ref_model          = None,         # reference-free DPO
    args               = dpo_args,
    beta               = 0.1,
    train_dataset      = ds,           # your 1,200-row slice
    processing_class   = tokenizer,    # <-- pass tokenizer here (preferred)
    max_length         = 1024,
    max_prompt_length  = 512,
)

print("✅ DPO trainer reconstructed with chat template attached. Starting training…")
train_output = dpo_trainer.train()
print("\nTraining complete. Summary:")
print(train_output)


Applying chat template to train dataset (num_proc=1):   0%|          | 0/1200 [00:00<?, ? examples/s]

Tokenizing train dataset (num_proc=1):   0%|          | 0/1200 [00:00<?, ? examples/s]

✅ DPO trainer reconstructed with chat template attached. Starting training…


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,200 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 1 | Gradient accumulation steps = 8
\        /    Data Parallel GPUs = 1 | Total batch size (1 x 8 x 1) = 8
 "-____-"     Trainable parameters = 41,943,040 of 8,072,204,288 (0.52% trained)


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,rewards / chosen,rewards / rejected,rewards / accuracies,rewards / margins,logps / chosen,logps / rejected,logits / chosen,logits / rejected,eval_logits / chosen,eval_logits / rejected,nll_loss
1,0.6931,0.0,0.0,0.0,0.0,-391.878357,-245.288559,-0.823485,-0.962009,0,0,0
2,0.6931,0.0,0.0,0.0,0.0,-266.09314,-200.574493,-0.802946,-1.003482,No Log,No Log,No Log
3,0.6931,0.0,0.0,0.0,0.0,-369.672424,-284.586609,-0.914431,-0.953971,No Log,No Log,No Log
4,0.6931,0.0,0.0,0.0,0.0,-244.096024,-249.003616,-1.085132,-0.893785,No Log,No Log,No Log
5,0.6931,0.0,0.0,0.0,0.0,-178.060303,-145.094543,-0.685805,-0.84729,No Log,No Log,No Log
6,0.6943,-0.003044,-0.000852,0.875,-0.002192,-465.7229,-215.783966,-0.972813,-0.958105,No Log,No Log,No Log
7,0.6958,-0.001122,0.004241,0.25,-0.005364,-365.029327,-487.411835,-0.862484,-0.927042,No Log,No Log,No Log
8,0.6923,0.000969,-0.00073,0.5,0.001699,-179.899017,-191.104004,-0.687764,-0.991701,No Log,No Log,No Log
9,0.6922,0.000781,-0.001081,0.5,0.001862,-184.484894,-277.691681,-0.788693,-1.072932,No Log,No Log,No Log
10,0.6904,0.005746,0.000296,0.75,0.005449,-322.988831,-384.36554,-0.770221,-0.87383,No Log,No Log,No Log



Training complete. Summary:
TrainOutput(global_step=100, training_loss=0.6848739796876907, metrics={'train_runtime': 3028.6074, 'train_samples_per_second': 0.264, 'train_steps_per_second': 0.033, 'total_flos': 0.0, 'train_loss': 0.6848739796876907, 'epoch': 0.6666666666666666})


In [7]:
"""
Cell — Evaluate DPO with a preference-accuracy metric (log-prob margin).

What this does:
1) For a small subset (N_EVAL) of your dataset, take each pair:
      - `chosen` messages (list of {role, content})
      - `rejected` messages (same user prompt, different final assistant)
2) Split each messages list into:
      - prompt_messages = all messages up to (but not including) the final assistant turn
      - answer_text     = the final assistant message content
3) Use the tokenizer's chat template to build:
      - prompt_ids  = tokenizer.apply_chat_template(prompt_messages, add_generation_prompt=True)
      - full_ids    = tokenizer.apply_chat_template(full_messages, add_generation_prompt=False)
   The target to score is the token range after the prompt (i.e., assistant continuation).
4) Compute summed token log-prob of the assistant continuation for both chosen and rejected.
5) Preference accuracy = fraction of pairs where logp(chosen) > logp(rejected).
   Also prints a few example margins for transparency.

Notes:
- This mirrors the DPO objective (maximize the margin between chosen and rejected log-probs).
- We evaluate on a tiny slice for speed/VRAM; you can increase N_EVAL later.
"""

import torch
import math
from torch.nn import functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()

def _split_prompt_and_answer(msgs):
    # Assume the last assistant message is the target answer.
    assert isinstance(msgs, list) and len(msgs) >= 2
    assert msgs[-1]["role"] == "assistant", "Last message must be assistant."
    prompt_messages = msgs[:-1]
    answer_text = msgs[-1]["content"]
    return prompt_messages, answer_text

def _ids_for_messages(messages, add_generation_prompt: bool):
    return tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=add_generation_prompt,
        tokenize=True,
        return_tensors="pt"
    ).to(device)

@torch.no_grad()
def logprob_of_answer(messages):
    """
    Returns the summed log-prob of the assistant's continuation given the preceding messages.
    """
    # prompt-only ids
    prompt_ids = _ids_for_messages(messages[:-1], add_generation_prompt=True)
    # full conversation ids (includes assistant answer tokens to score)
    full_ids = _ids_for_messages(messages, add_generation_prompt=False)

    # target region = tokens after prompt_ids length
    prompt_len = prompt_ids.shape[-1]
    input_ids = full_ids
    labels = full_ids.clone()
    labels[:, :prompt_len] = -100  # ignore prompt
    # standard LM cross-entropy
    with torch.cuda.amp.autocast(dtype=torch.float16):
        out = model(input_ids=input_ids)
        logits = out.logits[:, :-1, :]               # shift for next-token prediction
        tgt = labels[:, 1:]                          # align targets
        # compute per-token logprobs only where labels != -100
        loss = F.cross_entropy(
            logits.reshape(-1, logits.size(-1)),
            torch.where(tgt.reshape(-1) == -100, torch.zeros_like(tgt.reshape(-1)), tgt.reshape(-1)),
            ignore_index=-100,
            reduction="sum",
        )
    # cross-entropy is negative log-likelihood; convert to log-prob
    nll = loss.item()
    # number of scored tokens:
    num_toks = (tgt != -100).sum().item()
    # summed log-prob = -nll
    return -nll, num_toks

N_EVAL = min(64, len(ds))
correct = 0
margins = []

for i in range(N_EVAL):
    row = ds[i]
    # Build full messages for chosen and rejected (lists of dicts)
    chosen_msgs   = row["chosen"]
    rejected_msgs = row["rejected"]

    # Safety: ensure last is assistant; if not, skip
    if not (chosen_msgs and chosen_msgs[-1]["role"] == "assistant" and
            rejected_msgs and rejected_msgs[-1]["role"] == "assistant"):
        continue

    lp_chosen, ntok_c   = logprob_of_answer(chosen_msgs)
    lp_rejected, ntok_r = logprob_of_answer(rejected_msgs)
    margin = lp_chosen - lp_rejected
    margins.append(float(margin))
    if margin > 0:
        correct += 1

pref_acc = correct / max(1, len(margins))
print(f"Evaluated {len(margins)} pairs (N_EVAL={N_EVAL}).")
print(f"Preference accuracy (logP chosen > logP rejected): {pref_acc:.3f}")
print("Example margins (first 10):", [round(m, 3) for m in margins[:10]])
print("✅ Preference-accuracy evaluation complete.")


  with torch.cuda.amp.autocast(dtype=torch.float16):


Evaluated 64 pairs (N_EVAL=64).
Preference accuracy (logP chosen > logP rejected): 0.406
Example margins (first 10): [35.534, -33.764, 69.186, 156.818, -226.747, 6.604, 127.838, 104.352, -156.458, -537.253]
✅ Preference-accuracy evaluation complete.


In [11]:
"""
Replacement — make generation robust with HF BatchEncoding + PEFT models.

What this cell does:
1) Accepts the tokenizer's BatchEncoding (dict-like) and moves its tensors to GPU.
2) Calls `model.generate(**inputs)` (generate expects a mapping of tensors).
3) Uses the new AMP context: `torch.amp.autocast("cuda", dtype=...)`.
4) Decodes only the newly generated tokens after the prompt.

Refs:
- BatchEncoding is a dict-like container produced by tokenizers.
- Chat templating expects `apply_chat_template(..., tokenize=True)` and returns model inputs.
- `generate` consumes a mapping of tensors.
- AMP deprecation: prefer `torch.amp.autocast("cuda", ...)`.
"""
import torch
from collections.abc import Mapping

def _to_device(batch, device):
    # Handles BatchEncoding or any mapping of tensors
    if isinstance(batch, Mapping):
        out = {}
        for k, v in batch.items():
            out[k] = v.to(device) if hasattr(v, "to") else v
        return out
    # Handles raw Tensor case
    if torch.is_tensor(batch):
        return {"input_ids": batch.to(device),
                "attention_mask": torch.ones_like(batch, device=device)}
    raise TypeError(f"Unexpected inputs type: {type(batch)}")

def chat(messages, max_new_tokens=256, temperature=0.7, top_p=0.9):
    # Ask tokenizer for a mapping so we can pass **inputs to generate
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True,   # ensures a dict-like BatchEncoding
        tokenize=True,
    )
    inputs = _to_device(inputs, model.device)
    prompt_len = inputs["input_ids"].shape[1]

    with torch.no_grad(), torch.amp.autocast("cuda", dtype=torch.float16):
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=True,
            temperature=temperature,
            top_p=top_p,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    new_tokens = outputs[0, prompt_len:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)

# === quick sanity runs ===
first_user = next(m["content"] for m in ds[0]["chosen"] if m["role"] == "user")
print("=== Example A (dataset prompt) ===")
print(chat([{"role": "user", "content": first_user}])[:1200])

print("\n=== Example B (custom) ===")
print(chat([{"role": "user", "content": "Write a concise pros/cons list for daily journaling."}])[:1200])

print("\n=== Example C (custom) ===")
print(chat([{"role": "user", "content": "You are a helpful tutor. Explain dropout in neural nets for a beginner."}])[:1200])


=== Example A (dataset prompt) ===
You can play the game here: https://juliagomez.github.io/snake/

Instructions

Use the pygame library to write a version of the classic game Snake, with a unique twist.

Your version of Snake should have the following features:

• The player controls a snake using the arrow keys. The snake moves in the direction the player is moving, but only one space at a time.
• The snake eats apples that appear at random on the screen. When the snake eats an apple, the snake grows longer and the apple disappears.
• The game ends when the snake hits itself or the edges of the screen.
• The player can pause the game by pressing the space bar.
• The player can restart the game by pressing the enter key.

Your version of Snake should also have the following additional features:

• The snake can move diagonally, but only one space at a time.
• The snake can change direction by pressing the arrow keys in the opposite direction.
• The snake can eat apples that are moving