# Step 1: Mounting Google Drive and Importing Libraries

In [1]:
from google.colab import drive
drive.mount("/content/drive")
%cd /content/drive/MyDrive/grpo-verified-reasoner
!ls

Mounted at /content/drive
/content/drive/MyDrive/grpo-verified-reasoner
data			      notebooks  unsloth_compiled_cache
huggingface_tokenizers_cache  outputs	 _unsloth_sentencepiece_temp
LICENSE			      README.md
models			      src


In [None]:
# Install UV (Faster pip)
!pip install --upgrade -qqq uv

In [None]:
!pip -q install -U evalplus

In [4]:
import os
import subprocess

In [5]:
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False"

In [6]:
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"

In [66]:
os.environ["WANDB_PROJECT"] = "mbpp-rl-project"

In [7]:
# Environment Logic (Colab vs Local)
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # Version Matching
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    try: is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except: is_t4 = False

    # A100 gets vllm 0.10.2 (Fast), T4 gets 0.9.2 (Stable)
    get_vllm, get_triton = ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm==0.10.2", "triton")

    # Install Everything
    !uv pip install -qqq --upgrade \
        unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
    !uv pip install -qqq {get_triton}

# Install TRL
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m18 packages[0m [2min 112ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 534ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 345ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 47ms[0m[0m
 [31m-[39m [1mtransformers[0m[2m==4.57.3[0m
 [32m+[39m [1mtransformers[0m[2m==4.56.2[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m1 package[0m [2min 3ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 32ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 1ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 5ms[0m[0m
 [31m-[39m [1mtrl[0m[2m==0.24.0[0m
 [32m+[39m [1mtrl[0m[2m==0.22.2[0m


In [63]:
import re
import ast
import torch
import wandb
import random
import evalplus
import traceback
import numpy as np
import multiprocessing as mp
from datasets import Dataset
from unsloth import FastLanguageModel
from evalplus.data import get_mbpp_plus
from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams

In [None]:
wandb.login()

# Step 2: Verifying GPU and Environment

In [9]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

Torch version: 2.7.0+cu126
CUDA available: True
GPU: Tesla T4


# Step 3: Loading Base Model and LoRA Adapters

In [10]:
MODEL_PATH = "models/qwen3-4b-sft"

In [None]:
# Load the model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_PATH,
    max_seq_length = 3072,
    load_in_4bit = True,        # TRUE for T4 (Crucial for memory)
    fast_inference = True,      # TRUE to test vLLM
    gpu_memory_utilization = 0.6, # Conservative for T4
)

In [12]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen3ForCausalLM(
      (model): Qwen3Model(
        (embed_tokens): Embedding(151936, 2560, padding_idx=151654)
        (layers): ModuleList(
          (0-1): 2 x Qwen3DecoderLayer(
            (self_attn): Qwen3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=2560, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(


# Step 4: Sanity Check

In [13]:
# This is the same prompt that we used during SFT
system_prompt = """You are a code-generation engine.
You must output your response in the following exact format:
<START_WORKING_OUT>
Concise reasoning steps required to solve the problem.
</END_WORKING_OUT>
<SOLUTION>
Valid Python code only.
</SOLUTION>
Do not output anything outside these tags."""

In [14]:
user_prompt = "Write a Python function that returns the factorial of a number."

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_prompt},
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
)

In [15]:
# Move the dictionary to GPU manually
inputs = {k: v.to("cuda") for k, v in inputs.items()}

In [16]:
FastLanguageModel.for_inference(model) # Temporarily enable inference mode for the test
with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.0, # Deterministic check
    )

In [17]:
decoded = tokenizer.decode(output[0], skip_special_tokens=True)

In [18]:
print("\n--- MODEL OUTPUT ---")
input_len = inputs["input_ids"].shape[1]
print(tokenizer.decode(output[0][input_len:], skip_special_tokens=True))


--- MODEL OUTPUT ---
<START_WORKING_OUT>
Define a function factorial that takes an integer n.
Handle non-positive input by returning 1 (factorial of 0 or negative is 1).
Initialize result to 1.
Multiply result by each integer i from 1 to n.
Return result.
</END_WORKING_OUT>
<SOLUTION>
def factorial(n):
    if n <= 0:
        return 1
    result = 1
    for i in range(1, n + 1):
        result *= i
    return result
</SOLUTION>


Comment:  No schema check, extractor, or reward function ever sees the full decoded sequence. They only ever see generated_text.

# Step 6: Defining Output Schema

In [19]:
# Regular expressions for tag validation (case-insensitive)
RE_START = re.compile(r"<START_WORKING_OUT>", re.IGNORECASE)
RE_END   = re.compile(r"</END_WORKING_OUT>", re.IGNORECASE)
RE_SOL   = re.compile(r"<SOLUTION>", re.IGNORECASE)
RE_SOL_END = re.compile(r"</SOLUTION>", re.IGNORECASE)

In [20]:
def validate_schema(text: str) -> tuple[bool, str]:
    """
    Checks whether the model output follows the exact required schema.
    Returns (is_valid, reason).
    """
    if not RE_START.search(text):
        return False, "Missing <START_WORKING_OUT>"
    if not RE_END.search(text):
        return False, "Missing </END_WORKING_OUT>"
    if not RE_SOL.search(text):
        return False, "Missing <SOLUTION>"
    if not RE_SOL_END.search(text):
        return False, "Missing </SOLUTION>"

    # Optional: check order consistency
    start_idx = RE_START.search(text).start()
    sol_idx   = RE_SOL.search(text).start()
    if sol_idx < start_idx:
        return False, "Tag order incorrect (<SOLUTION> before reasoning block)."

    return True, "Schema valid"

In [21]:
# Run a sanity test using the previous decoded output
is_valid, reason = validate_schema(decoded)
print("Schema Check:", is_valid, "|", reason)

Schema Check: True | Schema valid


# Step 7: Solution Extraction

In [22]:
# Regex to extract the code block between <SOLUTION> ... </SOLUTION>
RE_SOLUTION = re.compile(r"<SOLUTION>\s*(.*?)\s*</SOLUTION>", re.IGNORECASE | re.DOTALL)

In [23]:
def extract_solution(text: str) -> tuple[str | None, str]:
    """
    Extracts the Python code inside <SOLUTION> tags.
    Returns (code, status) where:
        code   -> the extracted string or None if failed
        status -> textual reason (for debugging)
    """
    match = RE_SOLUTION.search(text)
    if not match:
        return None, "No <SOLUTION> block found."

    code = match.group(1).strip()
    if not code:
        return None, "Empty <SOLUTION> block."

    # Syntax check via Python's AST parser
    try:
        ast.parse(code)
    except SyntaxError as e:
        return None, f"Syntax error in code: {e}"

    return code, "Valid Python code extracted."

In [24]:
# Calculate where the prompt ends
input_len = inputs["input_ids"].shape[1]

In [25]:
# Decode ONLY the new tokens (The Assistant's reply)
generated_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True)

In [26]:
# Now run the check on ONLY the generated text
code, status = extract_solution(generated_text) # Use the new variable
print("Status:", status)

Status: Valid Python code extracted.


In [27]:
# Show snippet of the extracted code
if code:
    print("\n--- Extracted Python Code ---\n")
    print(code)


--- Extracted Python Code ---

def factorial(n):
    if n <= 0:
        return 1
    result = 1
    for i in range(1, n + 1):
        result *= i
    return result


# Step 8: Verifier Integration (EvalPlus MBPP+)

In [28]:
# Load MBPP+ tasks as a dict: {task_id: problem_dict}
MBPP_TASKS = get_mbpp_plus()

print(f"Loaded MBPP+ tasks: {len(MBPP_TASKS)}")

Downloading dataset from https://github.com/evalplus/mbppplus_release/releases/download/v0.2.0/MbppPlus.jsonl.gz
Loaded MBPP+ tasks: 378


In [29]:
# Quick peek at one task to confirm fields & shape
sample_task_id = next(iter(MBPP_TASKS.keys()))
sample_task = MBPP_TASKS[sample_task_id]

print("\nSample Task ID:", sample_task_id)
print("Keys:", list(sample_task.keys()))
print("\nPrompt (first 400 chars):\n", sample_task["prompt"][:400])


Sample Task ID: Mbpp/2
Keys: ['task_id', 'prompt', 'entry_point', 'canonical_solution', 'base_input', 'atol', 'plus_input', 'contract', 'assertion']

Prompt (first 400 chars):
 """
Write a function to find the shared elements from the given two lists.
assert set(similar_elements((3, 4, 5, 6),(5, 7, 4, 10))) == set((4, 5))
"""



In [30]:
# Different EvalPlus versions may store tests under slightly different keys,
# so we normalize via a helper (used later in reward function).
def get_tests_from_task(task: dict) -> list[str]:
    """
    Extracts MBPP test assertions from a task.
    Supports both list-based and string-based formats.
    """
    # Case 1: already a list of assertions
    for k in ("test_list", "tests", "plus_tests", "base_tests"):
        if k in task and task[k]:
            return list(task[k])

    # Case 2: single multiline assertion string (MBPP+ common case)
    if "assertion" in task and task["assertion"]:
        lines = task["assertion"].strip().splitlines()
        return [line for line in lines if line.strip()]

    raise KeyError(f"No tests found in task keys: {list(task.keys())}")

# Step 9: Defining Helper Functions

In [31]:
def _exec_code_and_tests_worker(code: str, tests: list[str], queue: mp.Queue) -> None:
    """
    Runs model code + tests in a shared environment.
    Fixes import errors and reports specific test failures.
    """
    try:
        # Create the "Main Desk" (Environment)
        env = {"__builtins__": __builtins__}

        # Run the User's Code into 'env'
        # We pass 'env' twice so it acts as both Globals and Locals
        exec(code, env, env)

        # Run the Test Cases
        for t in tests:
            try:
                # Run the test using that same desk
                exec(t, env, env)
            except AssertionError:
                # If a test fails, tell us WHICH one
                queue.put((False, f"Failed assertion: {t}"))
                return

        # If we finish the loop, all tests passed
        queue.put((True, None))

    except Exception:
        # Catch any other crashes (syntax errors, etc.)
        queue.put((False, traceback.format_exc()))

In [32]:
def run_mbpp_tests(code: str, task: dict, timeout_s: float = 2.0) -> tuple[bool, str | None]:
    """
    Executes MBPP tests for a given task in a subprocess with timeout.
    Returns (passed, error_str).
    """
    tests = get_tests_from_task(task)

    ctx = mp.get_context("fork")  # Colab/Linux: fork is fastest & simplest
    q = ctx.Queue()
    p = ctx.Process(target=_exec_code_and_tests_worker, args=(code, tests, q))
    p.start()
    p.join(timeout_s)

    if p.is_alive():
        p.terminate()
        p.join()
        return False, f"Timeout after {timeout_s:.1f}s"

    if q.empty():
        return False, "No result returned from worker."

    passed, err = q.get()
    return passed, err

# Step 10: Defining Reward Functions

In [33]:
def format_reward_func(completions, **kwargs) -> list[float]:
    """
    Rewards the model for strictly following the XML schema.
    Args:
        completions: List of generated strings from the model.
    Returns:
        List of rewards (0.1 for valid schema, 0.0 for invalid).
    """
    rewards = []
    for completion in completions:
        # Uses your existing validator from Step 6
        is_valid, _ = validate_schema(completion)
        rewards.append(0.1 if is_valid else 0.0)
    return rewards

In [34]:
def reasoning_reward_func(completions, **kwargs) -> list[float]:
    """
    Rewards the model for generating a detailed reasoning block.
    Uses a "soft length" penalty to encourage thinking without spamming.
    Args:
        completions: List of generated strings from the model.
    Returns:
        List of rewards (0.0 to 0.2, scaled by length of reasoning).
    """
    rewards = []
    for completion in completions:
        # Regex to find the reasoning block specifically
        match = re.search(r"<START_WORKING_OUT>(.*?)</END_WORKING_OUT>", completion, re.DOTALL | re.IGNORECASE)
        if match:
            reasoning_content = match.group(1).strip()
            # Soft Length Reward: Cap at 0.2 for ~500 characters
            # This incentivizes "thinking" without encouraging infinite spam
            length = len(reasoning_content)
            score = min(0.2, (length / 1000.0) * 0.2)
            rewards.append(score)
        else:
            rewards.append(0.0)
    return rewards

In [35]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Rewards the model for writing code that passes the actual unit tests.
    Args:
        prompts: The prompts fed to the model.
        completions: The model's generated answers.
        answer: The ground-truth data (Expected to be the MBPP task dict).
    Returns:
        List of rewards (1.0 for passing tests, 0.0 for failing).
    """
    rewards = []
    for prompt, completion, task_data in zip(prompts, completions, answer):
        code, status = extract_solution(completion)
        if not code:
            rewards.append(0.0)
            # Debug: Log extraction failures
            # print(f"[Debug] Extract Failed: {status}")
            continue

        passed, err = run_mbpp_tests(code, task_data)

        if passed:
            rewards.append(1.0)
        else:
            rewards.append(0.0)
            # CRITICAL: Print the error for the user to see!
            # We only print the first few chars to avoid spamming the logs
            print(f"\n[FAIL] Task: {task_data.get('task_id', 'Unknown')}")
            print(f"Error: {err}")
    return rewards

# Step 11: Dataset Formatting and Unit Testing

In [36]:
# Clean the Data
# The raw MBPP+ dataset has inconsistent schemas (some fields are lists, some are None).
# We fix this by extracting ONLY what we need: the test cases.
dict_data = []

In [37]:
for task_id, task_data in MBPP_TASKS.items():
    # Extract the test cases using our helper from Step 8
    # This handles the "messy" parsing right now, so the Dataset is clean.
    try:
        tests = get_tests_from_task(task_data)
    except KeyError:
        # If a task is broken/empty, skip it to prevent crashes
        print(f"Skipping task {task_id}: No tests found.")
        continue

    # Create a CLEAN 'answer' dictionary
    # This guarantees every row has the exact same structure.
    # This prevents the "ArrowInvalid" error.
    clean_answer = {
        "task_id": str(task_id),
        "test_list": tests  # Always a List of Strings
    }

    # Append to our list
    dict_data.append({
        "prompt": task_data["prompt"],
        "answer": clean_answer
    })

In [38]:
# Creating a Hugging Face compatible dataset
dataset = Dataset.from_list(dict_data)

In [39]:
print("Dataset Features:", dataset.features)
print("Sample Row Answer Keys:", dataset[0]["answer"].keys())

Dataset Features: {'prompt': Value('string'), 'answer': {'task_id': Value('string'), 'test_list': List(Value('string'))}}
Sample Row Answer Keys: dict_keys(['task_id', 'test_list'])


In [40]:
# Pick the 2nd task again for consistency
task = dataset[2]["answer"] # We grab it from our NEW dataset column
prompt = dataset[2]["prompt"]

In [41]:
prompt

'"""\nWrite a function to find the n largest integers from a given list of numbers, returned in descending order.\nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65]\n"""\n'

In [42]:

# Build the prompt structure
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": prompt},
]

In [43]:
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

In [44]:
# Generate
FastLanguageModel.for_inference(model)
with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.0,
    )

In [45]:
# Slice to get only the generated text
input_len = inputs["input_ids"].shape[1]
generated_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True)

In [46]:
print(generated_text)

<START_WORKING_OUT>
Problem: Find n largest integers from a list, return in descending order.
Approach: Use heapq.nlargest which is efficient for this purpose.
Parameters: List of numbers, integer n.
Return: List of n largest numbers in descending order.
Implementation: heapq.nlargest(n, numbers) returns sorted list, no need custom sorting.
</END_WORKING_OUT>
<SOLUTION>
import heapq

def heap_queue_largest(numbers, n):
    """
    Return the n largest numbers from the list in descending order.
    
    Args:
        numbers: List of numbers (integers or floats).
        n: Number of largest elements to return.
    
    Returns:
        List of n largest numbers in descending order.
    """
    return heapq.nlargest(n, numbers)
</SOLUTION>


In [47]:
# CRITICAL PART: Testing the Reward Functions
# The Reward Functions expect LISTS (Batches), so we wrap our single item in a list.
# This simulates a batch size of 1.
batch_prompts = [prompt]
batch_completions = [generated_text]
batch_answers = [task] # This is the "answer" column data

In [48]:
# 1. Test Format Reward
r_format = format_reward_func(completions=batch_completions)
print(f"Format Reward (Expect 0.1): {r_format[0]}")

Format Reward (Expect 0.1): 0.1


In [49]:
# 2. Test Reasoning Reward
r_reason = reasoning_reward_func(completions=batch_completions)
print(f"Reasoning Reward (Expect 0.0-0.2): {r_reason[0]:.4f}")

Reasoning Reward (Expect 0.0-0.2): 0.0644


In [50]:

# 3. Test Correctness Reward (The complex one)
# Note: We pass 'answer' explicitly, just like the Trainer will.
r_correct = correctness_reward_func(
    prompts=batch_prompts,
    completions=batch_completions,
    answer=batch_answers
)
print(f"Correctness Reward (Expect 1.0 or 0.0): {r_correct[0]}")

Correctness Reward (Expect 1.0 or 0.0): 1.0


In [51]:
if r_format[0] > 0 and (r_correct[0] == 0.0 or r_correct[0] == 1.0):
    print(" SUCCESS: All reward functions accepted the inputs and returned scores.")
    print(" The plumbing is connected correctly.")
else:
    print(" FAIL: Something returned an unexpected format.")

 SUCCESS: All reward functions accepted the inputs and returned scores.
 The plumbing is connected correctly.


# Step 12: Apply Chat Template

In [52]:
def apply_chat_template(row):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": row["prompt"]}
    ]

    # "tokenize=False" gives us the raw text string (e.g. "<|system|>...<|user|>...")
    # This is exactly what the GRPOTrainer expects in the 'prompt' column.
    row["prompt"] = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    return row

In [None]:
# Apply it to the whole dataset
original_prompt = dataset[0]["prompt"]
dataset = dataset.map(apply_chat_template)

print("\n--- BEFORE ---")
print(original_prompt)
print("\n--- AFTER (What the Model Sees) ---")
print(dataset[0]["prompt"])

# Step 13: Setting up GRPO Configurations

In [56]:
# We give the model ample room so it never gets cut off
max_prompt_length = 512
max_completion_length = 2048  # doubled from T4 config

In [57]:
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 0.95,
    top_k = -1,
    seed = 3407,
    temperature = 0.9, # High enough to get diverse answers for GRPO
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

In [None]:
# 3. The Trainer Config
training_args = GRPOConfig(
    # Integration
    vllm_sampling_params = vllm_sampling_params, # We use vLLM for speed
    output_dir = "outputs",
    report_to = "wandb",
    run_name = "mbpp-grpo-a100-run1",

    # Optimization
    learning_rate = 5e-6,        # Safe, stable LR
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",

    # A100 POWER SETTINGS
    per_device_train_batch_size = 4, # A100 can handle real batches
    gradient_accumulation_steps = 1, # No need to accumulate if batch is 4
    num_generations = 8,             # G=8: Much better stability than G=4

    # Lengths
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,

    # Duration
    #num_train_epochs = 1,            # 1 Epoch is safest for RL on small data
    max_steps = 5,

    # Logging
    logging_steps = 1,
    save_steps = 50,                 # Save more frequently
    use_vllm = True,                 # Explicitly enable vLLM
)

# Step 14: Initialize and Run GRPO Trainer

In [68]:
# Select the Reward Functions we defined in Step 10
# These are the "Judges" that will score the model's outputs.
reward_functions = [
    format_reward_func,       # Did it use <START_WORKING_OUT> and <SOLUTION>? (0.1)
    reasoning_reward_func,    # Did it write ~500 chars of thought? (0.2)
    correctness_reward_func   # Did the code actually pass the tests? (1.0)
]

In [61]:
# Initialize the Trainer
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = reward_functions,
    args = training_args,         # The A100 Config we just built
    train_dataset = dataset,      # The dataset with the Chat Template applied
)

In [None]:
trainer.train()