<a href="https://colab.research.google.com/github/samipn/unsloth.ai_demo/blob/main/colab4_grpo_gsm8k_gemma1b.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Colab 4 â€” GRPO (self-play reasoning)
Train a reasoning model with **GRPO** on GSM8K prompts. Model: `unsloth/gemma-3-1b-it-bnb-4bit`. We define rewards for format and final-answer accuracy.

In [1]:
#@title Install Unsloth + deps (Colab-safe)
%pip -q install --upgrade pip
%pip -q install unsloth datasets trl transformers accelerate bitsandbytes peft --no-cache-dir
import torch, platform
print("PyTorch:", torch.__version__, "CUDA:", torch.version.cuda, "Python:", platform.python_version())


PyTorch: 2.8.0+cu126 CUDA: 12.6 Python: 3.12.12


In [2]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch

MODEL_NAME = "unsloth/gemma-3-1b-it-bnb-4bit"
max_seq_length = 2048
dtype = torch.bfloat16 if is_bfloat16_supported() else torch.float16

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name=MODEL_NAME,
    max_seq_length=max_seq_length,
    dtype=dtype,
    load_in_4bit=True,
)

model = FastLanguageModel.get_peft_model(
    model,
    r=16, lora_alpha=16, lora_dropout=0.0,
    target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
    use_gradient_checkpointing="unsloth",
    random_state=3407, max_seq_length=max_seq_length,
)


ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.11.2: Fast Gemma3 patching. Transformers: 4.57.1.
   \\   /|    NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.318 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.8.0+cu126. CUDA: 8.0. CUDA Toolkit: 12.6. Triton: 3.4.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.32.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Gemma3 does not support SDPA - switching to fast eager.


model.safetensors:   0%|          | 0.00/965M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/233 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

Unsloth: Making `model.base_model.model.model` require gradients


In [13]:
# Prepare GSM8K prompts
from datasets import load_dataset

gsm8k = load_dataset("gsm8k", "main", split="train[:200]")  # small slice for speed

SYSTEM = "You are a step-by-step math tutor. Think aloud inside <reasoning> tags, then give the final boxed answer in <answer> tags."
def to_prompt(ex):
    q = ex["question"].strip()
    messages = [{"role":"system","content":SYSTEM},{"role":"user","content":q}]
    prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    # we also keep ground truth answer for reward
    return {"prompt": prompt, "solution": ex["answer"]}

ds = gsm8k.map(to_prompt, remove_columns=gsm8k.column_names)

Map:   0%|          | 0/200 [00:00<?, ? examples/s]

In [14]:
import re, torch
from trl import GRPOTrainer, GRPOConfig

def extract_answer(text):
    # look for \boxed{...} or <answer>...</answer>
    m = re.search(r"\\boxed\{([^}]+)\}|<answer>(.*?)</answer>", text, flags=re.S|re.I)
    if not m: return None
    return m.group(1) if m.group(1) is not None else m.group(2)

def format_reward(text):
    return 1.0 if ("<reasoning>" in text and "</reasoning>" in text and "<answer>" in text and "</answer>" in text) else 0.0

def accuracy_reward(text, target):
    pred = extract_answer(text)
    # crude normalization (remove punctuation/spaces)
    if pred is None: return 0.0
    P = re.sub(r"[^0-9.-]", "", pred)
    T = re.sub(r"[^0-9.-]", "", target)
    return 1.0 if P == T and P != "" else 0.0

def reward_fn(prompts, completions, completion_ids, solution, trainer_state):
    # completions: list[str] are the generated samples
    # prompts: list[str] are the original prompts
    # completion_ids: list[torch.Tensor] are the tokenized completion ids
    # solution: list[str] are the ground truth answers from info_column
    # trainer_state: dict containing information about the current training state (e.g., global_step)
    rewards = []
    for c, s in zip(completions, solution):
        r = 0.6 * format_reward(c) + 0.4 * accuracy_reward(c, s)
        rewards.append(r)
    return rewards

cfg = GRPOConfig(
    output_dir="outputs_grpo_gemma1b",
    per_device_train_batch_size=1,   # GRPO samples multiple candidates / step
    gradient_accumulation_steps=2,
    num_train_epochs=1,
    learning_rate=5e-6,
    max_completion_length=256,
    logging_steps=10,
    save_steps=200,
    bf16=torch.cuda.is_bf16_supported(),
    fp16=not torch.cuda.is_bf16_supported(),
    num_generations=2,   # group size
)

trainer = GRPOTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=ds,
    reward_funcs=[reward_fn],
    args=cfg,
    dataset_num_proc=1,
    prompt_column="prompt",
    info_column="solution",
)
trainer.train()
trainer.save_model("gemma1b_grpo_adapter")
tokenizer.save_pretrained("gemma1b_grpo_adapter")

The model is already on multiple devices. Skipping the move to device specified in `args`.


Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 2


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 200 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 2
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 2 x 1) = 4
 "-____-"     Trainable parameters = 13,045,760 of 1,012,931,712 (1.29% trained)


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,sampling / sampling_logp_difference / mean,sampling / sampling_logp_difference / max,sampling / importance_sampling_ratio / min,sampling / importance_sampling_ratio / mean,sampling / importance_sampling_ratio / max,kl,rewards / reward_fn / mean,rewards / reward_fn / std
10,0.0,0.0,0.0,245.975,226.5,256.0,0.75,132.05,124.1,140.0,0,0,0,0,0,0.001483,0.0,0.0
20,0.0,0.0,0.0,231.9,203.0,255.2,0.525,164.783336,151.8,174.4,No Log,No Log,No Log,No Log,No Log,0.001859,0.0,0.0
30,0.0,0.0,0.0,243.525,220.5,256.0,0.725,147.6,143.7,151.5,No Log,No Log,No Log,No Log,No Log,0.001595,0.0,0.0
40,0.0,0.0,0.0,238.875,205.5,255.7,0.625,144.941669,128.7,155.4,No Log,No Log,No Log,No Log,No Log,0.001758,0.0,0.0
50,0.0,0.0,0.0,239.075,217.6,256.0,0.7,121.233334,115.2,127.2,No Log,No Log,No Log,No Log,No Log,0.001506,0.0,0.0
60,0.0,0.0,0.0,241.15,219.5,256.0,0.75,122.1,117.1,126.6,No Log,No Log,No Log,No Log,No Log,0.001435,0.0,0.0
70,0.0,0.0,0.0,236.45,205.0,256.0,0.65,191.866667,179.4,204.2,No Log,No Log,No Log,No Log,No Log,0.002333,0.0,0.0
80,0.0,0.0,0.0,248.1,230.6,256.0,0.7,136.366667,128.2,144.5,No Log,No Log,No Log,No Log,No Log,0.001558,0.0,0.0
90,0.0,0.0,0.0,237.0,207.4,256.0,0.725,134.85,130.6,139.1,No Log,No Log,No Log,No Log,No Log,0.001505,0.0,0.0
100,0.0,0.0,0.0,234.425,195.9,255.7,0.625,157.5,144.7,166.8,No Log,No Log,No Log,No Log,No Log,0.001531,0.0,0.0


('gemma1b_grpo_adapter/tokenizer_config.json',
 'gemma1b_grpo_adapter/special_tokens_map.json',
 'gemma1b_grpo_adapter/chat_template.jinja',
 'gemma1b_grpo_adapter/tokenizer.model',
 'gemma1b_grpo_adapter/added_tokens.json',
 'gemma1b_grpo_adapter/tokenizer.json')

In [15]:
# Quick inference helper
from unsloth import FastLanguageModel
import torch

FastLanguageModel.for_inference(model)  # enables 2x faster kernels (no change to outputs)

def chat(prompt, history=None, max_new_tokens=128):
    if history is None: history = []
    messages = [{"role": "system", "content": system_prompt}] + history + [{"role": "user", "content": prompt}]
    text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = tokenizer([text], return_tensors="pt").to(model.device)
    with torch.no_grad():
        out = model.generate(**inputs, max_new_tokens=max_new_tokens)
    print(tokenizer.decode(out[0], skip_special_tokens=True))

system_prompt = "You are a helpful assistant."
chat("Say hi in one sentence.")


user
You are a helpful assistant.

Say hi in one sentence.
model
Hello there! How can I help you today?
