### Libraries, random seed, wandb installation

In [1]:
%%capture
!pip install uv
!uv pip install --no-deps unsloth vllm==0.8.5.post1
import sys, re, requests; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
!uv pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
!uv pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" huggingface_hub hf_transfer

f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
with open("vllm_requirements.txt", "wb") as file:
    file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
!uv pip install -r vllm_requirements.txt
!uv pip install pympler

In [2]:
import json
import os
import random
import wandb
from collections import defaultdict
import re
from scipy.stats import pearsonr
from pympler import asizeof
import matplotlib.pyplot as plt
import time
import numpy as np
import torch

In [3]:
seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ['PYTHONHASHSEED'] = str(seed)

### Unsloth

Load up `Qwen 2.5 0.5B Instruct`, and set parameters

In [4]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "Qwen/Qwen2.5-0.5B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.8, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = seed,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 07-22 07:20:29 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 07-22 07:20:29 [__init__.py:239] Automatically detected platform cuda.
Unsloth: Patching vLLM v1 graph capture
Unsloth: Patching vLLM v0 graph capture
==((====))==  Unsloth 2025.7.6: Fast Qwen2 patching. Transformers: 4.53.2. vLLM: 0.8.5.post1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/qwen2.5-0.5b-instruct-unsloth-bnb-4bit with actual GPU utilization = 79.24%
Unsloth: Your GPU has CUDA compute 

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/270 [00:00<?, ?B/s]

INFO 07-22 07:21:04 [cuda.py:240] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 07-22 07:21:04 [cuda.py:289] Using XFormers backend.
INFO 07-22 07:21:04 [parallel_state.py:1004] rank 0 in world size 1 is assigned as DP rank 0, PP rank 0, TP rank 0
INFO 07-22 07:21:04 [model_runner.py:1108] Starting to load model unsloth/qwen2.5-0.5b-instruct-unsloth-bnb-4bit...
INFO 07-22 07:21:05 [loader.py:1187] Loading weights with BitsAndBytes quantization. May take a while ...
INFO 07-22 07:21:05 [weight_utils.py:265] Using model weights format ['*.safetensors']


model.safetensors:   0%|          | 0.00/538M [00:00<?, ?B/s]

INFO 07-22 07:21:10 [weight_utils.py:281] Time spent downloading weights for unsloth/qwen2.5-0.5b-instruct-unsloth-bnb-4bit: 5.125827 seconds
INFO 07-22 07:21:11 [weight_utils.py:315] No model.safetensors.index.json found in remote.


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 07-22 07:21:11 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 07-22 07:21:12 [model_runner.py:1140] Model loading took 0.5747 GiB and 6.857794 seconds
INFO 07-22 07:21:20 [worker.py:287] Memory profiling takes 7.36 seconds
INFO 07-22 07:21:20 [worker.py:287] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.79) = 11.68GiB
INFO 07-22 07:21:20 [worker.py:287] model weights take 0.57GiB; non_torch_memory takes 0.03GiB; PyTorch activation peak memory takes 1.23GiB; the rest of the memory reserved for KV Cache is 9.85GiB.
INFO 07-22 07:21:20 [executor_base.py:112] # cuda blocks: 53816, # CPU blocks: 0
INFO 07-22 07:21:20 [executor_base.py:117] Maximum concurrency for 1024 tokens per request: 840.88x
INFO 07-22 07:21:20 [vllm_utils.py:669] Unsloth: Running patched vLLM v0 `capture_model`.
INFO 07-22 07:21:20 [model_runner.py:1450] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run t

Capturing CUDA graph shapes:   0%|          | 0/31 [00:00<?, ?it/s]

INFO 07-22 07:21:42 [model_runner.py:1592] Graph capturing finished in 22 secs, took 0.43 GiB
INFO 07-22 07:21:42 [vllm_utils.py:676] Unsloth: Patched vLLM v0 graph capture finished in 22 secs.
INFO 07-22 07:21:43 [llm_engine.py:437] init engine (profile, create kv cache, warmup model) took 31.18 seconds
Unsloth: Just some info: will skip parsing ['pre_feedforward_layernorm', 'q_norm', 'k_norm', 'post_feedforward_layernorm']
Unsloth: Just some info: will skip parsing ['pre_feedforward_layernorm', 'q_norm', 'k_norm', 'post_feedforward_layernorm']


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/614 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

Unsloth 2025.7.6 patched 24 layers with 24 QKV layers, 24 O layers and 24 MLP layers.


### Get eval dataset

In [5]:
from datasets import load_dataset

In [6]:
def extract_xml_answer(text: str) -> str:
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
        return answer
    except IndexError:
        return ""

def extract_xml_reasoning(text: str) -> str:
    try:
        answer = text.split("<reasoning>")[-1].split("</reasoning>")[0].strip()
        return answer
    except IndexError:
        return ""

In [7]:
dataset = load_dataset('json', data_files={
    # 'train': 'https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl',
    'test': 'https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl'
})['test']

Downloading data:   0%|          | 0.00/750k [00:00<?, ?B/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [8]:
R1_STYLE_SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "
    "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
    "The reasoning process and answer are enclosed within <reasoning> </reasoning> and "
    "<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n"
    "<answer> answer here </answer>."
)

TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single integer."

CONTENT_EXAMPLE = (
    "<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n"
    "<answer>4</answer>"
)

def preprocess_dataset(dataset, chunk_size=1000):
    total_samples = len(dataset)
    print(f"Loaded {total_samples} samples")

    def extract_hash_answer(text: str) -> str | None:
        try:
            return text.split("####")[1].strip()
        except IndexError:
            return None

    def process_batch(batch):
        prompts = [[
            {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
            {'role': 'user', 'content': "What is 2+2?"},
            {'role': 'assistant', 'content': CONTENT_EXAMPLE},
            {'role': 'user', 'content': q.strip()}
        ] for q in batch['question']]

        return {
            'prompt': prompts,
            'answer': batch['answer'],
            'num_answer': [extract_hash_answer(a) for a in batch['answer']],
        }

    return dataset.map(process_batch, batched=True, batch_size=chunk_size)

dataset = preprocess_dataset(dataset, chunk_size=500)

Loaded 1319 samples


Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [9]:
print(json.dumps(dataset[0], indent=4))

{
    "question": "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
    "answer": "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\u2019s market.\n#### 18",
    "prompt": [
        {
            "content": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and <answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n<answer> answer here </answer>.\nThe answer must be a single integer.",
            "role": "system"
   

### Eval after training

In [12]:
run_name = "initial_model"

In [13]:
from vllm import LLM, SamplingParams
from tqdm.notebook import tqdm
from datetime import datetime

In [14]:
batch_size=16
lora_path = run_name + "_grpo_saved_lora"

In [15]:
model.save_lora(lora_path)

In [16]:
sampling_params = SamplingParams(
    temperature=0.0,
    max_tokens=4096,
    stop_token_ids=[tokenizer.eos_token_id],
)

total_samples = len(dataset)
results = []
correct = 0
total = 0
len_sum = 0
len_min = 4096
len_max = 0
correct_format = 0

def format_reward_func(response) -> float:
    pattern = r"^<reasoning>.*?</reasoning>\s*<answer>\s*[+-]?\d*\.?\d+\s*</answer>$"
    matched = bool(re.match(pattern, response))
    return int(matched)

progress_bar = tqdm(
    total=total_samples,
    desc="Processing samples",
    unit="examples",
    dynamic_ncols=True,
)

progress_bar.set_postfix({
    'acc': '0.00%',
    'correct': '0',
})


for i in range(0, total_samples, batch_size):
    batch_data = dataset[i:i + batch_size]
    current_batch_size = len(batch_data['prompt'])

    formatted_prompts = [
        tokenizer.apply_chat_template(
            p,
            tokenize=False,
            add_generation_prompt=True
        )
        for p in batch_data['prompt']
    ]

    outputs = model.fast_generate(
        formatted_prompts,
        sampling_params = sampling_params,
        lora_request = model.load_lora(lora_path),
    )

    for j, output in enumerate(outputs):
        response = output.outputs[0].text
        response_len = len(tokenizer.encode(response))

        generated_answer = extract_xml_answer(response)
        true_answer = batch_data['num_answer'][j]

        result = {
            'question': batch_data['question'][j],
            'true_answer': true_answer,
            'generated_answer': generated_answer,
            'full_response': response,
            'correct': generated_answer == true_answer,
            "response_len": response_len
        }
        results.append(result)

        if generated_answer == true_answer:
            correct += 1
        total += 1
        len_sum += response_len
        len_min = min(len_min, response_len)
        len_max = max(len_max, response_len)
        correct_format += format_reward_func(response)

    progress_bar.update(current_batch_size)
    progress_bar.set_postfix({
        'acc': f'{(correct/total)*100:.2f}%',
        'correct': f'{correct}/{total}',
    })

progress_bar.close()

Processing samples:   0%|          | 0/1319 [00:00<?, ?examples/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/16 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/7 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

In [17]:
accuracy = correct / total if total > 0 else 0
metrics = {
    'accuracy': accuracy,
    'correct': correct,
    'format_accuracy': correct_format / total,
    'correct_format': correct_format,
    'total': total,
    'avg_len': len_sum / total,
    'len_min': len_min,
    'len_max': len_max,
    'timestamp': datetime.now().isoformat()
}

save_path = f"{run_name}_eval_results_{datetime.now().strftime('day_%d_time_%H_%M')}.json"
with open(save_path, 'w') as f:
    json.dump({
        'metrics': metrics,
        'results': results
    }, f, indent=4)
print(f"\nResults saved to {save_path}")

print("\nFinal Evaluation Results:")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Correct: {metrics['correct']}/{metrics['total']}")

print(f"Format accuracy: {metrics['format_accuracy']:.2%}")
print(f"Correct format: {metrics['correct_format']}/{metrics['total']}")


Results saved to initial_model_eval_results_day_22_time_07_36.json

Final Evaluation Results:
Accuracy: 9.55%
Correct: 126/1319
Format accuracy: 66.34%
Correct format: 875/1319


In [18]:
from google.colab import files
files.download(save_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

# This is auxiliary code (execution will cause an error)

### Model downloading

In [None]:
model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)

Found HuggingFace hub cache directory: /root/.cache/huggingface/hub
Checking cache directory for required files...
Cache check failed: model.safetensors not found in local cache.
Not all required files found in cache. Will proceed with downloading.


Unsloth: Merging weights into 16bit:   0%|          | 0/1 [00:00<?, ?it/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

Unsloth: Merging weights into 16bit: 100%|██████████| 1/1 [00:25<00:00, 25.91s/it]


In [None]:
!zip model.zip -r model/

  adding: model/ (stored 0%)
  adding: model/tokenizer_config.json (deflated 89%)
  adding: model/vocab.json (deflated 61%)
  adding: model/tokenizer.json (deflated 81%)
  adding: model/model.safetensors (deflated 12%)
  adding: model/merges.txt (deflated 57%)
  adding: model/chat_template.jinja (deflated 71%)
  adding: model/.cache/ (stored 0%)
  adding: model/.cache/huggingface/ (stored 0%)
  adding: model/.cache/huggingface/.gitignore (stored 0%)
  adding: model/.cache/huggingface/download/ (stored 0%)
  adding: model/.cache/huggingface/download/model.safetensors.lock (stored 0%)
  adding: model/.cache/huggingface/download/model.safetensors.metadata (deflated 30%)
  adding: model/config.json (deflated 72%)
  adding: model/added_tokens.json (deflated 67%)
  adding: model/special_tokens_map.json (deflated 69%)
  adding: model/generation_config.json (deflated 19%)


In [None]:
os.rename('model.zip', run_name + '.zip')

In [None]:
from google.colab import files
files.download(run_name + '.zip')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

### Model upload

In [None]:
# from google.colab import drive
# drive.mount('/content/drive')

In [None]:
# !unzip /content/drive/MyDrive/model.zip -d .

### Eval (after session restart)

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from tqdm.notebook import tqdm
import numpy as np
import re
import json
from datetime import datetime
import logging

#### Get eval dataset

In [None]:
from datasets import load_dataset

In [None]:
def extract_xml_answer(text: str) -> str:
    try:
        answer = text.split("<answer>")[-1].split("</answer>")[0].strip()
        return answer
    except IndexError:
        return ""

def extract_xml_reasoning(text: str) -> str:
    try:
        answer = text.split("<reasoning>")[-1].split("</reasoning>")[0].strip()
        return answer
    except IndexError:
        return ""

In [None]:
dataset = load_dataset('json', data_files={
    # 'train': 'https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl',
    'test': 'https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl'
})['test']

In [None]:
R1_STYLE_SYSTEM_PROMPT = (
    "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. "
    "The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
    "The reasoning process and answer are enclosed within <reasoning> </reasoning> and "
    "<answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n"
    "<answer> answer here </answer>."
)

TASK_SPECIFIC_INSTRUCTIONS = "The answer must be a single integer."

CONTENT_EXAMPLE = (
    "<reasoning>To calculate 2+2, we simply add the numbers together: 2 + 2 = 4.</reasoning>\n"
    "<answer>4</answer>"
)

def preprocess_dataset(dataset, chunk_size=1000):
    total_samples = len(dataset)
    print(f"Loaded {total_samples} samples")

    def extract_hash_answer(text: str) -> str | None:
        try:
            return text.split("####")[1].strip()
        except IndexError:
            return None

    def process_batch(batch):
        prompts = [[
            {'role': 'system', 'content': R1_STYLE_SYSTEM_PROMPT + "\n" + TASK_SPECIFIC_INSTRUCTIONS},
            {'role': 'user', 'content': "What is 2+2?"},
            {'role': 'assistant', 'content': CONTENT_EXAMPLE},
            {'role': 'user', 'content': q.strip()}
        ] for q in batch['question']]

        return {
            'prompt': prompts,
            'answer': batch['answer'],
            'num_answer': [extract_hash_answer(a) for a in batch['answer']],
        }

    return dataset.map(process_batch, batched=True, batch_size=chunk_size)

dataset = preprocess_dataset(dataset, chunk_size=500)

Loaded 1319 samples


Map:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [None]:
print(json.dumps(dataset[0], indent=4))

{
    "question": "Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
    "answer": "Janet sells 16 - 3 - 4 = <<16-3-4=9>>9 duck eggs a day.\nShe makes 9 * 2 = $<<9*2=18>>18 every day at the farmer\u2019s market.\n#### 18",
    "prompt": [
        {
            "content": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <reasoning> </reasoning> and <answer> </answer> tags, respectively, i.e., <reasoning> reasoning process here </reasoning>\n<answer> answer here </answer>.\nThe answer must be a single integer.",
            "role": "system"
   

#### Download the model

In [None]:
gpu_memory_utilization=0.8
batch_size=16

model_path="model/"

In [None]:
with tqdm(total=2, desc="Loading model components") as pbar:
    llm = LLM(
        model=model_path,
        dtype="half",
        gpu_memory_utilization=gpu_memory_utilization,
        max_model_len=4096,
        device="cuda:0",
        enable_chunked_prefill=True,
    )
    pbar.update(1)

    tokenizer = AutoTokenizer.from_pretrained(
        model_path,
        model_max_length=4096,
        padding_side='right',
        truncation_side='right'
    )
    pbar.update(1)

Loading model components:   0%|          | 0/2 [00:00<?, ?it/s]

INFO 07-21 14:22:54 [config.py:717] This model supports multiple tasks: {'generate', 'embed', 'score', 'reward', 'classify'}. Defaulting to 'generate'.
INFO 07-21 14:22:54 [config.py:2003] Chunked prefill is enabled with max_num_batched_tokens=2048.
INFO 07-21 14:22:54 [llm_engine.py:240] Initializing a V0 LLM engine (v0.8.5.post1) with config: model='model/', speculative_config=None, tokenizer='model/', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.float16, max_seq_len=4096, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=None, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model

Loading safetensors checkpoint shards:   0% Completed | 0/1 [00:00<?, ?it/s]


INFO 07-21 14:23:00 [loader.py:458] Loading weights took 3.60 seconds
INFO 07-21 14:23:01 [model_runner.py:1140] Model loading took 0.9245 GiB and 3.644791 seconds
INFO 07-21 14:23:03 [worker.py:287] Memory profiling takes 1.20 seconds
INFO 07-21 14:23:03 [worker.py:287] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.80) = 11.79GiB
INFO 07-21 14:23:03 [worker.py:287] model weights take 0.92GiB; non_torch_memory takes 0.00GiB; PyTorch activation peak memory takes 1.40GiB; the rest of the memory reserved for KV Cache is 9.47GiB.
INFO 07-21 14:23:04 [executor_base.py:112] # cuda blocks: 51725, # CPU blocks: 21845
INFO 07-21 14:23:04 [executor_base.py:117] Maximum concurrency for 4096 tokens per request: 202.05x


OutOfMemoryError: CUDA out of memory. Tried to allocate 408.00 MiB. GPU 0 has a total capacity of 14.74 GiB of which 14.12 MiB is free. Process 4955 has 14.72 GiB memory in use. Of the allocated memory 14.11 GiB is allocated by PyTorch, with 24.00 MiB allocated in private pools (e.g., CUDA Graphs), and 79.71 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
def format_reward_func(response) -> float:
    pattern = r"^<reasoning>.*?</reasoning>\s*<answer>\s*[+-]?\d*\.?\d+\s*</answer>$"
    matched = bool(re.match(pattern, response))
    return int(matched)

In [None]:
sampling_params = SamplingParams(
    temperature=0.0,
    max_tokens=4096,
    stop_token_ids=[tokenizer.eos_token_id],
)

total_samples = len(dataset)
results = []
correct = 0
total = 0
len_sum = 0
len_min = 4096
len_max = 0
correct_format = 0

progress_bar = tqdm(
    total=total_samples,
    desc="Processing samples",
    unit="examples",
    dynamic_ncols=True,
)

progress_bar.set_postfix({
    'acc': '0.00%',
    'correct': '0',
})


for i in range(0, total_samples, batch_size):
    batch_data = dataset[i:i + batch_size]
    current_batch_size = len(batch_data['prompt'])

    formatted_prompts = [
        tokenizer.apply_chat_template(
            p,
            tokenize=False,
            add_generation_prompt=True
        )
        for p in batch_data['prompt']
    ]

    # Generate responses
    outputs = llm.generate(
        formatted_prompts,
        sampling_params,
    )

    # Process responses
    for j, output in enumerate(outputs):
        response = output.outputs[0].text
        response_len = len(tokenizer.encode(response))

        generated_answer = extract_xml_answer(response)
        true_answer = batch_data['num_answer'][j]

        result = {
            'question': batch_data['question'][j],
            'true_answer': true_answer,
            'generated_answer': generated_answer,
            'full_response': response,
            'correct': generated_answer == true_answer,
            "response_len": response_len
        }
        results.append(result)

        if generated_answer == true_answer:
            correct += 1
        total += 1
        len_sum += response_len
        len_min = min(len_min, response_len)
        len_max = max(len_max, response_len)
        correct_format += format_reward_func(response)

    # Update progress
    progress_bar.update(current_batch_size)
    progress_bar.set_postfix({
        'acc': f'{(correct/total)*100:.2f}%',
        'correct': f'{correct}/{total}',
    })

progress_bar.close()

In [None]:
# Calculate metrics
accuracy = correct / total if total > 0 else 0
metrics = {
    'accuracy': accuracy,
    'correct': correct,
    'format_accuracy': correct_format / total,
    'correct_format': correct_format,
    'total': total,
    'avg_len': len_sum / total,
    'len_min': len_min,
    'len_max': len_max,
    'model_path': model_path,
    'timestamp': datetime.now().isoformat()
}

save_path = f"gsm8k_eval_results_{datetime.now().strftime('day_%d_time_%H_%M')}.json"
with open(save_path, 'w') as f:
    json.dump({
        'metrics': metrics,
        'results': results
    }, f, indent=4)
print(f"\nResults saved to {save_path}")

print("\nFinal Evaluation Results:")
print(f"Accuracy: {metrics['accuracy']:.2%}")
print(f"Correct: {metrics['correct']}/{metrics['total']}")

print(f"Format accuracy: {metrics['format_accuracy']:.2%}")
print(f"Correct format: {metrics['correct_format']}/{metrics['total']}")

In [None]:
from google.colab import files
files.download(save_path)

### Manual Inference

In [None]:
text = [tokenizer.apply_chat_template([
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)]
text.append(text[0])

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = None,
)

output

In [None]:
output[1].outputs[0].text

And now with the LoRA we just trained with GRPO

In [None]:
model.save_lora("grpo_saved_lora")

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : R1_STYLE_SYSTEM_PROMPT},
    {"role" : "user", "content" : "How many r's are in strawberry?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output