# Test Notebook

Implement a simple SFT and PPO training pipeline for finetuning Qwen2.5-7B model on GSM8K dataset.

1. Load the Qwen2.5-7B model and tokenizer.
2. Load the GSM8K dataset from `openai/gsm8k`.
3. Split the dataset into training and validation sets.
4. Implement Supervised Fine-Tuning (SFT) on the training set using peft (LoRA).
5. Implement Proximal Policy Optimization (PPO) on the SFT model using trl.

***TODO***:

1. Fix bugs
2. Save trained models to specific paths
3. Set up GPU devices

## Imports

In [1]:
%pip install -U transformers datasets evaluate peft trl bitsandbytes

Collecting datasets
  Using cached datasets-4.4.1-py3-none-any.whl.metadata (19 kB)
Collecting evaluate
  Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)
Collecting peft
  Downloading peft-0.18.0-py3-none-any.whl.metadata (14 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.48.2-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Using cached pyarrow-22.0.0.tar.gz (1.2 MB)
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[31mERROR: Could not install packages due to an OSError: [Errno 2] No such file or directory: '/home/bowenyu/miniconda3/envs/rl-lora/lib/python3.10/site-packages/certifi-2025.10.5.dist-info/METADATA'
[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [2]:
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer

  from .autonotebook import tqdm as notebook_tqdm


## Global Config

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"

  return torch._C._cuda_getDeviceCount() > 0


## Model

In [None]:
MODEL_NAME = "Qwen/Qwen2.5-7B-Instruct"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.padding_side = "right"  # during training, right padding is needed
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,
    device_map="auto",
).to(device)  # is it really needed to add ".to(device)"?

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],  # adjust to Qwen’s naming
    bias="none",
    task_type="CAUSAL_LM",
)

sft_model = get_peft_model(model, lora_config)

Fetching 4 files:   0%|          | 0/4 [00:11<?, ?it/s]


## Dataset

In [None]:
gsm8k = load_dataset("openai/gsm8k", "main")

def format_example(ex):
    prompt = (
        "You are a helpful math tutor. Solve the following problem step by step.\n"
        "Show your reasoning clearly, and put the final answer in the form \"#### <answer>\".\n\n"
        f"Question:\n{ex['question']}\n\nAnswer:\n"
    )
    # GSM8K answer already ends with '#### <ans>'
    target = ex["answer"]
    full_text = prompt + target
    return {"text": full_text}

gsm8k = gsm8k.map(format_example)
train_data = gsm8k["train"]
test_data = gsm8k["test"]

def tokenize_fn(ex):
    out = tokenizer(
        ex["text"],
        truncation=True,
        max_length=1024,
    )
    out["labels"] = out["input_ids"].copy()
    return out

tokenized_train = train_data.map(tokenize_fn, batched=True, remove_columns=train_data.column_names)
# tokenized_train = tokenized_train.select(range(5000))  # Optionally limit to first 5000 samples for quicker training
tokenized_test = test_data.map(tokenize_fn, batched=True, remove_columns=test_data.column_names)

## Supervised Fine-Tuning (SFT)

In [None]:
training_args = TrainingArguments(
    output_dir="./qwen-gsm8k-sft",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=8,
    num_train_epochs=2,
    learning_rate=2e-4,
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    bf16=True,
    logging_steps=20,
    save_strategy="epoch",
    evaluation_strategy="epoch",
)

trainer = Trainer(
    model=sft_model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
)
trainer.train()

In [None]:
trainer.save_model("./qwen-gsm8k-sft-lora")
tokenizer.save_pretrained("./qwen-gsm8k-sft-lora")

## Proximal Policy Optimization (PPO)

In [None]:
# Setup

# policy model with value head
policy_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

# reference model (frozen)
ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)
ref_model.eval()
for p in ref_model.parameters():
    p.requires_grad = False

In [None]:
# config
ppo_config = PPOConfig(
    batch_size=64,          # queries per PPO step
    forward_batch_size=16,  # microbatching
    learning_rate=1e-5,
    log_with=None,          # wandb if you want
    mini_batch_size=16,
    ppo_epochs=4,
    kl_penalty="kl",
    kl_coef=0.1,
    target_kl=0.1,
)

ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=train_data,   # NOT tokenized; PPOTrainer expects raw text fields
    data_collator=None,
)

In [None]:
def make_prompt_only(ex):
    prompt = (
        "You are a helpful math tutor. Solve the following problem step by step.\n"
        "Show your reasoning clearly, and end with \"#### <answer>\".\n\n"
        f"Question:\n{ex['question']}\n\nAnswer:\n"
    )
    return {"prompt": prompt, "answer": ex["answer"]}

ppo_gsm = gsm8k["train"].map(make_prompt_only)
ppo_trainer = PPOTrainer(
    config=ppo_config,
    model=policy_model,
    ref_model=ref_model,
    tokenizer=tokenizer,
    dataset=ppo_gsm,
)

In [None]:
def extract_final_answer(text: str):
    """
    Extract the final answer (in the form #### <answer>) from the generated text.

    Args:
        text (str): generated text from the model.

    Returns:
        The extracted final answer as a string, or None if not found.
    """
    # Look for pattern #### <number>; TODO: handle more complex answers
    m = re.search(r"####\s*([-+]?\d+(\.\d+)?)", text)
    if m:
        return m.group(1)
    return None

def correctness_reward(generated: str, gold_answer: str) -> float:
    """
    Compute the correctness reward based on final answers. If the final answer
    extracted from the generated text matches the gold answer, return 1.0, else 0.0.
    
    Args:
        generated (str): The generated text from the model.
        gold_answer (str): The ground truth answer.
        
    Returns:
        float: 1.0 if answers match, else 0.0.
    """
    # TODO: Use Math-Verify for more robust evaluation
    gold = extract_final_answer(gold_answer)
    pred = extract_final_answer(generated)
    if gold is None or pred is None:
        return 0.0
    return 1.0 if gold.strip() == pred.strip() else 0.0

In [None]:
ppo_dataset = ppo_gsm.select(range(2000))  # tiny subset

for epoch in range(3):
    for batch_start in range(0, len(ppo_dataset), ppo_config.batch_size):
        batch = ppo_dataset[batch_start: batch_start + ppo_config.batch_size]
        if len(batch["prompt"]) == 0:
            continue
        
        queries = batch["prompt"]          # list[str]
        gold_answers = batch["answer"]     # list[str]

        # 1. Generate responses
        responses = []
        for q in queries:
            gen = policy_model.generate(
                **tokenizer(q, return_tensors="pt").to(policy_model.device),
                max_new_tokens=256,
                do_sample=True,
                top_p=0.9,
                temperature=0.7,
            )
            resp_text = tokenizer.decode(gen[0], skip_special_tokens=True)
            # Keep only the completion part if you want; minimal version uses full text
            responses.append(resp_text)

        # 2. Compute rewards
        rewards = []
        for r, gold in zip(responses, gold_answers):
            rewards.append(correctness_reward(r, gold))
        
        # 3. PPO update
        stats = ppo_trainer.step(queries, responses, rewards)
        # optionally log stats

## Evaluation

In [None]:
def evaluate_model(model, tokenizer, ds, num_samples=500):
    model.eval()
    correct = 0
    total = 0

    subset = ds.select(range(min(num_samples, len(ds))))

    for ex in tqdm(subset):
        prompt = (
            "You are a helpful math tutor. Solve the following problem step by step.\n"
            "Show your reasoning clearly, and end with \"#### <answer>\".\n\n"
            f"Question:\n{ex['question']}\n\nAnswer:\n"
        )
        gold = ex["answer"]

        with torch.no_grad():
            gen = model.generate(
                **tokenizer(prompt, return_tensors="pt").to(model.device),
                max_new_tokens=256,
                do_sample=False,      # greedy for eval
            )
        out = tokenizer.decode(gen[0], skip_special_tokens=True)
        # depending on your format, maybe slice out just the answer:
        pred = out[len(prompt):]

        r = correctness_reward(pred, gold)
        correct += r
        total += 1

    acc = correct / total
    return acc

In [None]:
gsm_test = gsm8k["test"]

base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
).to(device)

acc_base = evaluate_model(base_model, tokenizer, gsm_test)
acc_sft  = evaluate_model(sft_model, tokenizer, gsm_test)
acc_ppo  = evaluate_model(policy_model, tokenizer, gsm_test)         # PPO-only
# acc_sft_ppo = evaluate_model(policy_model_sft, tokenizer, gsm_test)  # SFT→PPO