# PRM-Math: Process Reward Model for Mathematical Reasoning

This notebook implements a **Process Reward Model (PRM)** using the "Generative Verifier" paradigm. It fine-tunes **Qwen-2.5-Math-1.5B-Instruct** to classify intermediate reasoning steps as correct (+) or incorrect (-) in mathematical problem-solving.

## Key Features
- **Paradigm**: Decoder-Only Generative Verifier
- **Base Model**: Qwen/Qwen2.5-Math-1.5B-Instruct (1.5B parameters)
- **Training Method**: QLoRA (4-bit quantization) via Unsloth + TRL
- **Inference Engine**: Best-of-N search with step-wise verification
- **Data Source**: Math-Shepherd dataset

---

**Important**: Make sure to select a GPU runtime!
- Go to `Runtime` → `Change runtime type` → Select `T4 GPU` (or better)

## 1. Check GPU and Environment

In [None]:
# Check GPU availability
!nvidia-smi

import torch
print(f"\nPyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Wed Dec 10 22:13:31 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA L4                      Off |   00000000:00:03.0 Off |                    0 |
| N/A   34C    P8             16W /   72W |       0MiB /  23034MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

## 2. Install Dependencies

**IMPORTANT: Follow these steps carefully!**

1. Run the installation cell below
2. **Restart the runtime**: `Runtime` → `Restart runtime`
3. After restart, **skip the installation cell** and continue from "Verify installations"

This is required because NumPy needs to be downgraded and the runtime must reload the correct version.

In [None]:
# IMPORTANT: Run this cell, then restart runtime (Runtime -> Restart runtime)
# After restart, skip this cell and continue from the next one

# First, fix NumPy version (must be done before other installs)
!pip uninstall numpy -y
!pip install "numpy<2.0.0"

# Install Unsloth for Colab (handles CUDA compatibility automatically)
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"

# Install other dependencies
!pip install transformers>=4.36.0 datasets>=2.14.0 accelerate>=0.25.0
!pip install trl>=0.7.0 peft>=0.7.0 bitsandbytes>=0.41.0
!pip install pyyaml>=6.0 tqdm>=4.66.0

# Note: Skipping vLLM as it can cause conflicts. Using transformers for inference instead.
# !pip install vllm>=0.2.0

print("\n" + "="*60)
print("IMPORTANT: Now restart the runtime!")
print("Go to: Runtime -> Restart runtime")
print("Then skip this cell and run the next cells.")
print("="*60)

Found existing installation: numpy 2.0.2
Uninstalling numpy-2.0.2:
  Successfully uninstalled numpy-2.0.2
Collecting numpy<2.0.0
  Downloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m3.4 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m18.0/18.0 MB[0m [31m133.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: numpy
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
shap 0.50.0 requires numpy>=2, but you have numpy 1.26.4 which is incomp

Collecting unsloth@ git+https://github.com/unslothai/unsloth.git (from unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Cloning https://github.com/unslothai/unsloth.git to /tmp/pip-install-n2_48v03/unsloth_27f0b89e3c2d4769a186b0f0d24535bd
  Running command git clone --filter=blob:none --quiet https://github.com/unslothai/unsloth.git /tmp/pip-install-n2_48v03/unsloth_27f0b89e3c2d4769a186b0f0d24535bd
  Resolved https://github.com/unslothai/unsloth.git to commit aa063de198f44822b2a7e7d0d9b97a4bc5e705c7
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting unsloth_zoo>=2025.12.3 (from unsloth@ git+https://github.com/unslothai/unsloth.git->unsloth[colab-new]@ git+https://github.com/unslothai/unsloth.git)
  Downloading unsloth_zoo-2025.12.3-py3-none-any.whl.metadata (32 kB)
Collecting tyro (from unsloth@ git+https://github.com/unslothai/unsloth.gi

In [None]:
# Verify installations (run this AFTER restarting runtime)
import numpy as np
print(f"NumPy version: {np.__version__}")

# Check NumPy version
if np.__version__.startswith("2"):
    print("\n⚠️  WARNING: NumPy 2.x detected!")
    print("Please run the installation cell above, then restart runtime.")
    raise RuntimeError("NumPy version must be < 2.0.0")

import transformers
import datasets
import peft
import trl
from unsloth import FastLanguageModel

print(f"Transformers: {transformers.__version__}")
print(f"Datasets: {datasets.__version__}")
print(f"PEFT: {peft.__version__}")
print(f"TRL: {trl.__version__}")
print("Unsloth: Installed successfully!")
print("\n✓ All dependencies verified!")

NumPy version: 1.26.4



Please restructure your imports with 'import unsloth' at the top of your file.
  from unsloth import FastLanguageModel


🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!


  import trl.experimental.openenv.utils as openenv_utils


Transformers: 4.57.3
Datasets: 4.3.0
PEFT: 0.18.0
TRL: 0.24.0
Unsloth: Installed successfully!

✓ All dependencies verified!


## 3. Mount Google Drive & Configuration

Mount Google Drive for persistent storage of checkpoints and models.

In [None]:
# Mount Google Drive for persistent storage
from google.colab import drive
drive.mount('/content/drive')

# Create project directory in Google Drive
import os
DRIVE_PROJECT_PATH = "/content/drive/MyDrive/Colab Notebooks/PRM-Math"
os.makedirs(DRIVE_PROJECT_PATH, exist_ok=True)
os.makedirs(f"{DRIVE_PROJECT_PATH}/checkpoints", exist_ok=True)
os.makedirs(f"{DRIVE_PROJECT_PATH}/logs", exist_ok=True)

print(f"Google Drive mounted!")
print(f"Project path: {DRIVE_PROJECT_PATH}")

# Check for existing checkpoints (for resume training)
checkpoint_dir = f"{DRIVE_PROJECT_PATH}/checkpoints"
existing_checkpoints = []
if os.path.exists(checkpoint_dir):
    for item in os.listdir(checkpoint_dir):
        item_path = os.path.join(checkpoint_dir, item)
        if os.path.isdir(item_path) and item.startswith("checkpoint-"):
            existing_checkpoints.append(item_path)
    existing_checkpoints.sort(key=lambda x: int(x.split("-")[-1]))

if existing_checkpoints:
    print(f"\nFound {len(existing_checkpoints)} existing checkpoint(s):")
    for cp in existing_checkpoints[-3:]:  # Show last 3
        print(f"  - {cp}")
    print(f"\nLatest: {existing_checkpoints[-1]}")
    RESUME_FROM_CHECKPOINT = existing_checkpoints[-1]
else:
    print("\nNo existing checkpoints found. Will start fresh training.")
    RESUME_FROM_CHECKPOINT = None

# Check for existing merged model
merged_model_path = f"{DRIVE_PROJECT_PATH}/checkpoints/merged_model"
if os.path.exists(merged_model_path):
    print(f"\nFound existing merged model at: {merged_model_path}")
    print("You can skip training and go directly to evaluation.")

In [None]:
# Configuration - using Google Drive paths for persistence
from dataclasses import dataclass, field
from typing import List, Optional

@dataclass
class Config:
    # Project settings
    project_name: str = "qwen-prm-math"
    seed: int = 42
    output_dir: str = f"{DRIVE_PROJECT_PATH}/checkpoints"  # Save to Google Drive
    logging_dir: str = f"{DRIVE_PROJECT_PATH}/logs"

    # Data settings
    dataset_name: str = "peiyi9979/Math-Shepherd"
    max_samples: int = 30000  # Adjust based on time/memory
    balance_positives: bool = True
    validation_split: float = 0.1

    # Model settings - 1.5B model
    base_model: str = "Qwen/Qwen2.5-Math-1.5B-Instruct"
    max_seq_length: int = 2048
    load_in_4bit: bool = True

    # LoRA settings
    lora_r: int = 16
    lora_alpha: int = 16
    lora_dropout: float = 0.05
    target_modules: List[str] = field(default_factory=lambda: [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj"
    ])

    # Training settings
    batch_size: int = 8  # For L4 GPU with 1.5B model
    gradient_accumulation_steps: int = 4
    learning_rate: float = 1e-4
    num_train_epochs: int = 3
    warmup_ratio: float = 0.03
    save_steps: int = 100  # Save checkpoint every 100 steps
    logging_steps: int = 10
    response_template: str = "<|verify|>"
    
    # Resume training
    resume_from_checkpoint: str = RESUME_FROM_CHECKPOINT  # Auto-detected above

    # Inference settings
    n_candidates: int = 16
    temperature: float = 0.7
    max_new_tokens: int = 512

# Create config instance
config = Config()

print("Configuration loaded!")
print(f"  Base model: {config.base_model}")
print(f"  Max samples: {config.max_samples}")
print(f"  Batch size: {config.batch_size}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Output dir: {config.output_dir}")
if config.resume_from_checkpoint:
    print(f"  Resume from: {config.resume_from_checkpoint}")

## 4. Load and Process Dataset

Load the Math-Shepherd dataset and process it for PRM training.

In [None]:
from datasets import load_dataset
import re
from typing import Dict, Any, List
import random

print("Loading Math-Shepherd dataset...")
raw_dataset = load_dataset(config.dataset_name, split="train")
print(f"Loaded {len(raw_dataset)} examples")

# Inspect the dataset format
print("\n" + "="*50)
print("Dataset columns:", raw_dataset.column_names)
print("="*50)
print("\nSample item (label field - last 200 chars):")
sample = raw_dataset[0]
print(repr(sample['label'][-200:]))
print("="*50)


def parse_math_shepherd_item(item: Dict) -> List[Dict[str, Any]]:
    """
    Parse Math-Shepherd dataset format.

    The 'label' field contains: "Problem text Step 1: ... +\nStep 2: ... -\n..."
    Each step ends with a space and +/- label.
    """
    label_text = item.get("label", "")
    if not label_text:
        return []

    # Find where Step 1 starts to separate problem from steps
    step1_match = re.search(r'Step\s*1\s*[:\.]', label_text)
    if not step1_match:
        return []

    problem = label_text[:step1_match.start()].strip()
    steps_text = label_text[step1_match.start():]

    if not problem:
        return []

    # Parse steps using regex to find "Step N: content +/-"
    # The pattern captures: step header, content, and label
    step_pattern = r'(Step\s*\d+\s*[:\.])\s*(.+?)\s+([+-])(?=\s*Step\s*\d+|\s*$)'
    matches = re.findall(step_pattern, steps_text, re.DOTALL)

    if not matches:
        # Try alternative: split by newlines and look for +/- at end
        lines = steps_text.strip().split('\n')
        steps = []
        for line in lines:
            line = line.strip()
            if not line:
                continue
            # Check for " +" or " -" at the end
            match = re.match(r'(.+?)\s+([+-])\s*$', line)
            if match:
                steps.append({"text": match.group(1).strip(), "label": match.group(2)})

        if not steps:
            return []
    else:
        steps = []
        for header, content, label in matches:
            step_text = f"{header} {content}".strip()
            steps.append({"text": step_text, "label": label})

    if not steps:
        return []

    # Create training examples with cumulative context
    examples = []
    context = f"Problem: {problem}\n\nSolution:"

    for step in steps:
        example = {
            "context": context,
            "step": step["text"],
            "label": step["label"]
        }
        examples.append(example)
        context = f"{context}\n{step['text']}"

    return examples


def format_for_training(example: Dict[str, Any], response_template: str) -> str:
    """Format example for generative verifier training."""
    text = f"{example['context']}\n{example['step']}\n{response_template} {example['label']}"
    return text


# Process dataset
print("\nProcessing dataset into PRM format...")
all_examples = []
parsed_count = 0
failed_count = 0

# Test parsing on first item
test_result = parse_math_shepherd_item(raw_dataset[0])
print(f"\nTest parsing first item: {len(test_result)} steps found")
if test_result:
    print(f"  First step: {test_result[0]['step'][:80]}... [{test_result[0]['label']}]")

for i, item in enumerate(raw_dataset):
    if len(all_examples) >= config.max_samples * 2:
        break

    examples = parse_math_shepherd_item(item)

    if examples:
        all_examples.extend(examples)
        parsed_count += 1
    else:
        failed_count += 1
        if failed_count <= 2:
            print(f"\nCould not parse item {i}:")
            print(f"  Label (last 200 chars): {repr(item.get('label', 'N/A')[-200:])}")

print(f"\nParsing results:")
print(f"  Successfully parsed: {parsed_count} items")
print(f"  Failed to parse: {failed_count} items")
print(f"  Total step-level examples: {len(all_examples)}")

if len(all_examples) == 0:
    raise ValueError("Could not create any training examples. Please check the dataset format.")

# Balance positive and negative examples
positives = [ex for ex in all_examples if ex["label"] == "+"]
negatives = [ex for ex in all_examples if ex["label"] == "-"]

print(f"\nLabel distribution:")
print(f"  Positive examples: {len(positives)}")
print(f"  Negative examples: {len(negatives)}")

if config.balance_positives and len(negatives) > 0 and len(positives) > 0:
    min_count = min(len(positives), len(negatives), config.max_samples // 2)
    random.seed(config.seed)
    balanced_examples = (
        random.sample(positives, min_count) +
        random.sample(negatives, min_count)
    )
    random.shuffle(balanced_examples)
    all_examples = balanced_examples
    print(f"  Balanced to: {len(all_examples)} examples")
else:
    random.seed(config.seed)
    random.shuffle(all_examples)
    all_examples = all_examples[:config.max_samples]
    print(f"  Limited to: {len(all_examples)} examples")

# Format for training
formatted_texts = [
    format_for_training(ex, config.response_template)
    for ex in all_examples
]

# Create HuggingFace dataset
from datasets import Dataset
train_dataset = Dataset.from_dict({"text": formatted_texts})

# Train/validation split
split_dataset = train_dataset.train_test_split(
    test_size=config.validation_split,
    seed=config.seed
)

print(f"\nFinal dataset:")
print(f"  Training: {len(split_dataset['train'])} examples")
print(f"  Validation: {len(split_dataset['test'])} examples")

# Show sample
print("\n" + "="*50)
print("Sample training example:")
print("="*50)
sample_text = formatted_texts[0]
print(sample_text[:800] + "..." if len(sample_text) > 800 else sample_text)

Loading Math-Shepherd dataset...


README.md: 0.00B [00:00, ?B/s]

math-shepherd.jsonl:   0%|          | 0.00/793M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/444655 [00:00<?, ? examples/s]

Loaded 444655 examples

Dataset columns: ['input', 'label', 'task']

Sample item (label field - last 200 chars):
'sons per week. +\nStep 4: Janet spends 120 + 140 = <<120+140=260>>260 on music lessons per week. +\nStep 5: She spends 260 * 52 = <<260*52=13520>>13520 on music lessons in a year. The answer is: 13520 -'

Processing dataset into PRM format...

Test parsing first item: 5 steps found
  First step: Step 1: Janet spends 3 hours + 5 hours = <<3+5=8>>8 hours per week on music less... [+]

Could not parse item 926:
  Label (last 200 chars): "y dollars did Jerusha earn? Use L to represent Lottie's earnings. Jerusha earned 4L. +\nJerusha and Lottie earned 4L+85=<<4L+85=99>>99 together. +\nJerusha earned 99-L=<<99-L=85>>85. The answer is: 85 -"

Parsing results:
  Successfully parsed: 7988 items
  Failed to parse: 1 items
  Total step-level examples: 30001

Label distribution:
  Positive examples: 8683
  Negative examples: 21318
  Balanced to: 15000 examples

Final dataset:
  Trainin

## 5. Load Model with Unsloth

Load the base model with 4-bit quantization and add LoRA adapters.

In [None]:
from unsloth import FastLanguageModel
import torch

print(f"Loading model: {config.base_model}")
print("This may take a few minutes...\n")

# Load model with Unsloth (handles 4-bit quantization automatically)
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=config.base_model,
    max_seq_length=config.max_seq_length,
    dtype=None,  # Auto-detect
    load_in_4bit=config.load_in_4bit,
)

print(f"Model loaded!")
print(f"  Parameters: {model.num_parameters():,}")
print(f"  Max sequence length: {config.max_seq_length}")

# Add LoRA adapters
model = FastLanguageModel.get_peft_model(
    model,
    r=config.lora_r,
    target_modules=config.target_modules,
    lora_alpha=config.lora_alpha,
    lora_dropout=config.lora_dropout,
    bias="none",
    use_gradient_checkpointing="unsloth",
    random_state=config.seed,
)

# Ensure padding token is set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

print(f"\nLoRA adapters added!")
print(f"  LoRA rank: {config.lora_r}")
print(f"  Target modules: {config.target_modules}")

Loading model: Qwen/Qwen2.5-Math-1.5B-Instruct
This may take a few minutes...

==((====))==  Unsloth 2025.12.4: Fast Qwen2 patching. Transformers: 4.57.3.
   \\   /|    NVIDIA L4. Num GPUs = 1. Max memory: 22.161 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu126. CUDA: 8.9. CUDA Toolkit: 12.6. Triton: 3.5.0
\        /    Bfloat16 = TRUE. FA [Xformers = None. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/1.14G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/161 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/632 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

Unsloth: Dropout = 0 is supported for fast patching. You are using dropout = 0.05.
Unsloth will patch all other layers, except LoRA matrices, causing a performance hit.


Model loaded!
  Parameters: 1,543,714,304
  Max sequence length: 2048


Unsloth 2025.12.4 patched 28 layers with 0 QKV layers, 0 O layers and 0 MLP layers.



LoRA adapters added!
  LoRA rank: 16
  Target modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']


## 6. Training

Fine-tune the model using TRL's SFTTrainer with completion-only loss.

In [None]:
from transformers import TrainingArguments
import trl
import torch

print(f"TRL version: {trl.__version__}")

# Fix Unsloth + TRL 0.24 compatibility issue
try:
    import unsloth.trainer
    if not hasattr(unsloth.trainer, 'PADDING_FREE_BLOCKLIST'):
        unsloth.trainer.PADDING_FREE_BLOCKLIST = []
except:
    pass

from trl import SFTTrainer

# Custom data collator that only computes loss on tokens after <|verify|>
class DataCollatorForCompletionOnlyLM:
    """
    Custom collator that masks labels before the response template.
    Only computes loss on tokens after the response template.
    """
    def __init__(self, response_template, tokenizer):
        self.response_template = response_template  # List of token IDs
        self.tokenizer = tokenizer

    def __call__(self, examples):
        # Tokenize if needed
        if isinstance(examples[0], dict) and "text" in examples[0]:
            texts = [ex["text"] for ex in examples]
            batch = self.tokenizer(
                texts,
                padding=True,
                truncation=True,
                max_length=2048,
                return_tensors="pt"
            )
        else:
            batch = self.tokenizer.pad(examples, return_tensors="pt")

        # Create labels (copy of input_ids)
        labels = batch["input_ids"].clone()

        # Mask everything before response template
        for i, input_ids in enumerate(batch["input_ids"]):
            input_list = input_ids.tolist()
            response_start = None

            # Find where response template starts
            template_len = len(self.response_template)
            for j in range(len(input_list) - template_len + 1):
                if input_list[j:j + template_len] == self.response_template:
                    response_start = j + template_len
                    break

            if response_start is not None:
                # Mask everything before and including the template
                labels[i, :response_start] = -100
            else:
                # If template not found, mask everything (no loss)
                labels[i, :] = -100

        # Also mask padding tokens
        labels[labels == self.tokenizer.pad_token_id] = -100

        batch["labels"] = labels
        return batch

# Get response template token IDs
response_template_ids = tokenizer.encode(
    config.response_template,
    add_special_tokens=False
)
print(f"Response template token IDs: {response_template_ids}")

# Create collator
collator = DataCollatorForCompletionOnlyLM(
    response_template=response_template_ids,
    tokenizer=tokenizer,
)

# Training arguments
training_args = TrainingArguments(
    output_dir=config.output_dir,
    per_device_train_batch_size=config.batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    learning_rate=config.learning_rate,
    num_train_epochs=config.num_train_epochs,
    warmup_ratio=config.warmup_ratio,
    lr_scheduler_type="cosine",
    logging_steps=config.logging_steps,
    save_steps=config.save_steps,
    save_total_limit=3,
    fp16=not torch.cuda.is_bf16_supported(),
    bf16=torch.cuda.is_bf16_supported(),
    optim="adamw_8bit",
    seed=config.seed,
    report_to="none",  # Disable wandb/tensorboard in Colab
)

# Create trainer
trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=split_dataset["train"],
    eval_dataset=split_dataset["test"],
    data_collator=collator,
    args=training_args,
    max_seq_length=config.max_seq_length,
)

print("
Trainer configured!")
print(f"  Effective batch size: {config.batch_size * config.gradient_accumulation_steps}")
print(f"  Training examples: {len(split_dataset['train'])}")
print(f"  Steps per epoch: {len(split_dataset['train']) // (config.batch_size * config.gradient_accumulation_steps)}")

In [None]:
# Start training (with resume support)

# Check that trainer is defined (from previous cell)
if 'trainer' not in dir():
    raise NameError(
        "trainer is not defined!\n"
        "Please run the previous cell (Trainer Configuration) first.\n"
        "Make sure to run cells in order: Config → Dataset → Model → Trainer → Train"
    )

print("Starting training...")
print("="*50)

if config.resume_from_checkpoint:
    print(f"Resuming from checkpoint: {config.resume_from_checkpoint}")
    trainer_stats = trainer.train(resume_from_checkpoint=config.resume_from_checkpoint)
else:
    print("Starting fresh training...")
    trainer_stats = trainer.train()

print("\n" + "="*50)
print("Training complete!")
print(f"  Total steps: {trainer_stats.global_step}")
print(f"  Final loss: {trainer_stats.training_loss:.4f}")
print(f"  Checkpoints saved to: {config.output_dir}")

In [None]:
# Save the merged model to Google Drive
merged_model_path = f"{config.output_dir}/merged_model"

print(f"Saving merged model to {merged_model_path}...")
print("(This saves directly to Google Drive for persistence)")

# Save in 16-bit for inference
model.save_pretrained_merged(
    merged_model_path,
    tokenizer,
    save_method="merged_16bit",
)

print("\nModel saved successfully!")
print(f"Location: {merged_model_path}")
print("\nYou can now:")
print("1. Restart runtime for evaluation")
print("2. Or disconnect and reconnect later - your model is saved!")

In [None]:
# Save the model
merged_model_path = f"{config.output_dir}/merged_model"

print(f"Saving merged model to {merged_model_path}...")

# Save in 16-bit for inference
model.save_pretrained_merged(
    merged_model_path,
    tokenizer,
    save_method="merged_16bit",
)

print("Model saved successfully!")

Saving merged model to ./checkpoints/merged_model...


config.json:   0%|          | 0.00/761 [00:00<?, ?B/s]

Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...
Cache check failed: model.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.
Checking cache directory for required files...
Cache check failed: tokenizer.model not found in local cache.
Not all required files found in cache. Will proceed with downloading.


Unsloth: Preparing safetensor model files:   0%|          | 0/1 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/3.09G [00:00<?, ?B/s]

Unsloth: Preparing safetensor model files: 100%|██████████| 1/1 [00:10<00:00, 10.34s/it]


Note: tokenizer.model not found (this is OK for non-SentencePiece models)


Unsloth: Merging weights into 16bit: 100%|██████████| 1/1 [00:10<00:00, 10.31s/it]


Unsloth: Merge process complete. Saved to `/content/checkpoints/merged_model`
Model saved successfully!


## 7. Inference: Best-of-N with PRM Scoring

Use the trained PRM to score and rank multiple solution candidates.

In [None]:
import torch
import torch.nn.functional as F
from unsloth import FastLanguageModel

# First, prepare the trained model for inference
print("Preparing model for inference...")
FastLanguageModel.for_inference(model)

class PRMVerifier:
    """
    Process Reward Model verifier for scoring mathematical reasoning steps.
    Uses the trained Unsloth model in inference mode.
    """

    def __init__(self, model, tokenizer, device: str = "cuda"):
        self.device = device
        self.response_template = config.response_template
        self.model = model
        self.tokenizer = tokenizer

        # Get token IDs for + and -
        self.pos_token_id = self.tokenizer.encode("+", add_special_tokens=False)[0]
        self.neg_token_id = self.tokenizer.encode("-", add_special_tokens=False)[0]

        print("Verifier initialized!")
        print(f"  + token ID: {self.pos_token_id}")
        print(f"  - token ID: {self.neg_token_id}")

    def score_step(self, context: str, step: str) -> float:
        """
        Score a single reasoning step.

        Returns probability that the step is correct.
        """
        # Format input
        text = f"{context}\n{step}\n{self.response_template}"

        # Tokenize
        inputs = self.tokenizer(text, return_tensors="pt").to(self.device)

        # Get logits for next token
        with torch.no_grad():
            outputs = self.model(**inputs)
            next_token_logits = outputs.logits[0, -1, :]

        # Get probabilities for + and -
        probs = F.softmax(next_token_logits, dim=-1)
        pos_prob = probs[self.pos_token_id].item()
        neg_prob = probs[self.neg_token_id].item()

        # Normalize to get P(correct)
        score = pos_prob / (pos_prob + neg_prob) if (pos_prob + neg_prob) > 0 else 0.5

        return score

    def score_solution(self, problem: str, solution: str) -> dict:
        """
        Score an entire solution by scoring each step.

        Uses "Weakest Link" aggregation (min of step scores).
        """
        # Split solution into steps
        steps = [s.strip() for s in solution.split("\n") if s.strip()]

        if not steps:
            return {"score": 0.0, "step_scores": [], "steps": []}

        context = f"Problem: {problem}\n\nSolution:"
        step_scores = []

        for step in steps:
            score = self.score_step(context, step)
            step_scores.append(score)
            context = f"{context}\n{step}"

        # Aggregate using min (weakest link)
        final_score = min(step_scores) if step_scores else 0.0

        return {
            "score": final_score,
            "step_scores": step_scores,
            "steps": steps
        }


# Create verifier using the trained model (already in inference mode)
verifier = PRMVerifier(model, tokenizer)
print("\nVerifier ready!")

Preparing model for inference...
Verifier initialized!
  + token ID: 10
  - token ID: 12

Verifier ready!


In [None]:
# Solution generator using the same trained model (already in inference mode)
class SolutionGenerator:
    """
    Generate multiple solution candidates for a math problem.
    Uses the trained Unsloth model in inference mode.
    """

    def __init__(self, model, tokenizer, device: str = "cuda"):
        self.device = device
        self.model = model
        self.tokenizer = tokenizer

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        print("Generator initialized!")

    def generate_solutions(
        self,
        problem: str,
        n_candidates: int = 16,
        temperature: float = 0.7,
        max_new_tokens: int = 512
    ) -> list:
        """
        Generate multiple solution candidates.
        """
        prompt = f"Problem: {problem}\n\nSolution:\n"

        inputs = self.tokenizer(prompt, return_tensors="pt").to(self.device)

        solutions = []
        for i in range(n_candidates):
            with torch.no_grad():
                outputs = self.model.generate(
                    **inputs,
                    max_new_tokens=max_new_tokens,
                    temperature=temperature,
                    do_sample=True,
                    pad_token_id=self.tokenizer.pad_token_id,
                    eos_token_id=self.tokenizer.eos_token_id,
                    use_cache=True,
                )

            generated = self.tokenizer.decode(
                outputs[0][inputs["input_ids"].shape[1]:],
                skip_special_tokens=True
            )
            solutions.append(generated.strip())
            print(f"  Generated candidate {i+1}/{n_candidates}")

        return solutions


# Create generator using the same trained model
generator = SolutionGenerator(model, tokenizer)
print("\nGenerator ready!")
print("\nBoth verifier and generator are ready for inference!")

Generator initialized!

Generator ready!

Both verifier and generator are ready for inference!


In [ ]:
# ============================================================
# EVALUATION (Post-Restart) - Run this AFTER restarting runtime
# ============================================================
# DO NOT run any Unsloth cells before this!
# Just run this cell directly after restart.

# First, mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import re
import os

print("="*60)
print("EVALUATION MODE (No Unsloth)")
print("="*60)

# Configuration - use Google Drive paths
DRIVE_PROJECT_PATH = "/content/drive/MyDrive/Colab Notebooks/PRM-Math"
PRM_MODEL_PATH = f"{DRIVE_PROJECT_PATH}/checkpoints/merged_model"
BASE_MODEL_NAME = "Qwen/Qwen2.5-Math-1.5B-Instruct"  # For generation
VERIFY_TOKEN = "<|verify|>"

# Check if PRM model exists
if not os.path.exists(PRM_MODEL_PATH):
    print(f"ERROR: PRM model not found at {PRM_MODEL_PATH}")
    print("Please make sure training completed and model was saved.")
    print("\nAvailable files in checkpoints:")
    checkpoint_dir = f"{DRIVE_PROJECT_PATH}/checkpoints"
    if os.path.exists(checkpoint_dir):
        for item in os.listdir(checkpoint_dir):
            print(f"  - {item}")
else:
    print(f"PRM model found at: {PRM_MODEL_PATH}")
    print(f"Base model: {BASE_MODEL_NAME}")
    print("="*60)

In [None]:
# ============================================================
# EVALUATION (Post-Restart) - Run this AFTER restarting runtime
# ============================================================
# DO NOT run any Unsloth cells before this!
# Just run this cell directly after restart.

import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
from tqdm import tqdm
import re

print("="*60)
print("EVALUATION MODE (No Unsloth)")
print("="*60)

# Configuration
PRM_MODEL_PATH = "./checkpoints/merged_model"  # Your trained PRM
BASE_MODEL_NAME = "Qwen/Qwen2.5-Math-1.5B-Instruct"  # For generation
VERIFY_TOKEN = "<|verify|>"

# Check if PRM model exists
import os
if not os.path.exists(PRM_MODEL_PATH):
    print(f"ERROR: PRM model not found at {PRM_MODEL_PATH}")
    print("Please make sure training completed and model was saved.")
else:
    print(f"PRM model found at: {PRM_MODEL_PATH}")
    print(f"Base model: {BASE_MODEL_NAME}")
    print("="*60)

EVALUATION MODE (No Unsloth)
PRM model found at: ./checkpoints/merged_model
Base model: Qwen/Qwen2.5-Math-1.5B-Instruct


In [None]:
# Load BASE model for generation (clean, no Unsloth patches)
print("Loading BASE model for generation...")
base_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
)
base_model.eval()

if base_tokenizer.pad_token is None:
    base_tokenizer.pad_token = base_tokenizer.eos_token

print(f"Base model loaded: {BASE_MODEL_NAME}")

# Load PRM model for scoring
print("\nLoading PRM model for scoring...")
prm_tokenizer = AutoTokenizer.from_pretrained(PRM_MODEL_PATH)
prm_model = AutoModelForCausalLM.from_pretrained(
    PRM_MODEL_PATH,
    torch_dtype=torch.float16,
    device_map="auto",
)
prm_model.eval()

if prm_tokenizer.pad_token is None:
    prm_tokenizer.pad_token = prm_tokenizer.eos_token

# FIXED: Use correct token IDs (with leading space)
pos_token_id = prm_tokenizer.encode(" +", add_special_tokens=False)[0]  # ' +' not '+'
neg_token_id = prm_tokenizer.encode(" -", add_special_tokens=False)[0]  # ' -' not '-'

print(f"PRM model loaded from: {PRM_MODEL_PATH}")
print(f"Token IDs: ' +' = {pos_token_id}, ' -' = {neg_token_id}")
print("\nBoth models loaded successfully!")

Loading BASE model for generation...
Base model loaded: Qwen/Qwen2.5-Math-1.5B-Instruct

Loading PRM model for scoring...
PRM model loaded from: ./checkpoints/merged_model
Token IDs: ' +' = 488, ' -' = 481

Both models loaded successfully!


In [None]:
# Helper functions for evaluation

def generate_solution(problem, temperature=0.7, max_new_tokens=512):
    """Generate a solution using the BASE model."""
    messages = [
        {"role": "system", "content": "You are a helpful math assistant. Solve the problem step by step."},
        {"role": "user", "content": problem}
    ]
    prompt = base_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = base_tokenizer(prompt, return_tensors="pt").to("cuda")

    with torch.no_grad():
        # Handle temperature=0 (greedy) vs temperature>0 (sampling)
        if temperature == 0 or temperature is None:
            outputs = base_model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                do_sample=False,  # Greedy decoding
                pad_token_id=base_tokenizer.pad_token_id,
            )
        else:
            outputs = base_model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=temperature,
                do_sample=True,
                pad_token_id=base_tokenizer.pad_token_id,
            )

    generated = base_tokenizer.decode(outputs[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True)
    return generated.strip()


def score_solution(problem, solution):
    """Score a solution using the PRM model with product aggregation."""
    steps = [s.strip() for s in solution.split("\n") if s.strip()]

    if not steps:
        return 0.0, []

    context = f"Problem: {problem}\n\nSolution:"
    step_scores = []

    for step in steps:
        text = f"{context}\n{step}\n{VERIFY_TOKEN}"
        inputs = prm_tokenizer(text, return_tensors="pt").to("cuda")

        with torch.no_grad():
            outputs = prm_model(**inputs)
            logits = outputs.logits[0, -1, :]

        probs = F.softmax(logits, dim=-1)
        pos_prob = probs[pos_token_id].item()
        neg_prob = probs[neg_token_id].item()

        if pos_prob + neg_prob > 0:
            score = pos_prob / (pos_prob + neg_prob)
        else:
            score = 0.5

        step_scores.append(score)
        context = f"{context}\n{step}"

    # Product aggregation
    import math
    final_score = math.prod(step_scores) if step_scores else 0.0
    return final_score, step_scores


def extract_answer(text):
    """Extract numerical answer from solution text."""
    if not text:
        return None
    
    # Try \boxed{} format
    boxed_match = re.search(r"\\boxed\{([^}]+)\}", text)
    if boxed_match:
        return boxed_match.group(1).strip()
    
    # Try #### format (GSM8K style)
    if "####" in text:
        answer = text.split("####")[-1].strip()
        # Clean up the answer
        answer = re.sub(r"[^\d\.\-]", "", answer)
        if answer:
            try:
                return str(float(answer))
            except:
                return answer
    
    # Try "The answer is X" format
    answer_match = re.search(r"[Tt]he (?:final )?answer is[:\s]*([\d\.\-,]+)", text)
    if answer_match:
        answer = answer_match.group(1).replace(",", "")
        try:
            return str(float(answer))
        except:
            return answer
    
    # Try to find last number
    numbers = re.findall(r"[-]?\d+\.?\d*", text)
    if numbers:
        try:
            return str(float(numbers[-1]))
        except:
            return numbers[-1]
    
    return None


print("Helper functions defined:")
print("- generate_solution(problem, temperature=0.7) - handles temp=0 correctly")
print("- score_solution(problem, solution) - PRM scoring with product aggregation")
print("- extract_answer(text) - extracts numerical answers")

Helper functions defined!


In [None]:
def evaluate_gsm8k(n_problems=20, n_candidates=4, temperature=0.7):
    """Evaluate with multiple methods including PRM-weighted majority."""

    print("Loading GSM8K dataset...")
    gsm8k = load_dataset("gsm8k", "main", split="test")
    problems = list(gsm8k)[:n_problems]

    results = {"pass_1": 0, "majority": 0, "prm_rerank": 0, "prm_weighted": 0, "total": 0}

    print(f"\nEvaluating {n_problems} problems with {n_candidates} candidates...")
    print("BASE model for generation, PRM for scoring\n")

    for idx, item in enumerate(tqdm(problems)):
        question = item["question"]
        gt_match = re.search(r"####\s*([-\d,\.]+)", item["answer"])
        if not gt_match:
            continue
        ground_truth = str(float(gt_match.group(1).replace(",", "")))

        # Generate N candidates
        solutions = [generate_solution(question, temperature) for _ in range(n_candidates)]

        # Score all solutions
        scored = [(sol, score_solution(question, sol)[0]) for sol in solutions]
        answers = [extract_answer(sol) for sol in solutions]
        scores = [s[1] for s in scored]

        # Pass@1
        if answers[0] == ground_truth:
            results["pass_1"] += 1

        # Pure Majority vote
        from collections import Counter
        valid = [a for a in answers if a]
        if valid:
            majority = Counter(valid).most_common(1)[0][0]
            if majority == ground_truth:
                results["majority"] += 1

        # PRM Rerank (best score)
        scored.sort(key=lambda x: x[1], reverse=True)
        best_answer = extract_answer(scored[0][0])
        if best_answer == ground_truth:
            results["prm_rerank"] += 1

        # PRM-Weighted Majority (NEW)
        answer_weights = {}
        for ans, score in zip(answers, scores):
            if ans:
                answer_weights[ans] = answer_weights.get(ans, 0) + score
        if answer_weights:
            weighted_best = max(answer_weights, key=answer_weights.get)
            if weighted_best == ground_truth:
                results["prm_weighted"] += 1

        results["total"] += 1

        # Debug first problem
        if idx == 0:
            print(f"\n--- First Problem Debug ---")
            print(f"Q: {question[:80]}...")
            print(f"GT: {ground_truth}")
            print(f"Answers: {answers}")
            print(f"PRM scores: {[f'{s:.3f}' for s in scores]}")
            print(f"Answer weights: {answer_weights}")
            print(f"Weighted best: {weighted_best if answer_weights else 'N/A'}")

    # Results
    total = results["total"]
    print("\n" + "="*50)
    print("EVALUATION RESULTS")
    print("="*50)
    print(f"Problems: {total}")
    print(f"Candidates per problem: {n_candidates}")
    print("-"*50)
    print(f"Pass@1:           {results['pass_1']}/{total} = {results['pass_1']/total*100:.1f}%")
    print(f"Majority@{n_candidates}:        {results['majority']}/{total} = {results['majority']/total*100:.1f}%")
    print(f"PRM Rerank@{n_candidates}:      {results['prm_rerank']}/{total} = {results['prm_rerank']/total*100:.1f}%")
    print(f"PRM-Weighted@{n_candidates}:    {results['prm_weighted']}/{total} = {results['prm_weighted']/total*100:.1f}%")
    print("="*50)

    return results

# Run evaluation
eval_results = evaluate_gsm8k(n_problems=50, n_candidates=8, temperature=0.8)

Loading GSM8K dataset...

Evaluating 50 problems with 8 candidates...
BASE model for generation, PRM for scoring



  2%|▏         | 1/50 [01:32<1:15:55, 92.96s/it]


--- First Problem Debug ---
Q: Janet’s ducks lay 16 eggs per day. She eats three for breakfast every morning an...
GT: 18.0
Answers: ['18.0', '18.0', '18.0', '18.0', '14.0', '18.0', '18.0', '18.0']
PRM scores: ['0.032', '0.029', '0.279', '0.060', '0.015', '0.080', '0.101', '0.137']
Answer weights: {'18.0': 0.7175781479270358, '14.0': 0.014636016999102351}
Weighted best: 18.0


100%|██████████| 50/50 [1:19:25<00:00, 95.32s/it]


EVALUATION RESULTS
Problems: 50
Candidates per problem: 8
--------------------------------------------------
Pass@1:           40/50 = 80.0%
Majority@8:        45/50 = 90.0%
PRM Rerank@8:      42/50 = 84.0%
PRM-Weighted@8:    43/50 = 86.0%





In [None]:
# MATH-500 Dataset Evaluation (Competition-level math problems)
import re
import math

def normalize_math_answer(answer):
    """Normalize MATH dataset answers for comparison."""
    if answer is None:
        return None

    answer = str(answer).strip()

    # Remove LaTeX formatting
    answer = answer.replace("\\$", "").replace("$", "")
    answer = answer.replace("\\%", "%").replace("\\!", "")

    # Handle common LaTeX commands
    answer = re.sub(r'\\text\{([^}]*)\}', r'\1', answer)
    answer = re.sub(r'\\textbf\{([^}]*)\}', r'\1', answer)
    answer = re.sub(r'\\mathrm\{([^}]*)\}', r'\1', answer)
    answer = re.sub(r'\\left|\\right', '', answer)

    # Handle fractions: \frac{a}{b}
    frac_match = re.search(r'\\d?frac\{([^}]*)\}\{([^}]*)\}', answer)
    if frac_match:
        try:
            num = float(frac_match.group(1))
            den = float(frac_match.group(2))
            if den != 0:
                answer = str(num / den)
        except:
            pass

    # Handle sqrt
    sqrt_match = re.search(r'\\sqrt\{([^}]*)\}', answer)
    if sqrt_match:
        try:
            val = float(sqrt_match.group(1))
            answer = str(math.sqrt(val))
        except:
            pass

    # Remove remaining LaTeX commands
    answer = re.sub(r'\\[a-zA-Z]+', '', answer)
    answer = answer.replace("{", "").replace("}", "").replace(" ", "").strip()

    try:
        return str(float(answer))
    except:
        return answer.lower()


def extract_math_answer(text):
    """Extract answer from MATH-style solutions."""
    if not text:
        return None

    # Try \boxed{answer}
    boxed_patterns = [
        r'\\boxed\{([^{}]*(?:\{[^{}]*\}[^{}]*)*)\}',
        r'\\boxed\{([^}]+)\}',
    ]
    for pattern in boxed_patterns:
        matches = re.findall(pattern, text)
        if matches:
            return normalize_math_answer(matches[-1])

    # Try "the answer is X" patterns
    answer_patterns = [
        r'[Tt]he\s+(?:final\s+)?answer\s+is[:\s]*\$?([^\$\n]+)\$?',
        r'[Aa]nswer[:\s]*\$?([^\$\n]+)\$?',
        r'=\s*\$?([^\$\n]+)\$?\s*$',
    ]
    for pattern in answer_patterns:
        match = re.search(pattern, text)
        if match:
            return normalize_math_answer(match.group(1))

    # Fallback: last number
    numbers = re.findall(r'([+-]?\d+\.?\d*)', text)
    if numbers:
        return normalize_math_answer(numbers[-1])
    return None


def check_math_answer(predicted, ground_truth):
    """Check if predicted answer matches ground truth."""
    if predicted is None or ground_truth is None:
        return False

    pred_norm = normalize_math_answer(predicted)
    gt_norm = normalize_math_answer(ground_truth)

    if pred_norm is None or gt_norm is None:
        return False

    if pred_norm == gt_norm:
        return True

    try:
        return abs(float(pred_norm) - float(gt_norm)) < 1e-4
    except:
        return False


def evaluate_math500(n_problems=50, n_candidates=8, temperature=0.8):
    """Evaluate on MATH-500 dataset (competition-level problems)."""

    print("Loading MATH-500 dataset...")
    try:
        math_dataset = load_dataset("HuggingFaceH4/MATH-500", split="test")
        print(f"Loaded {len(math_dataset)} problems")
        print(f"Columns: {math_dataset.column_names}")
    except Exception as e:
        print(f"Error loading MATH-500: {e}")
        print("\nTrying alternative: openai/gsm8k hard subset...")
        # Fallback to harder GSM8K problems
        math_dataset = load_dataset("gsm8k", "main", split="test")
        # Take later problems (tend to be harder)
        math_dataset = math_dataset.select(range(len(math_dataset)-500, len(math_dataset)))
        print(f"Using last 500 GSM8K problems (harder subset)")

    problems = list(math_dataset)[:n_problems]

    results = {"pass_1": 0, "majority": 0, "prm_rerank": 0, "prm_weighted": 0, "total": 0}

    print(f"\nEvaluating {len(problems)} problems with {n_candidates} candidates...")
    print("BASE model for generation, PRM for scoring\n")

    for idx, item in enumerate(tqdm(problems)):
        # Handle different column names
        problem = item.get("problem", item.get("question", ""))

        # Get ground truth
        if "answer" in item:
            gt_answer = normalize_math_answer(item["answer"])
        elif "solution" in item:
            gt_answer = extract_math_answer(item["solution"])
        else:
            # GSM8K format
            gt_match = re.search(r"####\s*([-\d,\.]+)", item.get("answer", ""))
            gt_answer = str(float(gt_match.group(1).replace(",", ""))) if gt_match else None

        if not problem or gt_answer is None:
            continue

        # Generate candidates
        solutions = [generate_solution(problem, temperature) for _ in range(n_candidates)]
        scored = [(sol, score_solution(problem, sol)[0]) for sol in solutions]
        answers = [extract_math_answer(sol) for sol in solutions]
        scores = [s[1] for s in scored]

        # Pass@1
        if check_math_answer(answers[0], gt_answer):
            results["pass_1"] += 1

        # Majority vote
        from collections import Counter
        valid = [a for a in answers if a]
        if valid:
            majority = Counter(valid).most_common(1)[0][0]
            if check_math_answer(majority, gt_answer):
                results["majority"] += 1

        # PRM Rerank
        scored.sort(key=lambda x: x[1], reverse=True)
        best_answer = extract_math_answer(scored[0][0])
        if check_math_answer(best_answer, gt_answer):
            results["prm_rerank"] += 1

        # PRM-Weighted Majority
        answer_weights = {}
        for ans, score in zip(answers, scores):
            if ans:
                found = False
                for existing in answer_weights:
                    if check_math_answer(ans, existing):
                        answer_weights[existing] += score
                        found = True
                        break
                if not found:
                    answer_weights[ans] = score

        if answer_weights:
            weighted_best = max(answer_weights, key=answer_weights.get)
            if check_math_answer(weighted_best, gt_answer):
                results["prm_weighted"] += 1

        results["total"] += 1

        # Debug first 2 problems
        if idx < 2:
            print(f"\n--- Problem {idx+1} ---")
            print(f"Q: {problem[:100]}...")
            print(f"GT: {gt_answer}")
            print(f"Answers: {answers[:4]}")
            print(f"Scores: {[f'{s:.3f}' for s in scores[:4]]}")
            print(f"Best: {best_answer} | Correct: {check_math_answer(best_answer, gt_answer)}")

    # Results
    total = results["total"]
    print("\n" + "="*55)
    print("MATH-500 EVALUATION RESULTS")
    print("="*55)
    print(f"Problems: {total} | Candidates: {n_candidates}")
    print("-"*55)
    print(f"Pass@1:           {results['pass_1']}/{total} = {results['pass_1']/total*100:.1f}%")
    print(f"Majority@{n_candidates}:        {results['majority']}/{total} = {results['majority']/total*100:.1f}%")
    print(f"PRM Rerank@{n_candidates}:      {results['prm_rerank']}/{total} = {results['prm_rerank']/total*100:.1f}%")
    print(f"PRM-Weighted@{n_candidates}:    {results['prm_weighted']}/{total} = {results['prm_weighted']/total*100:.1f}%")
    print("="*55)
    print(f"\nPRM Rerank improvement: {(results['prm_rerank']-results['pass_1'])/total*100:+.1f}%")

    return results

print("MATH-500 evaluation ready!")
print("This dataset contains competition-level problems (AMC, AIME style)")

# Run MATH-500 Evaluation
# Competition-level problems - expect lower accuracy but more PRM benefit

# Standard evaluation (~2 hours)
math_results = evaluate_math500(n_problems=50, n_candidates=8, temperature=0.8)

# Quick test (~30 min)
# math_results = evaluate_math500(n_problems=20, n_candidates=4, temperature=0.8)

MATH-500 evaluation ready!
This dataset contains competition-level problems (AMC, AIME style)
Loading MATH-500 dataset...


README.md:   0%|          | 0.00/412 [00:00<?, ?B/s]

test.jsonl: 0.00B [00:00, ?B/s]

Generating test split:   0%|          | 0/500 [00:00<?, ? examples/s]

Loaded 500 problems
Columns: ['problem', 'solution', 'answer', 'subject', 'level', 'unique_id']

Evaluating 50 problems with 8 candidates...
BASE model for generation, PRM for scoring



  2%|▏         | 1/50 [01:55<1:34:03, 115.17s/it]


--- Problem 1 ---
Q: Convert the point $(0,3)$ in rectangular coordinates to polar coordinates.  Enter your answer in the...
GT: (3,2)
Answers: ['(3,2)', 'math.atan(y/x', '(3,2)', '(3,2)']
Scores: ['0.020', '0.013', '0.037', '0.045']
Best: (3,2) | Correct: True


  4%|▍         | 2/50 [04:23<1:47:51, 134.81s/it]


--- Problem 2 ---
Q: Define
\[p = \sum_{k = 1}^\infty \frac{1}{k^2} \quad \text{and} \quad q = \sum_{k = 1}^\infty \frac{...
GT: p-q
Answers: ['1^1n^2=p', '2^1n^2=p-1and_', '1^1n^2=pand', '2^(1n^2-1n^3)=(']
Scores: ['0.079', '0.024', '0.202', '0.014']
Best: 1^1n^2=pand | Correct: False


100%|██████████| 50/50 [1:41:20<00:00, 121.61s/it]


MATH-500 EVALUATION RESULTS
Problems: 50 | Candidates: 8
-------------------------------------------------------
Pass@1:           24/50 = 48.0%
Majority@8:        27/50 = 54.0%
PRM Rerank@8:      27/50 = 54.0%
PRM-Weighted@8:    27/50 = 54.0%

PRM Rerank improvement: +6.0%





## 8. MCTS with PRM Value Function

Monte Carlo Tree Search using the trained PRM as the value function.
- **Prior**: LLM generation probability (temperature sampling confidence)
- **Value**: PRM score for the current solution state
- **Selection**: UCB with exploration bonus
- **Expansion**: Generate candidate next steps
- **Evaluation**: PRM scores the partial solution
- **Backpropagation**: Update visit counts and values

In [None]:
import math
import numpy as np
from collections import Counter
import time

# ============================================================
# MCTS with PRM Value Function and Logprob Priors (Optimized)
# ============================================================

class MCTSNode:
    """Node in the MCTS tree."""
    def __init__(self, state, parent=None, action=None, prior=1.0):
        self.state = state
        self.parent = parent
        self.action = action
        self.children = []
        self.visits = 0
        self.value_sum = 0.0
        self.prior = prior
        self._cached_value = None
    
    @property
    def value(self):
        return self.value_sum / (self.visits + 1e-8)
    
    def is_fully_expanded(self):
        return len(self.children) > 0
    
    def is_terminal(self):
        if self.action is None:
            return False
        return "\\boxed" in self.action or "boxed{" in self.action or "####" in self.action


class MCTSSearchPRM:
    """MCTS using generation logprobs as prior and trained PRM as value function."""
    
    def __init__(self, base_model, base_tokenizer, prm_model, prm_tokenizer, 
                 pos_token_id, neg_token_id, config=None):
        self.base_model = base_model
        self.base_tokenizer = base_tokenizer
        self.prm_model = prm_model
        self.prm_tokenizer = prm_tokenizer
        self.pos_token_id = pos_token_id
        self.neg_token_id = neg_token_id
        
        self.config = config or {}
        self.c_puct = self.config.get("c_puct", 1.5)
        self.n_expand = self.config.get("n_expand", 3)
        self.temperature = self.config.get("temperature", 0.8)
        self.max_depth = self.config.get("max_depth", 10)
    
    def search_with_checkpoints(self, problem, max_simulations=50, checkpoints=[1, 5, 10, 20, 50]):
        """
        Run MCTS once and record best solution at each checkpoint.
        
        Returns: dict mapping simulation count to best solution at that point
        Example: {1: "solution after 1 sim", 5: "solution after 5 sims", ...}
        """
        # Create root node
        messages = [
            {"role": "system", "content": "You are a helpful math assistant. Solve the problem step by step."},
            {"role": "user", "content": problem}
        ]
        root_state = self.base_tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=True
        )
        root = MCTSNode(state=root_state, prior=1.0)
        
        # Sort checkpoints and ensure max_simulations covers all
        checkpoints = sorted([c for c in checkpoints if c <= max_simulations])
        if not checkpoints:
            checkpoints = [max_simulations]
        actual_max = max(checkpoints)
        
        results = {}
        checkpoint_idx = 0
        
        for sim in range(1, actual_max + 1):
            node = root
            
            # Selection
            depth = 0
            while node.is_fully_expanded() and not node.is_terminal() and depth < self.max_depth:
                node = self._select_child(node)
                depth += 1
            
            # Expansion
            if not node.is_terminal() and depth < self.max_depth:
                node = self._expand(node, problem)
            
            # Evaluation
            value = self._evaluate_prm(node, problem)
            
            # Backpropagation
            self._backpropagate(node, value)
            
            # Check if we hit a checkpoint
            if checkpoint_idx < len(checkpoints) and sim == checkpoints[checkpoint_idx]:
                results[sim] = self._get_best_solution(root, problem)
                checkpoint_idx += 1
        
        return results
    
    def search(self, problem, simulations=10):
        """Single search returning best solution (backward compatible)."""
        results = self.search_with_checkpoints(problem, max_simulations=simulations, checkpoints=[simulations])
        return results.get(simulations, "")
    
    def _select_child(self, node):
        sqrt_n = math.sqrt(max(1, node.visits))
        
        def ucb_score(child):
            exploitation = child.value
            exploration = self.c_puct * child.prior * sqrt_n / (1 + child.visits)
            return exploitation + exploration
        
        return max(node.children, key=ucb_score)
    
    def _expand(self, node, problem):
        candidates = self._generate_steps_with_logprobs(node.state, n=self.n_expand)
        
        if not candidates:
            return node
        
        priors = np.array([c[1] for c in candidates])
        priors = priors / (priors.sum() + 1e-8)
        
        for i, (step_text, _) in enumerate(candidates):
            child_state = node.state + step_text
            child = MCTSNode(child_state, parent=node, action=step_text, prior=priors[i])
            node.children.append(child)
        
        return max(node.children, key=lambda c: c.prior)
    
    def _generate_steps_with_logprobs(self, state, n=3):
        """Generate n candidate next steps with proper logprob priors."""
        candidates = []
        
        inputs = self.base_tokenizer(state, return_tensors="pt").to("cuda")
        input_length = inputs["input_ids"].shape[1]
        
        with torch.no_grad():
            for _ in range(n):
                outputs = self.base_model.generate(
                    **inputs,
                    max_new_tokens=150,
                    temperature=self.temperature,
                    do_sample=True,
                    pad_token_id=self.base_tokenizer.pad_token_id,
                    eos_token_id=self.base_tokenizer.eos_token_id,
                    output_scores=True,
                    return_dict_in_generate=True,
                )
                
                generated_ids = outputs.sequences[0, input_length:]
                scores = outputs.scores
                
                if len(scores) > 0 and len(generated_ids) > 0:
                    total_logprob = 0.0
                    num_tokens = 0
                    
                    for i, (score, token_id) in enumerate(zip(scores, generated_ids)):
                        if token_id == self.base_tokenizer.eos_token_id:
                            break
                        if token_id == self.base_tokenizer.pad_token_id:
                            continue
                        
                        probs = F.softmax(score[0] / self.temperature, dim=-1)
                        token_prob = probs[token_id].item()
                        
                        if token_prob > 0:
                            total_logprob += math.log(token_prob)
                            num_tokens += 1
                    
                    if num_tokens > 0:
                        avg_logprob = total_logprob / num_tokens
                        prior = math.exp(avg_logprob)
                    else:
                        prior = 0.5
                else:
                    prior = 0.5
                
                generated_text = self.base_tokenizer.decode(generated_ids, skip_special_tokens=True)
                step = generated_text.strip()
                
                if step:
                    candidates.append((step, prior))
        
        if not candidates:
            candidates = [("Let me solve this step by step.", 0.5)]
        
        return candidates
    
    def _evaluate_prm(self, node, problem):
        if node._cached_value is not None:
            return node._cached_value
        
        solution = node.state
        if "<|im_start|>assistant" in solution:
            solution = solution.split("<|im_start|>assistant")[-1]
        
        score = self._prm_score(problem, solution.strip())
        node._cached_value = score
        return score
    
    def _prm_score(self, problem, solution):
        if not solution:
            return 0.0
        
        steps = [s.strip() for s in solution.split("\n") if s.strip()]
        if not steps:
            return 0.0
        
        context = f"Problem: {problem}\n\nSolution:"
        step_scores = []
        
        for step in steps:
            text = f"{context}\n{step}\n<|verify|>"
            inputs = self.prm_tokenizer(text, return_tensors="pt").to("cuda")
            
            with torch.no_grad():
                outputs = self.prm_model(**inputs)
                logits = outputs.logits[0, -1, :]
            
            probs = F.softmax(logits, dim=-1)
            pos_prob = probs[self.pos_token_id].item()
            neg_prob = probs[self.neg_token_id].item()
            
            if pos_prob + neg_prob > 0:
                score = pos_prob / (pos_prob + neg_prob)
            else:
                score = 0.5
            
            step_scores.append(score)
            context = f"{context}\n{step}"
        
        if step_scores:
            return math.prod(step_scores)
        return 0.0
    
    def _backpropagate(self, node, value):
        while node is not None:
            node.visits += 1
            node.value_sum += value
            node = node.parent
    
    def _get_best_solution(self, root, problem):
        node = root
        solution_text = ""
        
        while node.children:
            node = max(node.children, key=lambda c: c.visits)
            if node.action:
                solution_text += node.action
            
            if node.is_terminal():
                break
        
        if not node.is_terminal() and solution_text:
            completion = self._greedy_complete(node.state)
            solution_text += completion
        
        return solution_text
    
    def _greedy_complete(self, state):
        inputs = self.base_tokenizer(state, return_tensors="pt").to("cuda")
        
        with torch.no_grad():
            outputs = self.base_model.generate(
                **inputs,
                max_new_tokens=256,
                temperature=0.0,
                do_sample=False,
                pad_token_id=self.base_tokenizer.pad_token_id,
            )
        
        generated = self.base_tokenizer.decode(
            outputs[0][inputs["input_ids"].shape[1]:], 
            skip_special_tokens=True
        )
        return generated


print("MCTS with checkpoints defined!")
print("- search_with_checkpoints(): Build tree ONCE, get results at multiple points")
print("- ~5x faster than running separate searches")

In [None]:
def evaluate_mcts_gsm8k(n_problems=50, n_candidates=16, simulations_list=[1, 5, 10, 20, 50], temperature=0.8):
    """
    Optimized evaluation with live progress indicators.
    """
    
    print("="*60)
    print("MCTS vs Other Methods - GSM8K (Optimized)")
    print("="*60)
    
    max_sims = max(simulations_list)
    print(f"\nOptimization: MCTS tree built once to {max_sims} sims")
    print(f"Checkpoints: {simulations_list}")
    
    # Initialize MCTS
    mcts_config = {"c_puct": 1.5, "n_expand": 3, "temperature": temperature, "max_depth": 10}
    mcts = MCTSSearchPRM(
        base_model, base_tokenizer, prm_model, prm_tokenizer,
        pos_token_id, neg_token_id, config=mcts_config
    )
    
    # Load dataset
    print("\nLoading GSM8K dataset...")
    gsm8k = load_dataset("gsm8k", "main", split="test")
    problems = list(gsm8k)[:n_problems]
    
    # Results storage
    all_results = {
        "pass_1": {"correct": 0, "total": 0},
        f"majority_{n_candidates}": {"correct": 0, "total": 0},
        f"prm_rerank_{n_candidates}": {"correct": 0, "total": 0},
        f"prm_weighted_{n_candidates}": {"correct": 0, "total": 0},
    }
    for sims in simulations_list:
        all_results[f"mcts_{sims}"] = {"correct": 0, "total": 0}
    
    print(f"\nEvaluating {n_problems} problems with {n_candidates} candidates...")
    print()
    
    start_time = time.time()
    
    # Progress bar with live stats
    pbar = tqdm(problems, desc="Evaluating")
    
    for idx, item in enumerate(pbar):
        question = item["question"]
        gt_match = re.search(r"####\s*([-\d,\.]+)", item["answer"])
        if not gt_match:
            continue
        ground_truth = str(float(gt_match.group(1).replace(",", "")))
        
        # === Pass@1 ===
        sol_1 = generate_solution(question, temperature=0.0)
        ans_1 = extract_answer(sol_1)
        if ans_1 == ground_truth:
            all_results["pass_1"]["correct"] += 1
        all_results["pass_1"]["total"] += 1
        
        # === Generate candidates ONCE ===
        solutions = [generate_solution(question, temperature=0.7) for _ in range(n_candidates)]
        answers = [extract_answer(s) for s in solutions]
        scores = [score_solution(question, s)[0] for s in solutions]
        
        # Majority
        valid_answers = [a for a in answers if a]
        if valid_answers:
            majority_ans = Counter(valid_answers).most_common(1)[0][0]
            if majority_ans == ground_truth:
                all_results[f"majority_{n_candidates}"]["correct"] += 1
        all_results[f"majority_{n_candidates}"]["total"] += 1
        
        # PRM Rerank
        best_idx = np.argmax(scores)
        if answers[best_idx] == ground_truth:
            all_results[f"prm_rerank_{n_candidates}"]["correct"] += 1
        all_results[f"prm_rerank_{n_candidates}"]["total"] += 1
        
        # PRM-Weighted
        answer_weights = {}
        for ans, score in zip(answers, scores):
            if ans:
                answer_weights[ans] = answer_weights.get(ans, 0) + score
        if answer_weights:
            weighted_best = max(answer_weights, key=answer_weights.get)
            if weighted_best == ground_truth:
                all_results[f"prm_weighted_{n_candidates}"]["correct"] += 1
        all_results[f"prm_weighted_{n_candidates}"]["total"] += 1
        
        # === MCTS with checkpoints ===
        mcts_solutions = mcts.search_with_checkpoints(question, max_sims, simulations_list)
        
        for sims in simulations_list:
            mcts_answer = extract_answer(mcts_solutions.get(sims, ""))
            if mcts_answer == ground_truth:
                all_results[f"mcts_{sims}"]["correct"] += 1
            all_results[f"mcts_{sims}"]["total"] += 1
        
        # Update progress bar with live accuracies
        if all_results["pass_1"]["total"] > 0:
            p1 = all_results["pass_1"]["correct"] / all_results["pass_1"]["total"] * 100
            maj = all_results[f"majority_{n_candidates}"]["correct"] / all_results[f"majority_{n_candidates}"]["total"] * 100
            prm = all_results[f"prm_rerank_{n_candidates}"]["correct"] / all_results[f"prm_rerank_{n_candidates}"]["total"] * 100
            best_mcts_sims = max(simulations_list)
            mcts_acc = all_results[f"mcts_{best_mcts_sims}"]["correct"] / max(all_results[f"mcts_{best_mcts_sims}"]["total"], 1) * 100
            
            pbar.set_postfix({
                "P@1": f"{p1:.0f}%",
                "Maj": f"{maj:.0f}%",
                "PRM": f"{prm:.0f}%",
                f"MCTS@{best_mcts_sims}": f"{mcts_acc:.0f}%"
            })
    
    pbar.close()
    elapsed = time.time() - start_time
    
    # Print final results
    print("\n" + "="*60)
    print(f"FINAL RESULTS - GSM8K ({elapsed/60:.1f} minutes)")
    print("="*60)
    print(f"{'Method':<25} {'Correct':>8} {'Accuracy':>10}")
    print("-"*45)
    
    for method, data in all_results.items():
        if data["total"] > 0:
            acc = data["correct"] / data["total"] * 100
            print(f"{method:<25} {data['correct']:>3}/{data['total']:<3}    {acc:>6.1f}%")
    
    # Summary
    print("\n" + "-"*45)
    maj_acc = all_results[f"majority_{n_candidates}"]["correct"] / max(all_results[f"majority_{n_candidates}"]["total"], 1) * 100
    prm_acc = all_results[f"prm_rerank_{n_candidates}"]["correct"] / max(all_results[f"prm_rerank_{n_candidates}"]["total"], 1) * 100
    
    mcts_items = [(k, v) for k, v in all_results.items() if k.startswith("mcts_")]
    if mcts_items:
        best_mcts = max(mcts_items, key=lambda x: x[1]["correct"] / max(x[1]["total"], 1))
        best_mcts_acc = best_mcts[1]["correct"] / max(best_mcts[1]["total"], 1) * 100
        print(f"Best MCTS: {best_mcts[0]} ({best_mcts_acc:.1f}%)")
        print(f"  vs Majority@{n_candidates}:    {best_mcts_acc - maj_acc:+.1f}%")
        print(f"  vs PRM Rerank@{n_candidates}:  {best_mcts_acc - prm_acc:+.1f}%")
    
    return all_results


print("GSM8K evaluation with live progress indicators defined!")

In [None]:
def evaluate_mcts_math500(n_problems=50, n_candidates=16, simulations_list=[1, 5, 10, 20, 50], temperature=0.8):
    """
    Optimized MATH-500 evaluation with live progress indicators.
    """
    
    print("="*60)
    print("MCTS vs Other Methods - MATH-500 (Optimized)")
    print("="*60)
    
    max_sims = max(simulations_list)
    print(f"\nOptimization: MCTS tree built once to {max_sims} sims")
    
    # Initialize MCTS
    mcts_config = {"c_puct": 1.5, "n_expand": 3, "temperature": temperature, "max_depth": 10}
    mcts = MCTSSearchPRM(
        base_model, base_tokenizer, prm_model, prm_tokenizer,
        pos_token_id, neg_token_id, config=mcts_config
    )
    
    # Load dataset
    print("\nLoading MATH-500 dataset...")
    try:
        math500 = load_dataset("HuggingFaceH4/MATH-500", split="test")
        problems = list(math500)[:n_problems]
        print(f"Loaded {len(problems)} problems from MATH-500")
    except Exception as e:
        print(f"Error: {e}\nUsing GSM8K hard subset...")
        gsm8k = load_dataset("gsm8k", "main", split="test")
        problems = list(gsm8k)[-n_problems:]
    
    # Results
    all_results = {
        "pass_1": {"correct": 0, "total": 0},
        f"majority_{n_candidates}": {"correct": 0, "total": 0},
        f"prm_rerank_{n_candidates}": {"correct": 0, "total": 0},
        f"prm_weighted_{n_candidates}": {"correct": 0, "total": 0},
    }
    for sims in simulations_list:
        all_results[f"mcts_{sims}"] = {"correct": 0, "total": 0}
    
    print(f"\nEvaluating {len(problems)} problems...")
    print()
    
    start_time = time.time()
    pbar = tqdm(problems, desc="Evaluating")
    
    for idx, item in enumerate(pbar):
        if "problem" in item:
            question = item["problem"]
            ground_truth = item.get("answer", "")
        else:
            question = item["question"]
            gt_match = re.search(r"####\s*([-\d,\.]+)", item["answer"])
            ground_truth = gt_match.group(1).replace(",", "") if gt_match else ""
        
        if not ground_truth:
            continue
        
        gt_normalized = normalize_math_answer(ground_truth)
        
        # Pass@1
        sol_1 = generate_solution(question, temperature=0.0)
        ans_1 = extract_answer(sol_1)
        if compare_math_answers(ans_1, gt_normalized):
            all_results["pass_1"]["correct"] += 1
        all_results["pass_1"]["total"] += 1
        
        # Generate candidates
        solutions = [generate_solution(question, temperature=0.7) for _ in range(n_candidates)]
        answers = [extract_answer(s) for s in solutions]
        scores = [score_solution(question, s)[0] for s in solutions]
        
        # Majority
        valid_answers = [a for a in answers if a]
        if valid_answers:
            majority_ans = Counter(valid_answers).most_common(1)[0][0]
            if compare_math_answers(majority_ans, gt_normalized):
                all_results[f"majority_{n_candidates}"]["correct"] += 1
        all_results[f"majority_{n_candidates}"]["total"] += 1
        
        # PRM Rerank
        best_idx = np.argmax(scores)
        if compare_math_answers(answers[best_idx], gt_normalized):
            all_results[f"prm_rerank_{n_candidates}"]["correct"] += 1
        all_results[f"prm_rerank_{n_candidates}"]["total"] += 1
        
        # PRM-Weighted
        answer_weights = {}
        for ans, score in zip(answers, scores):
            if ans:
                answer_weights[ans] = answer_weights.get(ans, 0) + score
        if answer_weights:
            weighted_best = max(answer_weights, key=answer_weights.get)
            if compare_math_answers(weighted_best, gt_normalized):
                all_results[f"prm_weighted_{n_candidates}"]["correct"] += 1
        all_results[f"prm_weighted_{n_candidates}"]["total"] += 1
        
        # MCTS
        mcts_solutions = mcts.search_with_checkpoints(question, max_sims, simulations_list)
        for sims in simulations_list:
            mcts_answer = extract_answer(mcts_solutions.get(sims, ""))
            if compare_math_answers(mcts_answer, gt_normalized):
                all_results[f"mcts_{sims}"]["correct"] += 1
            all_results[f"mcts_{sims}"]["total"] += 1
        
        # Update progress
        if all_results["pass_1"]["total"] > 0:
            p1 = all_results["pass_1"]["correct"] / all_results["pass_1"]["total"] * 100
            maj = all_results[f"majority_{n_candidates}"]["correct"] / all_results[f"majority_{n_candidates}"]["total"] * 100
            prm = all_results[f"prm_rerank_{n_candidates}"]["correct"] / all_results[f"prm_rerank_{n_candidates}"]["total"] * 100
            best_mcts_sims = max(simulations_list)
            mcts_acc = all_results[f"mcts_{best_mcts_sims}"]["correct"] / max(all_results[f"mcts_{best_mcts_sims}"]["total"], 1) * 100
            
            pbar.set_postfix({
                "P@1": f"{p1:.0f}%",
                "Maj": f"{maj:.0f}%",
                "PRM": f"{prm:.0f}%",
                f"MCTS@{best_mcts_sims}": f"{mcts_acc:.0f}%"
            })
    
    pbar.close()
    elapsed = time.time() - start_time
    
    # Final results
    print("\n" + "="*60)
    print(f"FINAL RESULTS - MATH-500 ({elapsed/60:.1f} minutes)")
    print("="*60)
    print(f"{'Method':<25} {'Correct':>8} {'Accuracy':>10}")
    print("-"*45)
    
    for method, data in all_results.items():
        if data["total"] > 0:
            acc = data["correct"] / data["total"] * 100
            print(f"{method:<25} {data['correct']:>3}/{data['total']:<3}    {acc:>6.1f}%")
    
    # Summary
    print("\n" + "-"*45)
    maj_acc = all_results[f"majority_{n_candidates}"]["correct"] / max(all_results[f"majority_{n_candidates}"]["total"], 1) * 100
    prm_acc = all_results[f"prm_rerank_{n_candidates}"]["correct"] / max(all_results[f"prm_rerank_{n_candidates}"]["total"], 1) * 100
    
    mcts_items = [(k, v) for k, v in all_results.items() if k.startswith("mcts_")]
    if mcts_items:
        best_mcts = max(mcts_items, key=lambda x: x[1]["correct"] / max(x[1]["total"], 1))
        best_mcts_acc = best_mcts[1]["correct"] / max(best_mcts[1]["total"], 1) * 100
        print(f"Best MCTS: {best_mcts[0]} ({best_mcts_acc:.1f}%)")
        print(f"  vs Majority@{n_candidates}:    {best_mcts_acc - maj_acc:+.1f}%")
        print(f"  vs PRM Rerank@{n_candidates}:  {best_mcts_acc - prm_acc:+.1f}%")
    
    return all_results


# Helper functions
def normalize_math_answer(answer):
    if answer is None:
        return None
    s = str(answer).strip()
    s = re.sub(r"\\(frac|dfrac)\{([^}]+)\}\{([^}]+)\}", lambda m: f"({m.group(2)})/({m.group(3)})", s)
    s = re.sub(r"\\sqrt\{([^}]+)\}", lambda m: f"sqrt({m.group(1)})", s)
    s = re.sub(r"\\[a-zA-Z]+", "", s)
    s = s.replace("{", "").replace("}", "").replace("$", "").replace(",", "").strip()
    try:
        import math as math_module
        return float(eval(s.replace("sqrt", "math_module.sqrt")))
    except:
        return s

def compare_math_answers(pred, truth):
    if pred is None or truth is None:
        return False
    pred_n = normalize_math_answer(pred)
    truth_n = normalize_math_answer(truth) if not isinstance(truth, (int, float)) else truth
    try:
        return abs(float(pred_n) - float(truth_n)) < 1e-4
    except:
        pass
    return str(pred_n).strip().lower() == str(truth_n).strip().lower()


print("MATH-500 evaluation with live progress indicators defined!")

In [None]:
def run_full_comparison(n_problems=50, n_candidates=16, simulations_list=[1, 5, 10, 20, 50]):
    """Run optimized comparison on both datasets."""
    import matplotlib.pyplot as plt
    
    print("="*60)
    print("FULL COMPARISON (Optimized)")
    print("="*60)
    print(f"Problems: {n_problems}, Candidates: {n_candidates}")
    print(f"MCTS checkpoints: {simulations_list}")
    print()
    
    # Run evaluations
    print(">>> GSM8K Evaluation")
    gsm8k_results = evaluate_mcts_gsm8k(n_problems, n_candidates, simulations_list)
    
    print("\n>>> MATH-500 Evaluation")
    math500_results = evaluate_mcts_math500(n_problems, n_candidates, simulations_list)
    
    # Visualization
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    # Bar charts
    for ax, (title, results) in zip(axes[:2], [("GSM8K", gsm8k_results), ("MATH-500", math500_results)]):
        methods = []
        accuracies = []
        colors = []
        
        color_map = {
            "pass_1": "#bdc3c7",
            f"majority_{n_candidates}": "#3498db", 
            f"prm_rerank_{n_candidates}": "#e74c3c",
            f"prm_weighted_{n_candidates}": "#9b59b6",
        }
        for sims in simulations_list:
            color_map[f"mcts_{sims}"] = plt.cm.Greens(0.3 + 0.7 * simulations_list.index(sims) / len(simulations_list))
        
        for method, data in results.items():
            if data["total"] > 0:
                label = method.replace("_", "@").replace(f"@{n_candidates}", f"@{n_candidates}")
                methods.append(label)
                accuracies.append(data["correct"] / data["total"] * 100)
                colors.append(color_map.get(method, "#95a5a6"))
        
        bars = ax.bar(methods, accuracies, color=colors, edgecolor="black", linewidth=1)
        ax.set_ylabel("Accuracy (%)", fontsize=11)
        ax.set_title(f"{title}", fontsize=13, fontweight="bold")
        ax.set_ylim(0, 100)
        ax.tick_params(axis="x", rotation=55, labelsize=8)
        
        for bar, acc in zip(bars, accuracies):
            ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 1, 
                   f"{acc:.0f}%", ha="center", fontsize=8)
    
    # MCTS Scaling plot
    ax3 = axes[2]
    for dataset_name, results, color, marker in [
        ("GSM8K", gsm8k_results, "#3498db", "o"),
        ("MATH-500", math500_results, "#e74c3c", "s")
    ]:
        mcts_data = [(int(k.split("_")[1]), v["correct"]/max(v["total"],1)*100) 
                     for k, v in results.items() if k.startswith("mcts_")]
        if mcts_data:
            mcts_data.sort()
            sims, accs = zip(*mcts_data)
            ax3.plot(sims, accs, marker=marker, label=f"MCTS ({dataset_name})", 
                    color=color, linewidth=2, markersize=8)
            
            maj_acc = results[f"majority_{n_candidates}"]["correct"] / max(results[f"majority_{n_candidates}"]["total"], 1) * 100
            ax3.axhline(y=maj_acc, color=color, linestyle="--", alpha=0.5)
    
    ax3.set_xlabel("MCTS Simulations", fontsize=11)
    ax3.set_ylabel("Accuracy (%)", fontsize=11)
    ax3.set_title("MCTS Scaling", fontsize=13, fontweight="bold")
    ax3.legend(fontsize=9)
    ax3.set_xscale("log")
    ax3.set_xticks(simulations_list)
    ax3.set_xticklabels(simulations_list)
    ax3.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig("mcts_comparison.png", dpi=150, bbox_inches="tight")
    plt.show()
    
    return gsm8k_results, math500_results


print("\n" + "="*60)
print("USAGE EXAMPLES")
print("="*60)
print("# Quick test (~30-60 min)")
print("results = evaluate_mcts_gsm8k(n_problems=30, n_candidates=8, simulations_list=[1, 5, 10])")
print()
print("# Standard evaluation (~2-3 hours)")
print("gsm8k, math500 = run_full_comparison(n_problems=50, n_candidates=16, simulations_list=[1, 5, 10, 20, 50])")
print()
print("# Thorough evaluation (~4-6 hours)")
print("gsm8k, math500 = run_full_comparison(n_problems=100, n_candidates=16, simulations_list=[1, 5, 10, 20, 50])")

In [None]:
from unsloth import FastLanguageModel
# Debug: Check what the trained model predicts
test_texts = [
    "Problem: What is 2+2?\n\nSolution:\nStep 1: 2+2=4\n<|verify|>",  # Correct
    "Problem: What is 2+2?\n\nSolution:\nStep 1: 2+2=5\n<|verify|>",  # Wrong
]

FastLanguageModel.for_inference(prm_model)

for text in test_texts:
    inputs = prm_tokenizer(text, return_tensors="pt").to("cuda")
    with torch.no_grad():
        outputs = prm_model(**inputs)
        logits = outputs.logits[0, -1, :]

    probs = torch.softmax(logits, dim=-1)
    pos_id = prm_tokenizer.encode("+", add_special_tokens=False)[0]
    neg_id = prm_tokenizer.encode("-", add_special_tokens=False)[0]

    print(f"Text: ...{text[-40:]}")
    print(f"  P(+)={probs[pos_id]:.4f}, P(-)={probs[neg_id]:.4f}")
    print(f"  Score: {probs[pos_id]/(probs[pos_id]+probs[neg_id]):.4f}")
    print()

Text: ...2+2?

Solution:
Step 1: 2+2=4
<|verify|>
  P(+)=0.0000, P(-)=0.0000
  Score: nan

Text: ...2+2?

Solution:
Step 1: 2+2=5
<|verify|>
  P(+)=0.0000, P(-)=0.0000
  Score: 1.0000



In [None]:
# Debug: What tokens is the model actually predicting?
test_text = "Problem: What is 2+2?\n\nSolution:\nStep 1: 2+2=4\n<|verify|>"

inputs = prm_tokenizer(test_text, return_tensors="pt").to("cuda")
with torch.no_grad():
    outputs = prm_model(**inputs)
    logits = outputs.logits[0, -1, :]

probs = torch.softmax(logits, dim=-1)

# Top 10 predicted tokens
top_probs, top_ids = torch.topk(probs, 10)
print("Top 10 predicted tokens after <|verify|>:")
for prob, tok_id in zip(top_probs, top_ids):
    token = prm_tokenizer.decode([tok_id])
    print(f"  '{token}' (id={tok_id.item()}): {prob.item():.4f}")

# Check our assumed token IDs
print(f"\nOur assumed token IDs:")
pos_id = prm_tokenizer.encode("+", add_special_tokens=False)[0]
neg_id = prm_tokenizer.encode("-", add_special_tokens=False)[0]
print(f"  '+' token ID: {pos_id}, prob: {probs[pos_id].item():.6f}")
print(f"  '-' token ID: {neg_id}, prob: {probs[neg_id].item():.6f}")

# Also check alternative encodings
print(f"\nAlternative '+' encodings:")
for text in ["+", " +", "+ ", " + "]:
    ids = prm_tokenizer.encode(text, add_special_tokens=False)
    print(f"  '{text}' -> {ids}")

Top 10 predicted tokens after <|verify|>:
  ' +' (id=488): 0.9971
  ' -' (id=481): 0.0031
  '_plus' (id=28043): 0.0000
  ' The' (id=576): 0.0000
  ' This' (id=1096): 0.0000
  ' x' (id=856): 0.0000
  ' *' (id=353): 0.0000
  ' by' (id=553): 0.0000
  ' //' (id=442): 0.0000
  ' (' (id=320): 0.0000

Our assumed token IDs:
  '+' token ID: 10, prob: 0.000000
  '-' token ID: 12, prob: 0.000000

Alternative '+' encodings:
  '+' -> [10]
  ' +' -> [488]
  '+ ' -> [10, 220]
  ' + ' -> [488, 220]


In [None]:
# Interactive Demo (Post-Restart)
# Run this AFTER cells 20-22 to test specific problems

#@title Enter your math problem
problem = "If a train travels at 60 mph for 2.5 hours, how far does it travel?" #@param {type:"string"}
n_candidates = 4 #@param {type:"slider", min:2, max:8, step:1}

print(f"Problem: {problem}\n")
print(f"Generating {n_candidates} solutions with BASE model...")

solutions = []
for i in range(n_candidates):
    sol = generate_solution(problem, temperature=0.7)
    solutions.append(sol)
    print(f"  Generated candidate {i+1}/{n_candidates}")

print("\nScoring with PRM model...")
scored = []
for i, sol in enumerate(solutions):
    score, step_scores = score_solution(problem, sol)
    scored.append((sol, score, step_scores))
    print(f"  Scored candidate {i+1}/{n_candidates}: {score:.3f}")

# Sort by score
scored.sort(key=lambda x: x[1], reverse=True)

print("\n" + "="*60)
print("BEST SOLUTION (Rank #1)")
print("="*60)
best_sol, best_score, best_step_scores = scored[0]
print(f"PRM Score: {best_score:.4f}")
print(f"Extracted Answer: {extract_answer(best_sol)}")
print("\nSolution:")
print("-"*40)
print(best_sol)
print("-"*40)

print("\nStep scores:")
steps = [s.strip() for s in best_sol.split("\n") if s.strip()]
for i, (step, score) in enumerate(zip(steps, best_step_scores)):
    status = "✓" if score > 0.5 else "✗"
    preview = step[:60] + "..." if len(step) > 60 else step
    print(f"  {status} Step {i+1} ({score:.3f}): {preview}")

---

## Alternative: Quick Test WITHOUT Restarting (Lower Quality)

The cells below use the fine-tuned model for BOTH generation and scoring. This works without restart but produces lower quality results because the PRM was trained for verification, not generation.

**Recommended**: Use the "Evaluation (Post-Restart)" section above for proper results.

In [None]:
# Quick test function (works without restart, but lower quality)
# Uses the fine-tuned model for both generation and scoring

def best_of_n_search_quick(problem, n_candidates=4, temperature=0.7):
    """
    Quick Best-of-N search using the trained model for both generation and scoring.
    Note: This is suboptimal - for best results, restart runtime and use base model.
    """
    print(f"Generating {n_candidates} solutions...")
    solutions = generator.generate_solutions(problem, n_candidates=n_candidates, temperature=temperature)

    print("Scoring solutions with PRM...")
    scored = []
    for i, sol in enumerate(solutions):
        result = verifier.score_solution(problem, sol)
        scored.append({
            "solution": sol,
            "score": result["score"],
            "step_scores": result["step_scores"],
            "steps": result["steps"]
        })

    scored.sort(key=lambda x: x["score"], reverse=True)

    return {
        "best_solution": scored[0],
        "all_solutions": scored
    }

def extract_answer_simple(text):
    """Extract numerical answer from solution."""
    import re
    patterns = [
        r"[Tt]he answer is[:\s]*([-\d,\.]+)",
        r"=\s*([-\d,\.]+)\s*$",
        r"([-\d,\.]+)\s*$"
    ]
    for pattern in patterns:
        match = re.search(pattern, text)
        if match:
            try:
                return str(float(match.group(1).replace(",", "")))
            except:
                continue
    return ""

print("Quick test functions defined!")
print("Note: For best results, restart runtime and use the post-restart evaluation cells.")

## Note: Google Drive Integration

Your model and checkpoints are **automatically saved to Google Drive** at:
```
/content/drive/MyDrive/Colab Notebooks/PRM-Math/checkpoints/
```

Features:
- **Checkpoints**: Saved every 100 steps during training
- **Resume Training**: Automatically detected and resumed on next run
- **Merged Model**: Saved after training completes
- **Persistent**: Survives runtime disconnects and restarts

In [ ]:
# Utility: List saved checkpoints and models
import os

DRIVE_PROJECT_PATH = "/content/drive/MyDrive/Colab Notebooks/PRM-Math"
checkpoint_dir = f"{DRIVE_PROJECT_PATH}/checkpoints"

print("Saved files in Google Drive:")
print("="*50)

if os.path.exists(checkpoint_dir):
    for item in sorted(os.listdir(checkpoint_dir)):
        item_path = os.path.join(checkpoint_dir, item)
        if os.path.isdir(item_path):
            # Get size
            size = sum(os.path.getsize(os.path.join(item_path, f)) 
                      for f in os.listdir(item_path) if os.path.isfile(os.path.join(item_path, f)))
            print(f"  📁 {item} ({size/1e9:.2f} GB)")
        else:
            size = os.path.getsize(item_path)
            print(f"  📄 {item} ({size/1e6:.2f} MB)")
else:
    print("  No checkpoints found yet.")

print("\nTo clear old checkpoints (keep only merged model):")
print("  !rm -rf '/content/drive/MyDrive/Colab Notebooks/PRM-Math/checkpoints/checkpoint-*'")

In [None]:
from google.colab import drive
import shutil

# Mount Google Drive
drive.mount('/content/drive')

# Create destination directory
drive_path = "/content/drive/MyDrive/PRM-Math-Models"
!mkdir -p "{drive_path}"

# Copy model
print(f"Copying model to {drive_path}...")
shutil.copytree(merged_model_path, f"{drive_path}/merged_model", dirs_exist_ok=True)

print("Model saved to Google Drive!")

## Summary

This notebook demonstrated:

1. **Training a Process Reward Model (PRM)** using the Math-Shepherd dataset
2. **Generative verification** where the model predicts "+" or "-" tokens
3. **Best-of-N search** with step-wise scoring
4. **Evaluation** on GSM8K benchmark

### Important: Correct PRM Architecture
- **BASE model** (Qwen-Math) → Generates solution candidates
- **Fine-tuned PRM** → Scores/ranks the candidates
- This requires **restarting runtime** after training to load both models cleanly

### Recommended Workflow:
1. Run cells 1-15 for training
2. Save model (cell 15)
3. **Restart runtime** (Runtime → Restart runtime)
4. Run cells 20-23 for evaluation (these load models without Unsloth)

### Hyperparameter Recommendations (L4 GPU, 24GB):
| Setting | Quick Test | Good Results | Best Results |
|---------|-----------|--------------|--------------|
| `max_samples` | 5,000 | 15,000 | 30,000+ |
| `num_train_epochs` | 1 | 2 | 2-3 |
| `batch_size` | 8 | 8 | 8 |
| `gradient_accumulation_steps` | 4 | 4 | 4 |
| Estimated time | ~30 min | ~2-3 hours | ~6+ hours |

### Next Steps:
- Train with more data and epochs for better PRM
- Experiment with different aggregation strategies (product, mean vs min)
- Evaluate on MATH dataset for harder problems