In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from dotenv import load_dotenv, find_dotenv

_ = load_dotenv(find_dotenv())

os.environ["HF_ACCESS_TOKEN"] = os.environ["HUGGINGFACE_ACCESS_TOKEN"]

## Loading the model and tokenizer we're going to use

In [3]:
import multiprocessing as mp
mp.set_start_method("fork")

In [5]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 2048 #Increase this for longer reasoning traces
lora_rank =32 #larger ranks are smarter but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/Llama-3.2-1B-Instruct",
    max_seq_length=max_seq_length,
    load_in_4bit=True,
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.8, #reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank, #8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth", #enable long context finetuning
    random_state=2025
)

==((====))==  Unsloth 2025.4.7: Fast Llama patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA A10G. Num GPUs = 1. Max memory: 22.184 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.4.7 patched 16 layers with 16 QKV layers, 16 O layers and 16 MLP layers.


## Using OpenAI's famous GSM8K dataset

In [6]:
from datasets import load_dataset
dataset = load_dataset("openai/gsm8k", "main", split = "train")
dataset

Dataset({
    features: ['question', 'answer'],
    num_rows: 7473
})

In [7]:
from IPython.display import display, Markdown

display(Markdown(dataset[0]["question"]))

Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?

In [8]:
display(Markdown(dataset[0]["answer"]))

Natalia sold 48/2 = <<48/2=24>>24 clips in May.
Natalia sold 48+24 = <<48+24=72>>72 clips altogether in April and May.
#### 72

In [9]:
def extract_hash_answer(text):
    if "####" not in text: return None
    return text.split("####")[1].strip()
extract_hash_answer(dataset[0]["answer"])

'72'

In [10]:
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

system_prompt = \
f"""You are given a problem.
Think about the problem and provide your working out.
Place it between {reasoning_start} and {reasoning_end}.
Then, provide your solution between {solution_start}{solution_end}"""

In [11]:
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": x["question"]},
    ],
    "answer": extract_hash_answer(x["answer"]),
})
dataset[0]

{'question': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
 'answer': '72',
 'prompt': [{'content': 'You are given a problem.\nThink about the problem and provide your working out.\nPlace it between <start_working_out> and <end_working_out>.\nThen, provide your solution between <SOLUTION></SOLUTION>',
   'role': 'system'},
  {'content': 'Natalia sold clips to 48 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?',
   'role': 'user'}]}

In [12]:
import re

match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

In [13]:
match_format.search(
    "<start_working_out>Let me think!<end_working_out>"\
    "<SOLUTION>2</SOLUTION>",
)

<re.Match object; span=(0, 71), match='<start_working_out>Let me think!<end_working_out>>

## Defining the reward functions

In [14]:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.5 if response.count(reasoning_start) == 1 else -1.0
        score += 0.5 if response.count(reasoning_end)   == 1 else -1.0
        score += 0.5 if response.count(solution_start)  == 1 else -1.0
        score += 0.5 if response.count(solution_end)    == 1 else -1.0
        scores.append(score)
    return scores

def check_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        # Correct answer gets 3 points!
        if guess == true_answer:
            score += 3.0
        # Match if spaces are seen, but less reward
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            # We also reward it if the answer is close via ratios!
            # Ie if the answer is within some range, reward it!
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 1.0
                elif ratio >= 0.8 and ratio <= 1.2: score += 0.5
                else: score -= 1.5 # Penalize wrong answers
            except:
                score -= 1.5 # Penalize
        scores.append(score)
    return scores

Also sometimes it might not be 1 number as the answer, but like a sentence for example "The solution is $20" -> we extract 20.

We also remove possible commas for example as in 123,456

In [15]:
match_numbers = re.compile(
    solution_start + r".*?([\d\.\,]{1,})",
    flags = re.MULTILINE | re.DOTALL
)
print(match_numbers.findall("<SOLUTION>  0.34  </SOLUTION>"))
print(match_numbers.findall("<SOLUTION>  123,456  </SOLUTION>"))

['0.34']
['123,456']


In [16]:
global PRINTED_TIMES
PRINTED_TIMES = 0
global PRINT_EVERY_STEPS
PRINT_EVERY_STEPS = 5

def check_numbers(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    # Print only every few steps
    global PRINTED_TIMES
    global PRINT_EVERY_STEPS
    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
        print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    PRINTED_TIMES += 1

    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        # Convert to numbers
        try:
            true_answer = float(true_answer.strip())
            # Remove commas like in 123,456
            guess       = float(guess.strip().replace(",", ""))
            scores.append(1.5 if guess == true_answer else -0.5)
        except:
            scores.append(0)
            continue
    return scores

Get the max prompt length so we don't accidentally truncate it!

In [17]:
max(dataset.map(
    lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
    batched = True,
).map(lambda x: {"length" : len(x["tokens"])})["length"])

287

## Train the model

In [18]:
max_prompt_length = 287 + 1 # + 1 just in case!

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 500,
    max_grad_norm = 1.0,
    run_name="grpo_take6",
    report_to = "wandb", 
    output_dir = "grpo_outputs_take6",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


This cell will probably take around half a workday (around 4hrs) to run!

In [19]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args = training_args,
    train_dataset = dataset
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 7,473 | Num Epochs = 1 | Total steps = 500
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 22,544,384/1,000,000,000 (2.25% trained)
[34m[1mwandb[0m: Currently logged in as: [33mtituslhy[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


`generation_config` default values have been modified to match model-specific defaults: {'max_length': 131072, 'top_p': 0.9, 'bos_token_id': 128000, 'eos_token_id': [128001, 128008, 128009]}. If this is not desired, please set these values explicitly.


******************** Question:
A concert ticket costs $40. Mr. Benson bought 12 tickets and received a 5% discount for every ticket bought that exceeds 10. How much did Mr. Benson pay in all? 
Answer:
476 
Response:
<start_working_out>

Mr. Benson bought 12 tickets.

He received a 5% discount for every ticket bought that exceeds 10.
- For the 11th ticket, he received a 5% discount (5% of $40 = $2) = $40 - $2 = $38.
- For the 12th ticket, he received a 5% discount (5% of $40 = $2) = $40 - $2 = $38.
- For the 13th ticket, he received a 5% discount (5% of $40 = $2) = $40 - $2 = $38.

He bought 13 tickets in total.

The total cost is:
12 tickets @ $40 = $480
11 tickets @ $38 = $418
13 tickets @ $38 = $488

The total cost is $481.

</end_working_out>

<end_working_out>
SOLUTION>

Mr. Benson paid a total of $481 for 13 tickets. 
Extracted:
None
Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / match_format_exactly,rewards / match_format_approximately,rewards / check_answer,rewards / check_numbers
1,-0.0,-1.9375,1.076865,224.375,0.0,0.0,-1.9375,0.0,0.0
2,-0.0,-3.53125,0.493686,181.875,0.0,0.0,-3.53125,0.0,0.0
3,0.0,-2.59375,0.493686,192.875,0.000298,0.0,-2.59375,0.0,0.0
4,0.0,-2.5,1.09307,225.9375,0.00029,0.0,-2.5,0.0,0.0
5,0.0,-2.875,0.808013,187.625,0.000259,0.0,-2.875,0.0,0.0
6,0.0,-3.34375,0.5625,138.5625,0.000229,0.0,-3.34375,0.0,0.0
7,0.0,-2.59375,1.243686,172.25,0.000325,0.0,-2.59375,0.0,0.0
8,0.0,-2.5,1.09307,161.3125,0.000286,0.0,-2.5,0.0,0.0
9,0.0,-2.6875,0.987372,163.5,0.000257,0.0,-2.6875,0.0,0.0
10,0.0,-2.96875,1.3125,155.75,0.000332,0.0,-2.96875,0.0,0.0


******************** Question:
Rene can finish reading 30 pages in 60 minutes. Lulu can read 27 pages in 60 minutes and Cherry can read 25 pages in 60 minutes. If they have been reading for 240 minutes now, how many pages have they finished reading in total? 
Answer:
328 
Response:
To find the total number of pages read by Rene, Lulu, and Cherry, we need to first calculate the number of pages they can read in 240 minutes.

Rene can read 30 pages in 60 minutes, so in 120 minutes (240 minutes / 2), she can read 30 * 2 = 60 pages.
Lulu can read 27 pages in 60 minutes, so in 120 minutes (240 minutes / 2), she can read 27 * 2 = 54 pages.
Cherry can read 25 pages in 60 minutes, so in 120 minutes (240 minutes / 2), she can read 25 * 2 = 50 pages.

Total number of pages read = 60 + 54 + 50 = 164

So, after 240 minutes, Rene, Lulu, and Cherry have finished reading a total of 164 pages. 
Extracted:
None
******************** Question:
A jar of jellybeans has 14 blue jellybeans, 26 purple jellybea

KeyboardInterrupt: 