# Gemma 3: GRPO and RL 
This notebook uses `TRL` and `GRPOTrainer` to make Gemma3 think before it answers.
It generates a LoRA that can be adapted to Gemma3.

In [3]:
!pip install -qqq git+https://github.com/huggingface/transformers@v4.49.0-Gemma-3 \
                  git+https://github.com/huggingface/trl.git@main \
                  bitsandbytes

In [4]:
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
from peft import LoraConfig, get_peft_model
import types
import re
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer
from transformers import pipeline

# Model checkpoint
ckpt = "../../models/gemma-3-27b-pt"

# Load the model
model = AutoModelForImageTextToText.from_pretrained(
    ckpt, 
    device_map="auto", 
    torch_dtype=torch.bfloat16, 
    attn_implementation="eager"
)

# Add LoRA Adaptation
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

# Load processor and ensure tokenizer can handle <think> tokens
processor = AutoProcessor.from_pretrained(ckpt)

# Add <think> as special tokens to prevent tokenization issues
special_tokens = {"additional_special_tokens": ["<think>", "</think>", "<answer>", "</answer>"]}
processor.tokenizer.add_special_tokens(special_tokens)
model.resize_token_embeddings(len(processor.tokenizer))

Loading checkpoint shards:   0%|          | 0/12 [00:00<?, ?it/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.48, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


trainable params: 122,211,840 || all params: 27,554,618,480 || trainable%: 0.4435
None


Gemma3TextScaledWordEmbedding(262149, 5376, padding_idx=0)

In [5]:
import re
from datasets import load_dataset, Dataset

# Convert OpenAI GSM8k dataset into reasoning chains 
# https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb

SYSTEM_PROMPT = """
Respond in the following format:
<think>
...
</think>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<think>
{think}
</think>
<answer>
{answer}
</answer>
"""

# Extract final answer
def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

# Extract answer from GSM8k dataset
def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# Load OpenAI GSM8K dataset and format into <think> style
def get_gsm8k_questions(split="train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    })
    return data

# Load dataset and format labels for training
dataset = get_gsm8k_questions()
dataset = dataset.map(lambda example: {"labels": example["answer"]})

In [6]:
# Reward modelling 

# Reward functions for GRPO training
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

# Ensure tokenizer special tokens match the GRPO training format
processor.pad_token_id = 0
processor.bos_token_id = 1
processor.eos_token_id = 2

# Define chat template
chat_template = """<bos><start_of_turn>user
{user_message}<end_of_turn>
<start_of_turn>model
<think>{reasoning}</think>
<answer>{final_answer}</answer><end_of_turn>"""

# Apply chat template
processor.chat_template = chat_template

# Configure GRPO training
max_prompt_length = 256
max_seq_length = 1024

training_args = GRPOConfig(
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=1,
    num_generations=2,
    max_prompt_length=max_prompt_length,
    max_completion_length=max_seq_length - max_prompt_length,
    num_train_epochs=1,
    max_steps=250,
    save_steps=250,
    max_grad_norm=0.1,
    report_to="none",
)

# Train the model using GRPO
trainer = GRPOTrainer(
    model=model,
    processing_class=processor,
    reward_funcs=[
        soft_format_reward_func,
        strict_format_reward_func,
        correctness_reward_func,
    ],
    args=training_args,
    train_dataset=dataset,
)

trainer.train()


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


Step,Training Loss
1,0.0
2,0.0
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


TrainOutput(global_step=250, training_loss=0.0, metrics={'train_runtime': 27436.0421, 'train_samples_per_second': 0.018, 'train_steps_per_second': 0.009, 'total_flos': 0.0, 'train_loss': 0.0})

In [7]:
trainer.push_to_hub("tobrun/gemma3-27b-LoRA-reasoning")

adapter_model.safetensors:   0%|          | 0.00/489M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

training_args.bin:   0%|          | 0.00/5.94k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

Upload 4 LFS files:   0%|          | 0/4 [00:00<?, ?it/s]

CommitInfo(commit_url='https://huggingface.co/tobrun/trainer_output/commit/e1dd0de049dd1d5b7da426532340180dbb32a6bd', commit_message='tobrun/gemma3-27b-LoRA-reasoning', commit_description='', oid='e1dd0de049dd1d5b7da426532340180dbb32a6bd', pr_url=None, repo_url=RepoUrl('https://huggingface.co/tobrun/trainer_output', endpoint='https://huggingface.co', repo_type='model', repo_id='tobrun/trainer_output'), pr_revision=None, pr_num=None)

In [8]:
from transformers import pipeline

# Generate output from trained model
question = "You have two ropes that each take exactly one hour to burn from one end to the other. However, the ropes do not burn at a uniform rate. Using these ropes, how can you measure exactly 45 minutes?"
generator = pipeline("text-generation", model=trainer.model, tokenizer=processor.tokenizer)

# Format input properly for chat
input_text = processor.apply_chat_template([{"role": "user", "content": question}]) + "<think>"

# Generate response
output = generator(input_text, max_new_tokens=1024)
print(output)

Device set to use cuda:0
The model 'PeftModelForCausalLM' is not supported for text-generation. Supported models are ['AriaTextForCausalLM', 'BambaForCausalLM', 'BartForCausalLM', 'BertLMHeadModel', 'BertGenerationDecoder', 'BigBirdForCausalLM', 'BigBirdPegasusForCausalLM', 'BioGptForCausalLM', 'BlenderbotForCausalLM', 'BlenderbotSmallForCausalLM', 'BloomForCausalLM', 'CamembertForCausalLM', 'LlamaForCausalLM', 'CodeGenForCausalLM', 'CohereForCausalLM', 'Cohere2ForCausalLM', 'CpmAntForCausalLM', 'CTRLLMHeadModel', 'Data2VecTextForCausalLM', 'DbrxForCausalLM', 'DiffLlamaForCausalLM', 'ElectraForCausalLM', 'Emu3ForCausalLM', 'ErnieForCausalLM', 'FalconForCausalLM', 'FalconMambaForCausalLM', 'FuyuForCausalLM', 'GemmaForCausalLM', 'Gemma2ForCausalLM', 'Gemma3ForCausalLM', 'Gemma3ForCausalLM', 'GitForCausalLM', 'GlmForCausalLM', 'GotOcr2ForConditionalGeneration', 'GPT2LMHeadModel', 'GPT2LMHeadModel', 'GPTBigCodeForCausalLM', 'GPTNeoForCausalLM', 'GPTNeoXForCausalLM', 'GPTNeoXJapaneseForCaus

[{'generated_text': '<bos><start_of_turn>user\n{user_message}<end_of_turn>\n<start_of_turn>model\n<think>{reasoning}</think>\n<answer>{final_answer}</answer><end_of_turn><think>\n""")\n\n# Define the prompt template\nprompt_template = PromptTemplate(\n    input_variables=["user_message", "reasoning", "final_answer"],\n    template=prompt_template_str\n)\n\n# Create a chain\nchain = LLMChain(llm=llm, prompt=prompt_template)\n\n# Run the chain\nresponse = chain.run(user_message="What is the capital of France?", reasoning="Paris is the capital of France.", final_answer="Paris")\n\n# Print the response\nprint(response)\n'}]
