# Notebook: GRPO Reinforcement Learning (Run 2 ‚Äì Stabilized Partial-Credit Training)

This notebook performs the **second GRPO reinforcement learning run** on top of the **SFT warm-start model**: `models/qwen3-4b-sft`.

The focus of this run is **controlled performance improvement** on MBPP-style coding tasks while explicitly preventing reward hacking, verbosity collapse, and masking failures observed in the initial GRPO attempt.

---

## Objective

The goal of this stage is to **improve pass@1 correctness** beyond the SFT baseline by applying **GRPO with dense, verifiable rewards**, while maintaining:

- Concise, efficient reasoning traces  
- Stable output formatting  
- Controlled policy drift (no degeneration into rambling or schema exploitation)

---

## Key Changes from the First GRPO Run

### 1. Dense Correctness Reward (Partial Credit)
- Replaced binary pass/fail correctness with **fractional credit**:
  - Reward = `passed_tests / total_tests`
- Added a small **victory bonus** for passing all tests to preserve a global optimum.
- Prevents flat reward landscapes where near-miss solutions receive no gradient.

### 2. Anti-Filibuster Reasoning Reward
- Replaced linear length-based reward with a **capped + penalized profile**:
  - Linear ascent up to 400 characters
  - Flat reward plateau from 400‚Äì800 characters
  - Aggressive negative slope beyond 800 characters
- Explicitly disincentivizes infinite rambling while still encouraging meaningful reasoning.

### 3. Format Reward Demotion
- Reduced format reward to a **small hygiene incentive** (0.02 max).
- Ensures schema compliance without allowing formatting to dominate learning.

### 4. Stop Condition Correction
- Reverted to `stop = [tokenizer.eos_token]`.
- Removed string-based stop conditions that caused masking failures, padding leakage, and clipped-ratio explosions.
- Restores correct termination detection and loss masking.

### 5. KL Term Removal (Default GRPO Behavior)
- Explicitly **did not use a KL penalty** (`beta` omitted / default 0.0).
- Aligns with modern GRPO practice and avoids unnecessary instability.
- KL metrics are monitored diagnostically only.

### 6. Stability-Oriented Training Setup
- Higher exploration via `num_generations = 16`
- Gradient accumulation for variance reduction

- Two-epoch training with early-stop awareness
- Careful learning-rate selection for short-horizon RL

---

## Training Procedure Summary

- Load the **SFT warm-start model**: `models/qwen3-4b-sft`
- Apply ChatML-style prompt formatting
- Generate multiple rollouts per prompt via vLLM
- Compute rewards using:
  - Schema validation
  - Reasoning length shaping
  - Partial-credit unit test execution
- Optimize policy using GRPO with clipped ratios
- Monitor reward decomposition, KL drift, and generation length
- Save checkpoints for post-hoc evaluation

# Step 1: Mounting Google Drive and Importing Libraries


In [1]:
from google.colab import drive
drive.mount("/content/drive")
%cd /content/drive/MyDrive/grpo-verified-reasoner
!ls

Mounted at /content/drive
/content/drive/MyDrive/grpo-verified-reasoner
data			      LICENSE	 outputs    unsloth_compiled_cache
grpo_trainer_lora_model       models	 README.md  _unsloth_sentencepiece_temp
huggingface_tokenizers_cache  notebooks  src	    wandb


In [None]:

# Install UV (Faster pip)
!pip install --upgrade -qqq uv

In [None]:

!pip -q install -U evalplus

In [4]:
import os
import subprocess

In [5]:

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:False"

In [6]:
os.environ["UNSLOTH_VLLM_STANDBY"] = "1"

In [7]:
os.environ["WANDB_PROJECT"] = "mbpp-rl-project"

In [8]:
# Environment Logic (Colab vs Local)
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # Version Matching
    try: import numpy, PIL; get_numpy = f"numpy=={numpy.__version__}"; get_pil = f"pillow=={PIL.__version__}"
    except: get_numpy = "numpy"; get_pil = "pillow"
    try: is_t4 = "Tesla T4" in str(subprocess.check_output(["nvidia-smi"]))
    except: is_t4 = False

    # A100 gets vllm 0.10.2 (Fast), T4 gets 0.9.2 (Stable)
    get_vllm, get_triton = ("vllm==0.9.2", "triton==3.2.0") if is_t4 else ("vllm==0.10.2", "triton")

    # Install Everything
    !uv pip install -qqq --upgrade \
        unsloth {get_vllm} {get_numpy} {get_pil} torchvision bitsandbytes xformers
    !uv pip install -qqq {get_triton}

# Install TRL
!uv pip install transformers==4.56.2
!uv pip install --no-deps trl==0.22.2

[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m18 packages[0m [2min 20ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 296ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 188ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 36ms[0m[0m
 [31m-[39m [1mtransformers[0m[2m==4.57.3[0m
 [32m+[39m [1mtransformers[0m[2m==4.56.2[0m
[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m1 package[0m [2min 1ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 24ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 1ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 5ms[0m[0m
 [31m-[39m [1mtrl[0m[2m==0.24.0[0m
 [32m+[39m [1mtrl[0m[2m==0.22.2[0m


In [9]:
import re
import ast
import torch
import wandb
import random
import evalplus
import traceback
import numpy as np
import multiprocessing as mp
from datasets import Dataset
from unsloth import FastLanguageModel
from evalplus.data import get_mbpp_plus
from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams

ü¶• Unsloth: Will patch your computer to enable 2x faster free finetuning.
INFO 01-05 02:12:48 [__init__.py:216] Automatically detected platform cuda.
ü¶• Unsloth Zoo will now patch everything to make training faster!


In [None]:
wandb.login()

# Step 2: Verifying GPU and Environment

In [11]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
if torch.cuda.is_available():
    print("GPU:", torch.cuda.get_device_name(0))

Torch version: 2.8.0+cu128
CUDA available: True
GPU: NVIDIA H100 80GB HBM3


# Step 3: Loading Base Model and LoRA Adapters

In [12]:

MODEL_PATH = "models/qwen3-4b-sft"

In [None]:
# Load the model
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = MODEL_PATH,
    max_seq_length = 3072,      # Aligned with GRPO + schema
    load_in_4bit = False,       # Full precision for RL stability
    fast_inference = True,      # Required for vLLM
    gpu_memory_utilization = 0.8,
)

In [14]:
trainable = 0
total = 0
trainable_names = []
for name, p in model.named_parameters():
    n = p.numel()
    total += n
    if p.requires_grad:
        trainable += n
        trainable_names.append(name)

print(f"Trainable params: {trainable:,} / {total:,} = {100*trainable/total:.4f}%")
print("Example trainable params:", trainable_names[:20])

Trainable params: 66,060,288 / 4,088,528,384 = 1.6157%
Example trainable params: ['base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight', 'base_model.model.model.layers.0.self_attn.q_proj.lora_B.default.weight', 'base_model.model.model.layers.0.self_attn.k_proj.lora_A.default.weight', 'base_model.model.model.layers.0.self_attn.k_proj.lora_B.default.weight', 'base_model.model.model.layers.0.self_attn.v_proj.lora_A.default.weight', 'base_model.model.model.layers.0.self_attn.v_proj.lora_B.default.weight', 'base_model.model.model.layers.0.self_attn.o_proj.lora_A.default.weight', 'base_model.model.model.layers.0.self_attn.o_proj.lora_B.default.weight', 'base_model.model.model.layers.0.mlp.gate_proj.lora_A.default.weight', 'base_model.model.model.layers.0.mlp.gate_proj.lora_B.default.weight', 'base_model.model.model.layers.0.mlp.up_proj.lora_A.default.weight', 'base_model.model.model.layers.0.mlp.up_proj.lora_B.default.weight', 'base_model.model.model.layers.0.mlp.down_proj

In [15]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen3ForCausalLM(
      (model): Qwen3Model(
        (embed_tokens): Embedding(151936, 2560, padding_idx=151654)
        (layers): ModuleList(
          (0-35): 36 x Qwen3DecoderLayer(
            (self_attn): Qwen3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=2560, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear

# Step 4: Sanity Check

In [16]:
# This is the same prompt that we used during SFT
system_prompt = """You are a code-generation engine.
You must output your response in the following exact format:
<START_WORKING_OUT>
Concise reasoning steps required to solve the problem.
</END_WORKING_OUT>
<SOLUTION>
Valid Python code only.
</SOLUTION>
Do not output anything outside these tags."""

In [17]:
user_prompt = "Write a Python function that returns the factorial of a number."

messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": user_prompt},
]
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
)

In [18]:
# Move the dictionary to GPU manually
inputs = {k: v.to("cuda") for k, v in inputs.items()}

In [19]:
FastLanguageModel.for_inference(model) # Temporarily enable inference mode for the test
with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.0, # Deterministic check
    )

In [20]:

decoded = tokenizer.decode(output[0], skip_special_tokens=True)

In [21]:
print("\n--- MODEL OUTPUT ---")
input_len = inputs["input_ids"].shape[1]
print(tokenizer.decode(output[0][input_len:], skip_special_tokens=True))


--- MODEL OUTPUT ---
<START_WORKING_OUT>
Define a function factorial(n) that calculates the product of all positive integers up to n.
Handle non-positive integers by returning 1 (factorial of 0 and negative is 1).
Implement iterative approach for efficiency.
Return the computed factorial.
</END_WORKING_OUT>
<SOLUTION>
def factorial(n):
    if n <= 0:
        return 1
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result
</SOLUTION>


Comment:  No schema check, extractor, or reward function ever sees the full decoded sequence. They only ever see generated_text.


# Step 6: Defining Output Schema

In [22]:
# Regular expressions for tag validation (case-insensitive)
RE_START = re.compile(r"<START_WORKING_OUT>", re.IGNORECASE)
RE_END   = re.compile(r"</END_WORKING_OUT>", re.IGNORECASE)
RE_SOL   = re.compile(r"<SOLUTION>", re.IGNORECASE)
RE_SOL_END = re.compile(r"</SOLUTION>", re.IGNORECASE)

In [23]:
def validate_schema(text: str) -> tuple[bool, str]:
    """
    Checks whether the model output follows the exact required schema.
    Returns (is_valid, reason).
    """
    if not RE_START.search(text):
        return False, "Missing <START_WORKING_OUT>"
    if not RE_END.search(text):
        return False, "Missing </END_WORKING_OUT>"
    if not RE_SOL.search(text):
        return False, "Missing <SOLUTION>"
    if not RE_SOL_END.search(text):
        return False, "Missing </SOLUTION>"

    # Optional: check order consistency
    start_idx = RE_START.search(text).start()
    sol_idx   = RE_SOL.search(text).start()
    if sol_idx < start_idx:
        return False, "Tag order incorrect (<SOLUTION> before reasoning block)."

    return True, "Schema valid"

In [24]:
# Run a sanity test using the previous decoded output
is_valid, reason = validate_schema(decoded)
print("Schema Check:", is_valid, "|", reason)

Schema Check: True | Schema valid


# Step 7: Solution Extraction

In [25]:
# Regex to extract the code block between <SOLUTION> ... </SOLUTION>
RE_SOLUTION = re.compile(r"<SOLUTION>\s*(.*?)\s*</SOLUTION>", re.IGNORECASE | re.DOTALL)

In [26]:
def extract_solution(text: str) -> tuple[str | None, str]:
    """
    Extracts the Python code inside <SOLUTION> tags.
    Returns (code, status) where:
        code   -> the extracted string or None if failed
        status -> textual reason (for debugging)
    """
    match = RE_SOLUTION.search(text)
    if not match:
        return None, "No <SOLUTION> block found."

    code = match.group(1).strip()
    if not code:
        return None, "Empty <SOLUTION> block."

    # Syntax check via Python's AST parser
    try:
        ast.parse(code)
    except SyntaxError as e:
        return None, f"Syntax error in code: {e}"

    return code, "Valid Python code extracted."

In [27]:
# Calculate where the prompt ends
input_len = inputs["input_ids"].shape[1]

In [28]:
# Decode ONLY the new tokens (The Assistant's reply)
generated_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True)

In [29]:
# Now run the check on ONLY the generated text
code, status = extract_solution(generated_text) # Use the new variable
print("Status:", status)

Status: Valid Python code extracted.


In [30]:

# Show snippet of the extracted code
if code:
    print("\n--- Extracted Python Code ---\n")
    print(code)


--- Extracted Python Code ---

def factorial(n):
    if n <= 0:
        return 1
    result = 1
    for i in range(2, n + 1):
        result *= i
    return result


# Step 8: Verifier Integration (EvalPlus MBPP+)

In [31]:
# Load MBPP+ tasks as a dict: {task_id: problem_dict}
MBPP_TASKS = get_mbpp_plus()

print(f"Loaded MBPP+ tasks: {len(MBPP_TASKS)}")

Downloading dataset from https://github.com/evalplus/mbppplus_release/releases/download/v0.2.0/MbppPlus.jsonl.gz
Loaded MBPP+ tasks: 378


In [32]:
# Quick peek at one task to confirm fields & shape
sample_task_id = next(iter(MBPP_TASKS.keys()))
sample_task = MBPP_TASKS[sample_task_id]

print("\nSample Task ID:", sample_task_id)
print("Keys:", list(sample_task.keys()))
print("\nPrompt (first 400 chars):\n", sample_task["prompt"][:400])


Sample Task ID: Mbpp/2
Keys: ['task_id', 'prompt', 'entry_point', 'canonical_solution', 'base_input', 'atol', 'plus_input', 'contract', 'assertion']

Prompt (first 400 chars):
 """
Write a function to find the shared elements from the given two lists.
assert set(similar_elements((3, 4, 5, 6),(5, 7, 4, 10))) == set((4, 5))
"""



In [33]:
# Different EvalPlus versions may store tests under slightly different keys,
# so we normalize via a helper (used later in reward function).
def get_tests_from_task(task: dict) -> list[str]:
    """
    Extracts MBPP test assertions from a task.
    Supports both list-based and string-based formats.
    """
    # Case 1: already a list of assertions
    for k in ("test_list", "tests", "plus_tests", "base_tests"):
        if k in task and task[k]:
            return list(task[k])

    # Case 2: single multiline assertion string (MBPP+ common case)
    if "assertion" in task and task["assertion"]:
        lines = task["assertion"].strip().splitlines()
        return [line for line in lines if line.strip()]

    raise KeyError(f"No tests found in task keys: {list(task.keys())}")

# Step 9: Defining Helper Functions

In [34]:
def _exec_code_and_tests_worker(code: str, tests: list[str], queue: mp.Queue) -> None:
    """
    Runs model code + tests.
    CRITICAL FEATURES:
    1. Runs ALL tests (Partial Credit).
    2. Catches ALL exceptions (Robustness).
    3. Truncates error logs (IPC Safety).
    """
    try:
        # Create the "Main Desk" (Environment)
        env = {"__builtins__": __builtins__}

        # Run the User's Code into 'env'
        exec(code, env, env)

        passed_count = 0
        total_tests = len(tests)
        first_error = None

        # Run ALL Test Cases
        for t in tests:
            try:
                exec(t, env, env)
                passed_count += 1
            except Exception:
                # Capture only the FIRST error to save bandwidth
                if first_error is None:
                    # Get the full traceback
                    tb = traceback.format_exc()
                    # SOPHIA'S FIX: Truncate to 500 chars to prevent IPC deadlock
                    first_error = tb[:500] + "\n...[TRUNCATED]..." if len(tb) > 500 else tb
                # Continue to the next test!
                continue

        # Mission Complete: Return the score
        queue.put((passed_count, total_tests, first_error))

    except Exception:
        # Catch syntax errors or crashes in the main code body
        tb = traceback.format_exc()
        truncated_error = tb[:500] + "\n...[TRUNCATED]..." if len(tb) > 500 else tb
        queue.put((0, len(tests), truncated_error))

In [35]:
def run_mbpp_tests(code: str, task: dict, timeout_s: float = 2.0) -> tuple[int, int, str | None]:
    """
    Executes tests and returns (passed_count, total_count, first_error).
    """
    tests = get_tests_from_task(task)
    if not tests:
        return 0, 0, "No tests found."

    ctx = mp.get_context("fork")
    q = ctx.Queue()
    p = ctx.Process(target=_exec_code_and_tests_worker, args=(code, tests, q))
    p.start()
    p.join(timeout_s)

    if p.is_alive():
        p.terminate()
        p.join()
        return 0, len(tests), f"Timeout after {timeout_s:.1f}s"

    if q.empty():
        return 0, len(tests), "No result returned from worker."

    passed_count, total_count, err = q.get()
    return passed_count, total_count, err

# Step 10: Defining Reward Functions

In [36]:
def format_reward_func(completions, **kwargs) -> list[float]:
    """
    Rewards the model for strictly following the XML schema.
    Args:
        completions: List of generated strings from the model.
    Returns:
        List of rewards (0.02 for valid schema, 0.0 for invalid).
    """
    rewards = []
    for completion in completions:
        # Uses your existing validator from Step 6
        is_valid, _ = validate_schema(completion)
        rewards.append(0.02 if is_valid else 0.0)
    return rewards

In [37]:
def reasoning_reward_func(completions, **kwargs) -> list[float]:
    """
    - Ascent: Linear 0-400 chars.
    - Plateau: 400-800 chars (Max Reward 0.1).
    - Penalty: Aggressive slope (0.03) starts > 800.
    - Zero Point: Reward hits 0.0 at ~1133 chars.
    """
    rewards = []
    for completion in completions:
        match = re.search(r"<START_WORKING_OUT>(.*?)</END_WORKING_OUT>", completion, re.DOTALL | re.IGNORECASE)

        if match:
            content = match.group(1).strip()
            length = len(content)

            if length < 50:
                rewards.append(0.0)

            elif length <= 400:
                # LINEAR ASCENT
                score = (length / 400.0) * 0.1
                rewards.append(score)

            elif length <= 800:
                # PLATEAU: 0.1 (The "Thinking Room")
                rewards.append(0.1)

            else:
                # AGGRESSIVE PENALTY (Slope 0.03)
                overage = length - 800
                penalty = (overage / 100.0) * 0.03
                score = 0.1 - penalty
                rewards.append(max(-0.1, score))

        else:
            rewards.append(0.0)

    return rewards

In [38]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Rewards the model based on the PERCENTAGE of tests passed.
    Includes a "Clean Sweep Bonus" for 100% completion.
    """
    rewards = []
    for prompt, completion, task_data in zip(prompts, completions, answer):
        code, status = extract_solution(completion)
        if not code:
            rewards.append(0.0)
            continue

        passed, total, err = run_mbpp_tests(code, task_data)

        if total == 0:
            rewards.append(0.0)
            continue

        # CALCULATE SCORE: Fraction of tests passed
        score = passed / total

        # VICTORY BONUS: If 100% passed, add +0.1 bonus
        # This differentiates "perfect" from "lucky" and prevents settling
        if passed == total:
            score += 0.1

        rewards.append(score)

    return rewards

# Step 11: Dataset Formatting and Unit Testing

In [39]:
# Clean the Data
# The raw MBPP+ dataset has inconsistent schemas (some fields are lists, some are None).
# We fix this by extracting ONLY what we need: the test cases.
dict_data = []

In [40]:
for task_id, task_data in MBPP_TASKS.items():
    # Extract the test cases using our helper from Step 8
    # This handles the "messy" parsing right now, so the Dataset is clean.
    try:
        tests = get_tests_from_task(task_data)
    except KeyError:
        # If a task is broken/empty, skip it to prevent crashes
        print(f"Skipping task {task_id}: No tests found.")
        continue

    # Create a CLEAN 'answer' dictionary
    # This guarantees every row has the exact same structure.
    # This prevents the "ArrowInvalid" error.
    clean_answer = {
        "task_id": str(task_id),
        "test_list": tests  # Always a List of Strings
    }

    # Append to our list
    dict_data.append({
        "prompt": task_data["prompt"],
        "answer": clean_answer
    })

In [41]:
# Creating a Hugging Face compatible dataset
dataset = Dataset.from_list(dict_data)

In [42]:
print("Dataset Features:", dataset.features)
print("Sample Row Answer Keys:", dataset[0]["answer"].keys())

Dataset Features: {'prompt': Value('string'), 'answer': {'task_id': Value('string'), 'test_list': List(Value('string'))}}
Sample Row Answer Keys: dict_keys(['task_id', 'test_list'])


In [43]:

# Pick the 2nd task again for consistency
task = dataset[2]["answer"] # We grab it from our NEW dataset column
prompt = dataset[2]["prompt"]

In [44]:
prompt

'"""\nWrite a function to find the n largest integers from a given list of numbers, returned in descending order.\nassert heap_queue_largest( [25, 35, 22, 85, 14, 65, 75, 22, 58],3)==[85, 75, 65]\n"""\n'

In [45]:
# Build the prompt structure
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": prompt},
]

In [46]:
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

In [47]:
# Generate
FastLanguageModel.for_inference(model)
with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.0,
    )

In [48]:
# Slice to get only the generated text
input_len = inputs["input_ids"].shape[1]
generated_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True)

In [49]:
print(generated_text)

<START_WORKING_OUT>
Problem: Find n largest integers from a list, return in descending order.
Approach: Use heapq.nlargest which returns n largest elements in order from largest to smallest.
Parameters: List of numbers, integer n.
Return: List of n largest numbers in descending order.
</END_WORKING_OUT>
<SOLUTION>
import heapq

def heap_queue_largest(nums, n):
    """
    Return the n largest numbers from nums in descending order.
    
    Args:
        nums: List of numbers (integers or floats)
        n: Number of largest elements to return
        
    Returns:
        List of n largest numbers in descending order
    """
    if n <= 0:
        return []
    if n >= len(nums):
        nums_sorted = sorted(nums, reverse=True)
        return nums_sorted[:n]
    return heapq.nlargest(n, nums)
</SOLUTION>


In [50]:
# CRITICAL PART: Testing the Reward Functions
# The Reward Functions expect LISTS (Batches), so we wrap our single item in a list.
# This simulates a batch size of 1.
batch_prompts = [prompt]
batch_completions = [generated_text]
batch_answers = [task] # This is the "answer" column data

In [51]:
# 1. Test Format Reward
r_format = format_reward_func(completions=batch_completions)
print(f"Format Reward (Expect 0.1): {r_format[0]}")

Format Reward (Expect 0.1): 0.02


In [52]:
# 2. Test Reasoning Reward
r_reason = reasoning_reward_func(completions=batch_completions)
print(f"Reasoning Reward (Expect 0.0-0.15): {r_reason[0]:.4f}")

Reasoning Reward (Expect 0.0-0.15): 0.0663


In [53]:

# 3. Test Correctness Reward (The complex one)
# Note: We pass 'answer' explicitly, just like the Trainer will.
r_correct = correctness_reward_func(
    prompts=batch_prompts,
    completions=batch_completions,
    answer=batch_answers
)
print(f"Correctness Reward (Expect 1.0 or 0.0): {r_correct[0]}")

Correctness Reward (Expect 1.0 or 0.0): 1.1


In [54]:
if r_format[0] > 0 and (r_correct[0] == 0.0 or r_correct[0] == 1.1):
    print(" SUCCESS: All reward functions accepted the inputs and returned scores.")
    print(" The plumbing is connected correctly.")
else:
    print(" FAIL: Something returned an unexpected format.")

 SUCCESS: All reward functions accepted the inputs and returned scores.
 The plumbing is connected correctly.


# Step 12: Apply Chat Template

In [55]:
def apply_chat_template(row):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": row["prompt"]}
    ]

    # "tokenize=False" gives us the raw text string (e.g. "<|system|>...<|user|>...")
    # This is exactly what the GRPOTrainer expects in the 'prompt' column.
    row["prompt"] = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    return row

In [None]:
# Apply it to the whole dataset
original_prompt = dataset[0]["prompt"]
dataset = dataset.map(apply_chat_template)

print("\n--- BEFORE ---")
print(original_prompt)
print("\n--- AFTER (What the Model Sees) ---")
print(dataset[0]["prompt"])

# Step 13: Setting up GRPO Configurations

In [57]:
# We give the model ample room so it never gets cut off
max_prompt_length = 512
max_completion_length = 2048  # doubled from T4 config

In [58]:
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 0.95,
    top_k = -1,
    seed = 3407,
    temperature = 0.8, # High enough to get diverse answers for GRPO
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

In [59]:
# 3. The Trainer Config
training_args = GRPOConfig(
    # Integration
    vllm_sampling_params = vllm_sampling_params, # We use vLLM for speed
    output_dir = "outputs",
    report_to = "wandb",
    run_name = "mbpp-grpo-h100-run4-full",

    # Optimization
    learning_rate = 5e-6,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",

    # A100 POWER SETTINGS
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4,
    num_generations = 16,             # G=8: Much better stability than G=4

    # Lengths
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,

    # Duration
    num_train_epochs = 2,            # 2 Epoch is safest for RL on small data
    #max_steps = 5,

    # Logging
    logging_steps = 5,
    save_steps = 30,                 # Save more frequently
    use_vllm = True,                 # Explicitly enable vLLM
)

Unsloth: We now expect `per_device_train_batch_size` * `gradient_accumulation_steps` * `world_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 16


# Step 14: Initialize and Run GRPO Trainer

In [60]:
# Select the Reward Functions we defined in Step 10
# These are the "Judges" that will score the model's outputs.
reward_functions = [
    format_reward_func,       # Did it use <START_WORKING_OUT> and <SOLUTION>? (0.1)
    reasoning_reward_func,    # Did it write ~500 chars of thought? (0.2)
    correctness_reward_func   # Did the code actually pass the tests? (1.0)
]

In [61]:
# Initialize the Trainer
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = reward_functions,
    args = training_args,         # The A100 Config we just built
    train_dataset = dataset,      # The dataset with the Chat Template applied
)

In [62]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 378 | Num Epochs = 2 | Total steps = 188
O^O/ \_/ \    Batch size per device = 16 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (16 x 4 x 1) = 64
 "-____-"     Trainable parameters = 66,060,288 of 4,088,528,384 (1.62% trained)


wandb: Detected [huggingface_hub.inference, openai] in use.
wandb: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
wandb: For more information, check out the docs at: https://weave-docs.wandb.ai/


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / format_reward_func / mean,rewards / format_reward_func / std,rewards / reasoning_reward_func / mean,rewards / reasoning_reward_func / std,rewards / correctness_reward_func / mean,rewards / correctness_reward_func / std
5,0.0083,0.917053,0.175341,145.7375,59.0,739.8,0.003125,139.805557,59.0,399.2,8.317562,0.019937,0.0005,0.06024,0.024401,0.836875,0.407541
10,0.0039,1.014248,0.223273,193.525,69.2,1280.2,0.009375,176.250284,69.2,780.4,3.859137,0.019875,0.000701,0.065154,0.0233,0.929219,0.36826
15,0.0005,0.901061,0.330024,183.2125,79.6,892.2,0.003125,177.288196,79.6,557.4,0.542674,0.02,0.0,0.068978,0.026858,0.812083,0.442199
20,0.0006,0.91016,0.2613,244.290625,63.6,1946.0,0.0125,221.132184,63.6,1173.8,0.637628,0.019812,0.001201,0.066649,0.034955,0.823698,0.43716
25,0.0011,1.048846,0.212669,197.840625,59.2,1055.4,0.009375,180.495294,59.2,522.0,1.120589,0.019937,0.0005,0.071929,0.02315,0.956979,0.299423
30,0.0006,0.752737,0.221269,256.50625,68.8,1627.8,0.0125,233.955322,68.8,1139.0,0.649474,0.01975,0.001352,0.071581,0.036045,0.661406,0.447654
35,0.045,0.903331,0.243252,190.984375,63.6,1123.4,0.003125,185.278128,63.6,907.2,45.043564,0.02,0.0,0.073904,0.02742,0.809427,0.451101
40,0.0057,0.713682,0.328387,274.953125,73.6,1341.6,0.015625,247.035092,73.6,1007.0,5.70399,0.01975,0.001352,0.07487,0.037124,0.619063,0.497398
45,0.0009,0.891254,0.210045,262.625,75.0,1604.2,0.0125,240.103711,75.0,998.0,0.860802,0.019875,0.001,0.073462,0.035464,0.797917,0.463865
50,0.0004,0.999469,0.299818,237.096875,90.8,1051.2,0.003125,231.536609,90.8,772.4,0.425752,0.019937,0.0005,0.08099,0.03289,0.898542,0.357206


90
True
True
True
True
False
-inf
20
13
1
12
64
[('Social sciences', 82), ('English', 88), ('Science', 90), ('Maths', 97)]
(0, 0, '')
[3, 4, 5, 6, 7, 10]
30
[(6, 24, 12)]
[1, 4, 9, 16, 25]
Computed angle: 1.5707963267948966
[('Red',), ('Green',), ('Blue',)]


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


204.20352248333654
243
345
513
243
243
345
513
243
243
345
513
243
345
513
16.0
106
1256.6370614359173
7


0,1
profiling/Time taken: UnslothGRPOTrainer._calculate_rewards,‚ñà‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñà‚ñÇ‚ñÇ‚ñÅ‚ñà‚ñÇ‚ñÇ‚ñÉ‚ñÇ‚ñÉ‚ñÅ‚ñÅ‚ñÇ
profiling/Time taken: UnslothGRPOTrainer._prepare_inputs,‚ñÅ‚ñÅ‚ñÅ‚ñÜ‚ñÅ‚ñÜ‚ñÅ‚ñÅ‚ñà‚ñà‚ñÖ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñà‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÖ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñà‚ñÅ‚ñà‚ñÅ‚ñÅ
profiling/Time taken: UnslothGRPOTrainer.correctness_reward_func,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
profiling/Time taken: UnslothGRPOTrainer.format_reward_func,‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÖ‚ñÇ‚ñÇ‚ñÜ‚ñÑ‚ñÉ‚ñÉ‚ñÖ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÜ‚ñà‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÖ‚ñÜ‚ñá‚ñÖ‚ñÜ‚ñÇ‚ñÖ
profiling/Time taken: UnslothGRPOTrainer.reasoning_reward_func,‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñà‚ñÑ‚ñÑ‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñà‚ñÉ‚ñÑ‚ñÖ‚ñÉ‚ñÉ‚ñÜ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÑ‚ñÑ‚ñá‚ñÑ‚ñÜ‚ñÜ‚ñá‚ñÑ‚ñÑ‚ñÖ‚ñÉ‚ñÜ
profiling/Time taken: UnslothGRPOTrainer.vLLM.generate,‚ñà‚ñÑ‚ñá‚ñÅ‚ñÅ‚ñÅ‚ñà‚ñà‚ñà‚ñÉ‚ñà‚ñÉ‚ñÑ‚ñÉ‚ñà‚ñà‚ñÉ‚ñÑ‚ñá‚ñÑ‚ñá‚ñà‚ñà‚ñà‚ñÖ‚ñÇ‚ñÑ‚ñá‚ñà‚ñà‚ñà‚ñà‚ñÑ‚ñà‚ñá‚ñà‚ñà‚ñÇ‚ñÜ‚ñà
train/completion_length,‚ñÅ‚ñÇ‚ñÇ‚ñÑ‚ñÇ‚ñÑ‚ñÇ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÑ‚ñÖ‚ñÜ‚ñÖ‚ñÑ‚ñÑ‚ñÉ‚ñÖ‚ñÖ‚ñÜ‚ñÖ‚ñÑ‚ñÖ‚ñÜ‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñá‚ñÑ‚ñÜ‚ñÜ‚ñÖ‚ñà‚ñÖ‚ñÜ
train/completions/clipped_ratio,‚ñÇ‚ñÉ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÑ‚ñÉ‚ñÇ‚ñÑ‚ñÖ‚ñÅ‚ñÉ‚ñÜ‚ñÖ‚ñÉ‚ñÇ‚ñÇ‚ñÉ‚ñÑ‚ñÉ‚ñÖ‚ñÇ‚ñÇ‚ñà‚ñÉ‚ñÑ‚ñÇ‚ñÇ‚ñÑ‚ñÇ‚ñÖ‚ñÉ‚ñÑ‚ñÖ‚ñÉ‚ñÑ
train/completions/max_length,‚ñÅ‚ñÑ‚ñÇ‚ñá‚ñÉ‚ñÜ‚ñÉ‚ñÑ‚ñÜ‚ñÉ‚ñÖ‚ñá‚ñÇ‚ñá‚ñà‚ñà‚ñÖ‚ñÉ‚ñÇ‚ñÖ‚ñÜ‚ñÖ‚ñá‚ñÉ‚ñÇ‚ñà‚ñÖ‚ñá‚ñÜ‚ñÖ‚ñÜ‚ñÑ‚ñá‚ñá‚ñÜ‚ñÜ‚ñÑ‚ñÜ
train/completions/max_terminated_length,‚ñÅ‚ñÉ‚ñÇ‚ñÜ‚ñÇ‚ñÜ‚ñÑ‚ñÖ‚ñÖ‚ñÉ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÇ‚ñÖ‚ñÇ‚ñÑ‚ñÜ‚ñÜ‚ñÖ‚ñÑ‚ñÑ‚ñÜ‚ñÇ‚ñÖ‚ñá‚ñÜ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÖ‚ñà‚ñÑ‚ñá

0,1
profiling/Time taken: UnslothGRPOTrainer._calculate_rewards,5.18347
profiling/Time taken: UnslothGRPOTrainer._prepare_inputs,1e-05
profiling/Time taken: UnslothGRPOTrainer.correctness_reward_func,5.18033
profiling/Time taken: UnslothGRPOTrainer.format_reward_func,0.00032
profiling/Time taken: UnslothGRPOTrainer.reasoning_reward_func,0.0006
profiling/Time taken: UnslothGRPOTrainer.vLLM.generate,11.72052
total_flos,0
train/completion_length,317.44271
train/completions/clipped_ratio,0.01562
train/completions/max_length,1581.66667


TrainOutput(global_step=188, training_loss=0.0022916886509653737, metrics={'train_runtime': 5441.8689, 'train_samples_per_second': 0.139, 'train_steps_per_second': 0.035, 'total_flos': 0.0, 'train_loss': 0.0022916886509653737})

# Step 15: Sanity Check

Let us now check the model that we just trained!

In [63]:
# Switch to Inference Mode
FastLanguageModel.for_inference(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Qwen3ForCausalLM(
      (model): Qwen3Model(
        (embed_tokens): Embedding(151936, 2560, padding_idx=151654)
        (layers): ModuleList(
          (0-35): 36 x Qwen3DecoderLayer(
            (self_attn): Qwen3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=2560, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=2560, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear

In [64]:
test_question = "Write a function to find the volume of a sphere given its radius."

In [65]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": test_question},
]

In [66]:
# Tokenize (Exactly as before)
inputs = tokenizer.apply_chat_template(
    messages,
    tokenize=True,
    add_generation_prompt=True,
    return_tensors="pt",
    return_dict=True,
)
inputs = {k: v.to("cuda") for k, v in inputs.items()}

In [67]:
# Generate
with torch.no_grad():
    output = model.generate(
        **inputs,
        max_new_tokens=512,
        temperature=0.8, # Slight creativity to encourage reasoning
    )

In [68]:
# 5. Decode (Slicing input_len just like before)
input_len = inputs["input_ids"].shape[1]
generated_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True)

print("\n=== FINAL MODEL OUTPUT ===")
print(generated_text)


=== FINAL MODEL OUTPUT ===
<START_WORKING_OUT>
The formula for the volume of a sphere is V = (4/3) * œÄ * r¬≥.
We need a function that takes the radius r as input and returns the volume.
We can use the math module for the value of œÄ.
The function should handle positive radius values.
We'll return the calculated volume as a float.
</END_WORKING_OUT>
<SOLUTION>
import math

def sphere_volume(radius):
    """
    Calculate the volume of a sphere given its radius.
    
    Parameters:
    radius (float): The radius of the sphere.
    
    Returns:
    float: The volume of the sphere.
    """
    if radius < 0:
        raise ValueError("Radius must be non-negative")
    return (4/3) * math.pi * (radius ** 3)
</SOLUTION>


# Step 16: Saving the Model

In [69]:
MODEL_OUT = "models/qwen3-4b-grpo-final-2"

In [70]:
model.save_lora(MODEL_OUT)

In [71]:
tokenizer.save_pretrained(MODEL_OUT)

('models/qwen3-4b-grpo-final-2/tokenizer_config.json',
 'models/qwen3-4b-grpo-final-2/special_tokens_map.json',
 'models/qwen3-4b-grpo-final-2/chat_template.jinja',
 'models/qwen3-4b-grpo-final-2/vocab.json',
 'models/qwen3-4b-grpo-final-2/merges.txt',
 'models/qwen3-4b-grpo-final-2/added_tokens.json',
 'models/qwen3-4b-grpo-final-2/tokenizer.json')