# Train Qwen/TinyLlama on GSM8K using GRPO (Colab Version)

This notebook replicates the local training setup for GSM8K using TRL's GRPOTrainer.
It is self-contained and includes fixes for the `RuntimeError` by pinning `trl` version.

In [1]:
# Install dependencies
# Pinning trl to 0.12.0 to avoid regression issues
!pip install transformers torch accelerate datasets textual
!pip install trl==0.27.0
!pip install math_verify # Assuming this is the correct package name
# !pip install git+https://github.com/huggingface/trl.git@v0.12.0  # Force install from tag if pip fails or verify version

Collecting textual
  Downloading textual-7.3.0-py3-none-any.whl.metadata (9.1 kB)
Collecting rich>=14.2.0 (from textual)
  Downloading rich-14.2.0-py3-none-any.whl.metadata (18 kB)
Downloading textual-7.3.0-py3-none-any.whl (716 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m716.4/716.4 kB[0m [31m17.4 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading rich-14.2.0-py3-none-any.whl (243 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m243.4/243.4 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: rich, textual
  Attempting uninstall: rich
    Found existing installation: rich 13.9.4
    Uninstalling rich-13.9.4:
      Successfully uninstalled rich-13.9.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.31.0 requires rich<14,>=12.4.4, but you have rich 14.2.0 which is inco

In [2]:
from huggingface_hub import login
# Login to verify access to gated models or push access
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
import sys
import torch
import yaml
import re
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer
from datasets import load_dataset
# Import math_verify. If this fails, please check the package name or local installation.
from math_verify import LatexExtractionConfig, parse, verify, StringExtractionConfig, ExprExtractionConfig
from abc import ABC, abstractmethod

In [4]:
!pip show trl

Name: trl
Version: 0.27.0
Summary: Train transformer language models with reinforcement learning.
Home-page: https://github.com/huggingface/trl
Author: 
Author-email: Leandro von Werra <leandro.vonwerra@gmail.com>
License: 
Location: /usr/local/lib/python3.12/dist-packages
Requires: accelerate, datasets, packaging, transformers
Required-by: 


## Environment Definitions

In [None]:
class BaseEnvironment(ABC):
    @abstractmethod
    def get_dataset(self, config):
        """
        Loads and processes the dataset.
        Returns a huggingface Dataset object.
        """
        pass

    @abstractmethod
    def get_reward_functions(self):
        """
        Returns a list of reward functions.
        Each reward function should take (completions, **kwargs) and return a list of scores.
        """
        pass
    
    @abstractmethod
    def get_system_prompt(self):
        """
        Returns the system prompt to be used for this environment.
        """
        pass

In [6]:
class GSM8KEnvironment(BaseEnvironment):
    def __init__(self, config):
        self.config = config

    def get_system_prompt(self):
        return self.config.get('system_prompt', "")

    def make_conversation(self, example):
        system_prompt = self.get_system_prompt()
        prompt_column = self.config['data']['prompt_column']
        return {
            "prompt": [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": example[prompt_column]},
            ],
        }

    def get_dataset(self, config):
        print(f"Loading dataset: {config['data']['dataset_name']}")
        dataset = load_dataset(
            config['data']['dataset_name'], 
            config['data'].get('subset', None),
            split=config['data'].get('split', 'train')
        )
        
        # Use lambda to pass self.make_conversation
        dataset = dataset.map(lambda x: self.make_conversation(x))
        return dataset

    def format_reward(self, completions, **kwargs):
        """Reward function that checks if the completion has a specific format."""
        pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
        completion_contents = [completion[0]["content"] for completion in completions]
        matches = [re.match(pattern, content, re.DOTALL) for content in completion_contents]
        return [1.0 if match else 0.0 for match in matches]

    def accuracy_reward(self, completions, **kwargs):
        """Reward function that checks if the completion is the same as the ground truth (GSM8K specific)."""
        solutions = kwargs['answer'] # In GSM8K, the ground truth is in the 'answer' column
        completion_contents = [completion[0]["content"] for completion in completions]
        rewards = []
        for content, solution in zip(completion_contents, solutions):
            # Extract the ground truth value from GSM8K format (after ####)
            if isinstance(solution, str):
                 gold_answer_match = solution.split("####")
                 if len(gold_answer_match) > 1:
                     gold_answer = gold_answer_match[1].strip()
                 else:
                     gold_answer = solution.strip() # Fallback
            else:
                 gold_answer = str(solution)
    
            gold_parsed = parse(gold_answer, extraction_mode="first_match", extraction_config=[LatexExtractionConfig(),ExprExtractionConfig(),StringExtractionConfig()])
    
            # Extract answer from the model completion (inside <answer> tags)
            answer_match = re.search(r"<answer>(.*?)</answer>", content, re.DOTALL)
    
            if answer_match:
                answer_content = answer_match.group(1).strip()
                answer_parsed = parse(answer_content, extraction_mode="first_match", extraction_config=[LatexExtractionConfig(),ExprExtractionConfig(),StringExtractionConfig()])
    
                if len(gold_parsed) != 0:
                    try:
                        rewards.append(float(verify(answer_parsed, gold_parsed)))
                    except Exception:
                        rewards.append(0.0)
                else:
                    rewards.append(1.0) 
            else:
                rewards.append(0.0) # No answer tag found
                
        return rewards

    def get_reward_functions(self):
        return [self.format_reward, self.accuracy_reward]

## Configuration
Adapted from `configs/config_gsm8k.yaml` for Colab usage (GPU enabled).

In [7]:
config = {
    "model": {
        "name_or_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "torch_dtype": "auto", # Let TRL decide, or use torch.float16 for T4
        "device_map": "auto"   # Use GPU
    },
    "data": {
        "dataset_name": "openai/gsm8k",
        "subset": "main",
        "split": "train",
        "prompt_column": "question",
        "answer_column": "answer"
    },
    "training": {
        "output_dir": "Qwen2-0.5B-GRPO-General",
        "learning_rate": 1.0e-5,
        "gradient_accumulation_steps": 16,
        "num_train_epochs": 1,
        "bf16": False, # T4 does not support bf16 well, use fp16 if needed
        "logging_steps": 1,
        "save_strategy": "steps",
        "save_steps": 10,
        "max_steps": 20,
        "report_to": ["tensorboard"],
        "per_device_train_batch_size": 1
    },
    "generation": {
        "max_completion_length": 256,
        "temperature": 0.7,
        "num_generations": 4,
        "max_prompt_length": 512
    },
    "system_prompt": "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>\n"
}

## Training Loop

In [9]:
# Initialize Environment
print("Initializing Environment...")
# Direct instantiation since we define it inline
env = GSM8KEnvironment(config)

# Load and Process Dataset
dataset = env.get_dataset(config)

# Reward Functions
reward_funcs = env.get_reward_functions()

# Load Model
print(f"Loading model: {config['model']['name_or_path']}")

model = AutoModelForCausalLM.from_pretrained(
    config['model']['name_or_path'],
    torch_dtype="auto",
    device_map=config['model'].get('device_map', 'auto'),
    use_cache=False,
)

# Training Arguments
print("Configuring training arguments...")
training_conf = config['training']

training_args = GRPOConfig(
    output_dir=training_conf['output_dir'],
    learning_rate=float(training_conf['learning_rate']),
    remove_unused_columns=False,
    gradient_accumulation_steps=training_conf.get('gradient_accumulation_steps', 1),
    num_train_epochs=training_conf.get('num_train_epochs', 1),
    bf16=training_conf.get('bf16', True),
    max_completion_length=config['generation']['max_completion_length'],
    num_generations=config['generation']['num_generations'],
    max_prompt_length=config['generation'].get('max_prompt_length', 128),
    report_to=training_conf.get('report_to', []),
    logging_steps=training_conf.get('logging_steps', 1),
    push_to_hub=training_conf.get('push_to_hub', False),
    save_strategy=training_conf.get('save_strategy', 'steps'),
    save_steps=training_conf.get('save_steps', 10),
    max_steps=training_conf.get('max_steps', 20),
    temperature=config['generation'].get('temperature', 0.7),
    per_device_train_batch_size=training_conf.get('per_device_train_batch_size', 1),
)

# Trainer
print("Initializing GRPOTrainer...")
trainer = GRPOTrainer(
    model=model,
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=dataset,
)

# Train
print("Starting training...")
trainer.train()

# Save Model
trainer.save_model(config['training']['output_dir'])
print(f"Model saved to {config['training']['output_dir']}")

Initializing Environment...
Loading dataset: openai/gsm8k


Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

Loading model: TinyLlama/TinyLlama-1.1B-Chat-v1.0


config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Configuring training arguments...
Initializing GRPOTrainer...




tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'pad_token_id': 2}.


Starting training...


`generation_config` default values have been modified to match model-specific defaults: {'max_length': 2048}. If this is not desired, please set these values explicitly.


Step,Training Loss
1,0.0
2,0.0
3,0.0
4,0.0
5,0.0
6,0.0
7,0.0
8,0.0
9,0.0
10,0.0


Model saved to Qwen2-0.5B-GRPO-General


In [None]:
import torch
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(0))
else:
    print("CPU")

Tesla T4


In [None]:
!nvidia-smi

Fri Jan 23 23:55:46 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   70C    P0             28W /   70W |    9652MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
from google.colab import drive
drive.mount('/content/drive')
# Copy the output folder to your Drive (adjust path as needed)
!cp -r Qwen2-0.5B-GRPO-General /content/drive/MyDrive/
print("Model copied to Google Drive!")

ValueError: mount failed