# Finetune LLMs with GRPO

This notebook shows how to finetune an LLM with GRPO, using the `trl` library.

It's by [Ben Burtenshaw](https://huggingface.co/burtenshaw) and [Maxime Labonne](https://huggingface.co/mlabonne).

This is a minimal example. For a complete example, refer to the GRPO chapter in the [course](https://huggingface.co/course/en/chapter12/1).

## Install dependencies

In [1]:
!pip install -qqq datasets==3.2.0 transformers==4.47.1 trl==0.14.0 peft==0.14.0 accelerate==1.2.1 bitsandbytes==0.45.2 wandb==0.19.7 --progress-bar off
!pip install -qqq flash-attn --no-build-isolation --progress-bar off

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.[0m[31m
[0m  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone


## Load Dataset

In [2]:
import torch
import wandb
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer
from trl import GRPOConfig, GRPOTrainer

# Log to Weights & Biases
wandb.login()

# Load dataset
dataset_id = "mlabonne/smoltldr"
dataset = load_dataset(dataset_id)
print(dataset)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33meagle0504[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


README.md:   0%|          | 0.00/981 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.44M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/151k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/151k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2000 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/200 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/200 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
    test: Dataset({
        features: ['prompt', 'completion'],
        num_rows: 200
    })
})


## Load Model

In [3]:
# Load model
model_id = "HuggingFaceTB/SmolLM-135M-Instruct"
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map="auto",
    attn_implementation="flash_attention_2",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# Load LoRA
lora_config = LoraConfig(
    task_type="CAUSAL_LM",
    r=16,
    lora_alpha=32,
    target_modules="all-linear",
)
model = get_peft_model(model, lora_config)
print(model.print_trainable_parameters())

config.json:   0%|          | 0.00/723 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/3.59k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/801k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/466k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.10M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/565 [00:00<?, ?B/s]

trainable params: 4,884,480 || all params: 139,399,488 || trainable%: 3.5039
None


## Define Reward Function

In [4]:
# Reward function
def reward_len(completions, **kwargs):
    return [-abs(50 - len(completion)) for completion in completions]

In [5]:
import re


def reward_format(completions, **kwargs):
    pattern = r"^<think>.*?</think><answer>.*?</answer>$"
    return [1.0 if re.match(pattern, c) else 0.0 for c in completions]

## Define Training Arguments

In [6]:
%%time

# Training arguments
training_args = GRPOConfig(
    output_dir="GRPO",  # Directory where the model outputs will be saved.
    learning_rate=2e-5,  # Sets the learning rate for training.
    per_device_train_batch_size=2,  # Number of examples per training step per device.
    gradient_accumulation_steps=1,  # Number of steps to accumulate gradients before updating model weights.
    max_prompt_length=512,  # Maximum number of tokens for the prompts.
    max_completion_length=96,  # Maximum number of tokens for the completions.
    num_generations=8,  # Number of sequences to generate at each decoding step.
    optim="adamw_8bit",  # Optimization algorithm with 8-bit precision.
    num_train_epochs=1,  # Total number of training epochs.
    bf16=True,  # Use bfloat16 mixed precision training.
    report_to=["wandb"],  # Reporting and logging the outputs to Weights & Biases dashboard.
    remove_unused_columns=False,  # Whether to remove columns that are not used by the model.
    logging_steps=1,  # Frequency of logging training progress per number of steps.
)

# Trainer
trainer = GRPOTrainer(
    model=model,  # The model to be trained.
    reward_funcs=[
        reward_len,  # Reward function to optimize the length of the outputs.
        reward_format  # Reward function to optimize the format of the outputs.
    ],  # List of reward functions to guide training with multiple optimization targets.
    args=training_args,  # Training arguments defined above.
    train_dataset=dataset["train"],  # Training dataset.
)

# Initialize Weights and Biases for tracking experiments
wandb.init(project="GRPO")

# Start training the model
trainer.train()



Step,Training Loss
1,-0.0
2,0.0
3,0.0001
4,0.0
5,0.0
6,0.0001
7,0.0001
8,0.0001
9,0.0
10,0.0001


CPU times: user 1h 16min 45s, sys: 7.54 s, total: 1h 16min 52s
Wall time: 1h 16min 43s


TrainOutput(global_step=1000, training_loss=0.028103687996044756, metrics={'train_runtime': 4601.9877, 'train_samples_per_second': 0.435, 'train_steps_per_second': 0.217, 'total_flos': 0.0, 'train_loss': 0.028103687996044756})

## Push Model to Hub

In [7]:
dataset_id = dataset_id.replace("/", "-")
model_id = model_id.replace("/", "-")

print(dataset_id)
print(model_id)

mlabonne-smoltldr
HuggingFaceTB-SmolLM-135M-Instruct


In [8]:
# Save model
hf_username = "eagle0504"
repo_name = f"{hf_username}/finetune-{dataset_id}-using-{model_id}"

# Merge model layers if needed and unload from memory (specific to certain training setups)
merged_model = trainer.model.merge_and_unload()

# Push the merged model to the Hugging Face Hub
merged_model.push_to_hub(repo_name, private=False)

# Push the tokenizer to the same repository on the Hub
tokenizer.push_to_hub(repo_name, private=False)

# Construct the URL to the model on Hugging Face Models Hub
model_url = f"https://huggingface.co/{repo_name}"

# Print the URL to access the model
print(f"Model merged and pushed to hub. Access it at {model_url}")

README.md:   0%|          | 0.00/3.70k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

No files have been modified since last commit. Skipping to prevent empty commit.


Model merged and pushed to hub. Access it at https://huggingface.co/eagle0504/finetune-mlabonne-smoltldr-using-HuggingFaceTB-SmolLM-135M-Instruct


## Generate Text

In [None]:
prompt = """
# A long document about the Cat

The cat (Felis catus), also referred to as the domestic cat or house cat, is a small
domesticated carnivorous mammal. It is the only domesticated species of the family Felidae.
Advances in archaeology and genetics have shown that the domestication of the cat occurred
in the Near East around 7500 BC. It is commonly kept as a pet and farm cat, but also ranges
freely as a feral cat avoiding human contact. It is valued by humans for companionship and
its ability to kill vermin. Its retractable claws are adapted to killing small prey species
such as mice and rats. It has a strong, flexible body, quick reflexes, and sharp teeth,
and its night vision and sense of smell are well developed. It is a social species,
but a solitary hunter and a crepuscular predator. Cat communication includes
vocalizations—including meowing, purring, trilling, hissing, growling, and grunting—as
well as body language. It can hear sounds too faint or too high in frequency for human ears,
such as those made by small mammals. It secretes and perceives pheromones.
"""

messages = [
    {"role": "user", "content": prompt},
]

In [None]:
# Generate text
from transformers import pipeline

generator = pipeline("text-generation", model="eagle0504/finetune-mlabonne-smoltldr-using-HuggingFaceTB-SmolLM-135M-Instruct")

## Or use the model and tokenizer we defined earlier
# generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

generate_kwargs = {
    "max_new_tokens": 256,
    "do_sample": True,
    "temperature": 0.5,
    "min_p": 0.1,
}

generated_text = generator(messages)

print(generated_text)