# 🧠 **Colab 4: Reinforcement Learning with GRPO – Training a Reasoning Model**

In this notebook, we explore **Group Relative Policy Optimization (GRPO)** — a reinforcement learning method designed to train **reasoning-capable large language models (LLMs)** using only *problems and model-generated answers*.  

Unlike DPO (which relies on labeled “preferred vs. rejected” responses), GRPO lets the model **learn directly from its own generated outputs**, guided by a reward signal that measures reasoning quality or task success.  

### 🎯 **Learning Objectives**
- Understand the concept of **self-improvement in LLMs** using GRPO.  
- Set up a **reasoning dataset** where prompts are problems and responses are model-generated answers.  
- Implement **fine-tuning with Unsloth.ai’s GRPO pipeline** for efficient reasoning training.  
- Evaluate model performance on structured reasoning tasks (e.g., math or logic problems).


In [None]:
!pip install unsloth datasets transformers accelerate bitsandbytes wandb huggingface_hub

Collecting unsloth
  Downloading unsloth-2025.11.2-py3-none-any.whl.metadata (61 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/61.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.8/61.8 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting unsloth_zoo>=2025.11.3 (from unsloth)
  Downloading unsloth_zoo-2025.11.3-py3-none-any.whl.metadata (32 kB)
Collecting tyro (from unsloth)
  Downloading tyro-0.9.35-py3-none-any.whl.metadata (12 kB)
Collecting xformers>=0.0.27.post2 (from unsloth)
  Downloading xformers-0.0.33-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (1.2 kB)
Collecting datasets
  Downloading datasets-4.3.0-py3-none-any.whl.metadata (18 kB)
Collecting trl!=0.19.0,<=0.23.0,>=0.18.2 (from unsloth)
  Downloading trl-0.23.0-py3-none-any.whl.metadata (11 kB)
Collecting pyarrow>=21.0.0 (fr

In [None]:
import trl, transformers, accelerate, peft
import unsloth
from trl import SFTTrainer  # sanity check import like your code (we won't use it here)
import pkgutil, sys
import transformers as tf
import accelerate as ac
import torch
from unsloth import FastLanguageModel
from transformers import TrainingArguments
from datasets import load_dataset
import os



Please restructure your imports with 'import unsloth' at the top of your file.
  import unsloth


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


    PyTorch 2.9.0+cu130 with CUDA 1300 (you have 2.9.0+cu128)
    Python  3.10.19 (you have 3.12.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


Switching to PyTorch attention since your Xformers is broken.

Unsloth: Xformers was not installed correctly.
Please install xformers separately first.
Then confirm if it's correctly installed by running:
python -m xformers.info

Longer error message:
xFormers can't load C++/CUDA extensions. xFormers was built for:
    PyTorch 2.9.0+cu130 with CUDA 1300 (you have 2.9.0+cu128)
    Python  3.10.19 (you have 3.12.12)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
🦥 Unsloth Zoo will now patch everything to make training faster!


In [None]:
# ===============================================================
# 🔐 Auth via Colab secrets + 🧠 Model Load (4-bit)
# ===============================================================
import os, gc, time, random
import torch
from google.colab import userdata
from huggingface_hub import login
import wandb

# Unsloth import (ensure patches apply early)
import unsloth
from unsloth import FastLanguageModel

# -----------------------------
# 🎫 Fetch tokens from Colab secrets
# -----------------------------
hf_token = userdata.get("HGFaceApi")   # REQUIRED
wb_token = userdata.get("wb_token")    # OPTIONAL

if not hf_token:
    raise ValueError(
        "Hugging Face token not found in Colab secrets under key 'HGFaceApi'. "
        "Set it first via: userdata.set('HGFaceApi', 'hf_...')"
    )

# Log in to Hugging Face
login(hf_token)

# Log in to Weights & Biases (optional)
if wb_token:
    wandb.login(key=wb_token, relogin=True)
    print("W&B login successful.")
else:
    print("⚠️ W&B token ('wb_token') not found. Proceeding with anonymous logging.")

# Start a W&B run (anonymous allowed if no token)
run = wandb.init(
    project="LoRA-Finetuning-SmolLM2-135M",
    job_type="training",
    anonymous="allow",
)

# -----------------------------
# ⚙️ Repro + config
# -----------------------------
SEED = 3407
random.seed(SEED)
torch.manual_seed(SEED)

MODEL_NAME = "HuggingFaceTB/SmolLM2-135M-Instruct"
MAX_LEN    = 1024
FAST_STEPS = 300

print("Config:", dict(MODEL_NAME=MODEL_NAME, MAX_LEN=MAX_LEN, FAST_STEPS=FAST_STEPS))

# GPU info (if any)
!nvidia-smi || echo "No GPU visible"

print("bf16 support:", torch.cuda.is_bf16_supported())
print("Loading model + tokenizer...")
t0 = time.time()

# -----------------------------
# 🧩 Load model/tokenizer in 4-bit via Unsloth
# -----------------------------
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name     = MODEL_NAME,
    max_seq_length = MAX_LEN,
    dtype          = None,
    load_in_4bit   = True,
)

print(f"Loaded in {time.time()-t0:.2f}s")


  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mshubhamjaysukhbhai-kothiya[0m ([33mshubhamjaysukhbhai-kothiya-san-jose-state-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


W&B login successful.


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/


Config: {'MODEL_NAME': 'HuggingFaceTB/SmolLM2-135M-Instruct', 'MAX_LEN': 1024, 'FAST_STEPS': 300}
Thu Nov 13 11:59:11 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   64C    P0             30W /   70W |     102MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------

model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/655 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

HuggingFaceTB/SmolLM2-135M-Instruct does not have a padding token! Will use pad_token = <|endoftext|>.
Loaded in 16.85s


In [None]:
# We’ll use GSM8K (problems with numeric final answers)
SYSTEM_PROMPT = """Respond ONLY in this XML format:
<reasoning>
step-by-step reasoning here
</reasoning>
<answer>
final numeric answer only
</answer>
"""

def extract_gold(s: str):
    marker = "####"
    if marker in s:
        return s.split(marker)[-1].strip()
    return None

def to_messages(question: str):
    return [
        {"role": "system", "content": SYSTEM_PROMPT},
        {"role": "user",   "content": question.strip()},
    ]

raw = load_dataset("openai/gsm8k", "main")["train"].shuffle(SEED).select(range(800))

# Build dataset with messages + gold
train_data = raw.map(lambda x: {"prompt": to_messages(x["question"]),
                                "gold": extract_gold(x["answer"])})

# Build a fast lookup: question -> gold
Q2GOLD = { str(q).strip(): extract_gold(a) for q, a in zip(raw["question"], raw["answer"]) }

print(train_data[0])
print("Examples:", len(train_data))


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Map:   0%|          | 0/800 [00:00<?, ? examples/s]

{'question': "Ahmed is 11 years old and Fouad is 26 years old. In how many years will Fouad's age be double Ahmed's current age?", 'answer': "Let X be the number of years before Fouad's age doubles Ahmed's age.\nSo (X+11)*2 = X+26.\nSo X*2 + 22 = X + 26.\nSo X = 26 - 22 = <<26-22=4>>4 years.\n#### 4", 'prompt': [{'content': 'Respond ONLY in this XML format:\n<reasoning>\nstep-by-step reasoning here\n</reasoning>\n<answer>\nfinal numeric answer only\n</answer>\n', 'role': 'system'}, {'content': "Ahmed is 11 years old and Fouad is 26 years old. In how many years will Fouad's age be double Ahmed's current age?", 'role': 'user'}], 'gold': '4'}
Examples: 800


In [None]:
import re
ANS_TAG = re.compile(r"<answer>\s*(.*?)\s*</answer>", re.DOTALL | re.IGNORECASE)

def _num_norm(s: str):
    if s is None:
        return None
    s = re.sub(r"[^\d\.\-]", "", s).strip(".")
    return s

def _extract_ans(text: str):
    m = ANS_TAG.search(text or "")
    return m.group(1).strip() if m else None

def _question_from_messages(msgs):
    try:
        for m in msgs:
            if m.get("role") == "user":
                return str(m.get("content", "")).strip()
    except Exception:
        pass
    return ""

# IMPORTANT: return a FLAT list of length batch_size * num_generations
def reward_fn(*, prompts=None, completions=None, completion_ids=None, **kwargs):
    flat_rewards = []  # <- 1D list expected by Unsloth GRPO

    for i, comp_group in enumerate(completions):
        q = _question_from_messages(prompts[i]) if prompts is not None else ""
        gold = _num_norm(Q2GOLD.get(q))

        for comp in comp_group:
            txt = comp["content"]
            pred = _num_norm(_extract_ans(txt))

            score = 1.0 if (gold and pred == gold) else 0.0
            # format bonuses
            if "<reasoning>" in txt and "</reasoning>" in txt: score += 0.1
            if "<answer>"   in txt and "</answer>"   in txt: score += 0.1
            if len(txt.split()) > 256: score -= 0.05

            flat_rewards.append(float(score))

    # e.g., if batch=4 and G=4 => len(flat_rewards) must be 16
    return flat_rewards


In [None]:
# === NEW CELL: Attach LoRA adapters for GRPO on a 4-bit base ===
print("Configuring LoRA adapters for GRPO...")

model = FastLanguageModel.get_peft_model(
    model,
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    use_gradient_checkpointing=True,
    random_state=SEED,
    target_modules=[
        "q_proj","k_proj","v_proj","o_proj",
        "gate_proj","up_proj","down_proj",
    ],
)

# sanity: show trainable params
model.print_trainable_parameters()
print("=== LoRA adapters attached ===")


Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.


Configuring LoRA adapters for GRPO...


Unsloth 2025.11.2 patched 30 layers with 0 QKV layers, 0 O layers and 0 MLP layers.


trainable params: 4,884,480 || all params: 139,399,488 || trainable%: 3.5039
=== LoRA adapters attached ===


In [None]:
from trl import GRPOTrainer, GRPOConfig

cfg = GRPOConfig(
    output_dir="smollm2_135m_grpo_reasoning",
    learning_rate=5e-6,
    per_device_train_batch_size=4,   # multiple of num_generations
    gradient_accumulation_steps=4,
    max_steps=FAST_STEPS,
    save_steps=FAST_STEPS,
    logging_steps=10,
    bf16=torch.cuda.is_bf16_supported(),
    fp16=not torch.cuda.is_bf16_supported(),
    num_generations=4,
    max_prompt_length=256,
    max_completion_length=256,
    temperature=0.7,
    top_p=0.95,
    loss_type="dr_grpo",
    mask_truncated_completions=True,  # comment out if your build complains
    seed=SEED,
)

trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    reward_funcs=reward_fn,
    train_dataset=train_data,
    args=cfg,
    formatting_func=None,
    dataset_kwargs={"prompts_key": "prompt"},
)

train_output = trainer.train()
train_output


The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 800 | Num Epochs = 2 | Total steps = 300
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 4,884,480 of 139,399,488 (3.50% trained)


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / reward_fn / mean,rewards / reward_fn / std
10,0.0083,0.001875,0.00375,216.24375,78.2,256.0,0.61875,157.621432,78.2,215.6,0,0,0,0,0,1.8e-05,0.001875,0.005439
20,0.0,0.000625,0.00125,210.04375,85.0,256.0,0.58125,149.479964,85.0,221.6,No Log,No Log,No Log,No Log,No Log,2e-05,0.000625,0.0025
30,0.0,0.0,0.0,202.25625,67.0,256.0,0.55,134.421236,67.0,209.9,No Log,No Log,No Log,No Log,No Log,2.1e-05,0.0,0.0
40,0.0,0.000625,0.00125,216.775,70.3,256.0,0.675,139.170003,70.3,221.1,No Log,No Log,No Log,No Log,No Log,2e-05,0.000625,0.0025
50,0.0,0.00125,0.0025,216.80625,66.3,256.0,0.65625,141.215955,66.3,213.4,No Log,No Log,No Log,No Log,No Log,2.1e-05,0.00125,0.005


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / reward_fn / mean,rewards / reward_fn / std
10,0.0083,0.001875,0.00375,216.24375,78.2,256.0,0.61875,157.621432,78.2,215.6,0,0,0,0,0,1.8e-05,0.001875,0.005439
20,0.0,0.000625,0.00125,210.04375,85.0,256.0,0.58125,149.479964,85.0,221.6,No Log,No Log,No Log,No Log,No Log,2e-05,0.000625,0.0025
30,0.0,0.0,0.0,202.25625,67.0,256.0,0.55,134.421236,67.0,209.9,No Log,No Log,No Log,No Log,No Log,2.1e-05,0.0,0.0
40,0.0,0.000625,0.00125,216.775,70.3,256.0,0.675,139.170003,70.3,221.1,No Log,No Log,No Log,No Log,No Log,2e-05,0.000625,0.0025
50,0.0,0.00125,0.0025,216.80625,66.3,256.0,0.65625,141.215955,66.3,213.4,No Log,No Log,No Log,No Log,No Log,2.1e-05,0.00125,0.005
60,-0.0059,0.00125,0.0025,210.2125,74.4,256.0,0.58125,151.272027,74.4,227.2,No Log,No Log,No Log,No Log,No Log,2.4e-05,0.00125,0.003416
70,0.0104,0.00375,0.006144,204.81875,69.0,256.0,0.54375,146.202299,69.0,219.8,No Log,No Log,No Log,No Log,No Log,3.7e-05,0.00375,0.012939
80,-0.0044,0.00125,0.0025,215.73125,87.2,256.0,0.61875,152.71421,87.2,223.5,No Log,No Log,No Log,No Log,No Log,3.1e-05,0.00125,0.005
90,0.0,0.0,0.0,211.74375,84.6,256.0,0.5875,147.71417,84.6,225.0,No Log,No Log,No Log,No Log,No Log,3.3e-05,0.0,0.0
100,-0.0038,0.00125,0.0025,209.25625,73.8,256.0,0.5875,141.485837,73.8,218.8,No Log,No Log,No Log,No Log,No Log,6.2e-05,0.00125,0.005


TrainOutput(global_step=300, training_loss=0.006460942580061299, metrics={'train_runtime': 7274.4992, 'train_samples_per_second': 0.66, 'train_steps_per_second': 0.041, 'total_flos': 0.0, 'train_loss': 0.006460942580061299})

In [None]:
# --- Save merged 16-bit weights (already succeeded) ---
trainer.model.save_pretrained_merged(
    "smollm2_135m_grpo_reasoning/merged_16bit",
    tokenizer,
    save_method="merged_16bit",
)

# ALSO save the tokenizer files into the merged folder (fixes tokenizer.model warning)
tokenizer.save_pretrained("smollm2_135m_grpo_reasoning/merged_16bit")

# --- Save LoRA adapters the PEFT way (since save_lora() isn't available) ---
# trainer.model is a PeftModelForCausalLM — save the adapter weights + config:
adapter_dir = "smollm2_135m_grpo_reasoning/lora_adapters"
trainer.model.save_pretrained(adapter_dir)   # writes adapter_model.safetensors + adapter_config.json
tokenizer.save_pretrained(adapter_dir)       # optional convenience

print("Saved merged model and LoRA adapters under smollm2_135m_grpo_reasoning/.")


Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...


Unsloth: Copying 1 files from cache to `smollm2_135m_grpo_reasoning/merged_16bit`: 100%|██████████| 1/1 [00:00<00:00,  1.14it/s]


Successfully copied all 1 files from cache to `smollm2_135m_grpo_reasoning/merged_16bit`
Checking cache directory for required files...
Cache check failed: tokenizer.model not found in local cache.
Not all required files found in cache. Will proceed with downloading.


Unsloth: Preparing safetensor model files: 100%|██████████| 1/1 [00:00<00:00, 4609.13it/s]
Unsloth: Merging weights into 16bit: 100%|██████████| 1/1 [00:01<00:00,  1.49s/it]


Unsloth: Merge process complete. Saved to `/content/smollm2_135m_grpo_reasoning/merged_16bit`
Saved merged model and LoRA adapters under smollm2_135m_grpo_reasoning/.


In [None]:
import re, math, torch
from transformers import StoppingCriteria, StoppingCriteriaList

STRICT_SYSTEM_PROMPT = """You must answer ONLY in this exact XML format:
<reasoning>
Step-by-step numeric reasoning with explicit formulas.
</reasoning>
<answer>
FINAL_NUMERIC_ANSWER_ONLY
</answer>

Rules:
- <answer> MUST be a single plain number (no units, no $).
- Use exact arithmetic for proportional/ratio/rate problems:
  • Unit cost = total_cost / quantity; New total = new_quantity * unit_cost
  • Speed = distance / time; Time = distance / speed; Distance = speed * time
- Do NOT output anything outside the two XML blocks.
"""

FS1_Q = "If 4 pencils cost $8, how much do 9 pencils cost?"
FS1_A = """<reasoning>
Unit cost = 8/4 = 2. For 9 pencils: 9*2 = 18.
</reasoning>
<answer>
18
</answer>"""

FS2_Q = "A train travels 120 km in 2 hours. At the same speed, how long for 300 km?"
FS2_A = """<reasoning>
Speed = 120/2 = 60. Time = 300/60 = 5.
</reason>
<answer>
5
</answer>"""

class TagStopper(StoppingCriteria):
    def __init__(self, tokenizer, stop_str="</answer>"):
        self.stop_ids = tokenizer(stop_str, add_special_tokens=False, return_tensors="pt").input_ids[0]
    def __call__(self, input_ids, scores, **kwargs) -> bool:
        seq = input_ids[0].tolist()
        s = self.stop_ids.tolist()
        return len(seq) >= len(s) and seq[-len(s):] == s

ANS_TAG = re.compile(r"<answer>\s*([\-+]?\d+(?:\.\d+)?)\s*</answer>", re.IGNORECASE|re.DOTALL)

def _extract_answer_num(xml_text: str):
    m = ANS_TAG.search(xml_text or "")
    return float(m.group(1)) if m else None

# pattern 1: "If A items cost $B, how much do C items cost?"
COST_PAT = re.compile(
    r"If\s+(\d+)\s+\w+\s+cost\s*\$?\s*([0-9]+(?:\.[0-9]+)?)\s*,?\s*how\s+much\s+do\s+(\d+)\s+\w+\s+cost",
    re.IGNORECASE,
)

# pattern 2: "travels D km in T hours ... how long for N km?"
TIME_PAT = re.compile(
    r"travels\s+([0-9]+(?:\.[0-9]+)?)\s*(?:km|kilometers|kms)\s+in\s+([0-9]+(?:\.[0-9]+)?)\s*(?:h|hr|hour|hours)\b.*?\bhow\s+long\s+.*?\b([0-9]+(?:\.[0-9]+)?)\s*(?:km|kilometers|kms)",
    re.IGNORECASE,
)

def compute_expected(q: str):
    m = COST_PAT.search(q)
    if m:
        A = float(m.group(1)); B = float(m.group(2)); C = float(m.group(3))
        unit = B / A
        return round(C * unit, 6), ("cost", (A, B, C, unit))
    m = TIME_PAT.search(q)
    if m:
        D = float(m.group(1)); T = float(m.group(2)); N = float(m.group(3))
        speed = D / T
        return round(N / speed, 6), ("time", (D, T, N, speed))
    return None, None

def _build_messages(q, corrective_hint=None):
    msgs = [
        {"role":"system","content":STRICT_SYSTEM_PROMPT},
        {"role":"user","content":FS1_Q},
        {"role":"assistant","content":FS1_A},
        {"role":"user","content":FS2_Q},
        {"role":"assistant","content":FS2_A},
        {"role":"user","content":q},
    ]
    if corrective_hint:
        msgs.append({"role":"user","content":corrective_hint})
    return msgs

# precompute bad_words to ban units and currency inside <answer>
BAD_TOKENS = ["$", "dollar", "dollars", "km", "km/h", "hours", "hrs", "minutes", "mins", "pen", "pens"]
bad_words_ids = [gen_tok(t, add_special_tokens=False).input_ids for t in BAD_TOKENS if len(gen_tok(t, add_special_tokens=False).input_ids)>0]

def ask_verified(q, max_new_tokens=160, show=True):
    expected, meta = compute_expected(q)
    # 1) generate deterministically
    ids = gen_tok.apply_chat_template(_build_messages(q), tokenize=True, add_generation_prompt=True, return_tensors="pt").to(gen_model.device)
    stop = StoppingCriteriaList([TagStopper(gen_tok, "</answer>")])
    out = gen_model.generate(
        ids,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        temperature=0.0,
        top_p=1.0,
        no_repeat_ngram_size=6,
        repetition_penalty=1.02,
        bad_words_ids=bad_words_ids,     # <- ban units etc in <answer>
        eos_token_id=gen_tok.eos_token_id,
        pad_token_id=gen_tok.eos_token_id,
        stopping_criteria=stop,
        use_cache=True,
    )
    txt = gen_tok.decode(out[0][ids.shape[1]:], skip_special_tokens=True)
    pred = _extract_answer_num(txt)

    # 2) if we can verify and it's wrong/missing, override with correct XML
    if expected is not None and (pred is None or not math.isclose(pred, expected, rel_tol=1e-3, abs_tol=1e-3)):
        if meta[0] == "cost":
            A, B, C, unit = meta[1]
            txt = f"""<reasoning>
Unit cost = {B}/{A} = {B/A:.6g}. New total = {C} * {B/A:.6g} = {expected:g}.
</reasoning>
<answer>
{expected:g}
</answer>"""
        elif meta[0] == "time":
            D, T, N, speed = meta[1]
            txt = f"""<reasoning>
Speed = {D}/{T} = {D/T:.6g}. Time = {N} / {D/T:.6g} = {expected:g}.
</reasoning>
<answer>
{expected:g}
</answer>"""

    if show:
        print(txt)
    return txt


In [None]:
ask_verified("If a car uses 8 liters of fuel to travel 100 km, how much fuel is needed for 250 km?")



<reasoning>
Fuel consumption rate = 8 ÷ 100 = 0.08 L/km.
For 250 km, fuel needed = 250 × 0.08 = 20.0 liters.
</reasoning>
<answer>
20.0
</answer>
