In [1]:
import re
import torch
import argparse
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
import numpy as np
from accelerate import Accelerator
from functools import partial

INFO 07-31 00:34:31 [__init__.py:244] Automatically detected platform cuda.


In [2]:
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""


def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip().replace(",", "").replace("$", "")


In [3]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    # 'completions' is a list of generated strings.
    responses = completions
    
    # 'prompts' is now also a list of strings because you pre-formatted the dataset.
    # The old way of accessing the question would cause an error:
    # q = prompts[0][-1]['content'] 
    
    extracted_responses = [extract_xml_answer(r) for r in responses]
    
    print("--- Correctness Check ---")
    # Simply print the first prompt string from the batch for context.
    print("Prompt:", prompts[0])
    print("First Response:", responses[0])
    print("-------------------------")
    
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

In [4]:
model_name = "Qwen/Qwen2.5-1.5B-Instruct"

if "Llama" in model_name:
    output_dir = "outputs/Llama-1B-GRPO"
    run_name = "Llama-1B-GRPO-gsm8k"
else:
    output_dir="outputs/Qwen-1.5B-GRPO"
    run_name="Qwen-1.5B-GRPO-gsm8k"
    
training_args = GRPOConfig(
    output_dir=output_dir,
    run_name=run_name,
    learning_rate=5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    bf16=True,
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_generations=4,
    max_prompt_length=256,
    max_completion_length=786,
    num_train_epochs=1,
    save_steps=100,
    max_grad_norm=0.1,
    report_to="wandb",
    log_on_each_node=False,
)

In [5]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=None
).to("cuda")

You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


In [6]:
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': tokenizer.apply_chat_template([
            {'role': 'system', 'content': SYSTEM_PROMPT},
            #{'role': 'user', 'content': 'What is the largest single-digit prime number?'},
            #{'role': 'assistant', 'content': XML_COT_FORMAT.format(
            #    reasoning="9 is divisble by 3 and 8 is divisible by 2, but 7 is prime.",
            #    answer="7"
            #)},
            {'role': 'user', 'content': x['question']}
        ], tokenize=False, add_generation_prompt=True),
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()


In [7]:
class DebugGRPOTrainer(GRPOTrainer):
    def _generate_and_score_completions(self, generation_batch):
        # 1. Let the original method run and capture its return value.
        processed_batch = super()._generate_and_score_completions(generation_batch)

        # 2. Access the prompt IDs using the correct key: "prompt_ids"
        prompt_input_ids = processed_batch["prompt_ids"]
        
        # 3. Decode the token IDs to see the formatted string.
        prompt_strings = self.processing_class.batch_decode(
            prompt_input_ids,
            skip_special_tokens=False
        )
        
        print("\n--- ✅ VERIFIED FINAL FORMATTED STRING ---")
        print(prompt_strings[0])
        print("-------------------------------------------\n")

        # 4. Return the new processed_batch.
        return processed_batch

trainer = DebugGRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        correctness_reward_func,
        ],
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

[2025-07-31 00:34:36,493] [INFO] [real_accelerator.py:254:get_accelerator] Setting ds_accelerator to cuda (auto detect)


[2025-07-31 00:34:36,767] [INFO] [logging.py:107:log_dist] [Rank -1] [TorchCheckpointEngine] Initialized with serialization = False


/home/ubuntu/miniconda3/envs/alex/compiler_compat/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/home/ubuntu/miniconda3/envs/alex/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::runtime_error::~runtime_error()@GLIBCXX_3.4'
/home/ubuntu/miniconda3/envs/alex/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `__gxx_personality_v0@CXXABI_1.3'
/home/ubuntu/miniconda3/envs/alex/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::ostream::tellp()@GLIBCXX_3.4'
/home/ubuntu/miniconda3/envs/alex/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::chrono::_V2::steady_clock::now()@GLIBCXX_3.4.19'
/home/ubuntu/miniconda3/envs/alex/compiler_compat/ld: /usr/local/cuda/lib64/libcufile.so: undefined reference to `std::string::_M_replace_aux(unsigned long, unsigned long, unsigned long, char)@GLIBCXX_3.4'
/home/ubuntu/miniconda3/

--- Correctness Check ---
Prompt: <|im_start|>system

Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
<|im_end|>
<|im_start|>user
Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been 9 assignments and Ahmed has a 91 in the class. Emily has a 92. The final assignment is worth the same amount as all the other assignments. Emily got a 90 on the final assignment. What is the minimum grade Ahmed needs to get to beat Emily if all grades are whole numbers?<|im_end|>
<|im_start|>assistant

First Response: To find the minimum grade Ahmed needs to beat Emily, we first need to calculate the total points that both have already received and the points they need from the final assignment.

Ahmed's current total:
91 (for 9 assignments) * 10 points per assignment = 910 points

Emily's current total:
92 (for one assignment) * 10 points per assignment = 920 points

Emily's final assignment score:
920 + 90 = 1010 points


Step,Training Loss
1,0.0
2,-0.0029
3,0.0494


--- Correctness Check ---
Prompt: <|im_start|>system

Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
<|im_end|>
<|im_start|>user
Marie has 98 unread messages on her phone. She decides to clear them by reading 20 messages a day. However, she also gets 6 new messages a day. How many days will it take her to read all her unread messages?<|im_end|>
<|im_start|>assistant

First Response: To determine how many days it will take Marie to read all her unread messages, we can set up an equation. Let's denote the number of days it will take her to read all her messages as \( d \).

Each day, Marie reduces her unread messages count by \( 20 - 6 = 14 \) messages. Initially, she has 98 unread messages. After \( d \) days, she will have read \( 14d \) messages and the remaining unread messages will be \( 98 - 14d \). To have no unread messages left, the number of unread messages must be zero.

Therefore, we can set up the following equation:
\[ 98 - 14d = 0 \]



KeyboardInterrupt: 

Error in callback <bound method _WandbInit._post_run_cell_hook of <wandb.sdk.wandb_init._WandbInit object at 0x77aea6e22590>> (for post_run_cell), with arguments args (<ExecutionResult object at 77af5d6ea710, execution_count=7 error_before_exec=None error_in_exec= info=<ExecutionInfo object at 77af5d748fd0, raw_cell="class DebugGRPOTrainer(GRPOTrainer):
    def _gene.." transformed_cell="class DebugGRPOTrainer(GRPOTrainer):
    def _gene.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell://ssh-remote%2Balex_capacityblock0/home/ubuntu/alex/verifiers/test.ipynb#X13sdnNjb2RlLXJlbW90ZQ%3D%3D> result=None>,),kwargs {}:


BrokenPipeError: [Errno 32] Broken pipe

In [None]:
data = get_gsm8k_questions()
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
for i, item in enumerate(data):
    # Create the chat structure, same as in training
    chat = [
        {'role': 'system', 'content': SYSTEM_PROMPT},
        {'role': 'user', 'content': item["question"]},
    ]
    
    # Apply the template to get the correctly formatted prompt string
    formatted_prompt = tokenizer.apply_chat_template(
        chat,
        tokenize=False,
        add_generation_prompt=True # Adds the prompt for the assistant's turn
    )
    break

In [None]:
formatted_prompt

'<|im_start|>system\n\nRespond in the following format:\n<reasoning>\n...\n</reasoning>\n<answer>\n...\n</answer>\n<|im_end|>\n<|im_start|>user\nNatalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?<|im_end|>\n<|im_start|>assistant\n'