In [1]:
!pip install trl peft

Collecting trl
  Downloading trl-0.23.1-py3-none-any.whl.metadata (11 kB)
Downloading trl-0.23.1-py3-none-any.whl (564 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m564.6/564.6 kB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: trl
Successfully installed trl-0.23.1


In [2]:
import re
import torch
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig
from trl import GRPOConfig, GRPOTrainer
import os
from dataclasses import dataclass
from typing import Optional, List, Dict

In [3]:
@dataclass
class TrainingConfig:
    model_name: str = "HuggingFaceTB/SmolLM2-135M-Instruct"
    output_dir: str = "outputs/default-GRPO"
    project_name: str = "grpo-xml-8k-smolLM2"
    use_one_shot: bool = True
    dataset_split: str = "train"
    dataset_size: Optional[int] = None  # Limit dataset size for testing

    # Training hyperparameters
    learning_rate: float = 5e-6
    per_device_train_batch_size: int = 8
    gradient_accumulation_steps: int = 4
    num_generations: int = 4
    num_train_epochs: int = 1
    max_grad_norm: float = 0.1

    # LoRA config
    use_lora: bool = True
    lora_r: int = 16
    lora_alpha: int = 64
    lora_dropout: float = 0.05

    # Reward weights (for balancing different objectives)
    correctness_reward: float = 2.0
    format_reward: float = 0.5
    int_reward: float = 0.5
    xmlcount_weight: float = 0.5

In [4]:
SYSTEM_PROMPT = """You are a helpful math assistant. Respond in the following format:

<reasoning>
Explain your step-by-step reasoning here.
</reasoning>
<answer>
Provide only the final numerical answer here.
</answer>"""

XML_COT_FORMAT = """<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>"""

In [5]:
def extract_xml_answer(text: str) -> str:
  answer_match = re.search(r'<answer>\s*(.*?)\s*</answer>', text, re.DOTALL)
  if answer_match:
    return answer_match.group(1).strip()
  return ""

def extract_hash_answer(text: str) -> Optional[str]:
  if "####" not in text:
    return None

  answer = text.split("####")[1].strip()
  return answer

def normalize_number(text: str) -> str:
  text = text.replace(",", "").replace("$", "").strip()

  try:
    num = float(text)
    if num.is_integer():
      return str(int(num))
    return str(num)
  except ValueError:
    return text

In [6]:
def get_gsm8k_questions(config: TrainingConfig) -> Dataset:
  data = load_dataset("openai/gsm8k", "main")[config.dataset_split]

  if config.dataset_size:
    data = data.select(range(min(config.dataset_size, len(data))))

  def format_example(x: Dict) -> Dict:
    prompt = [{"role": "system", "content": SYSTEM_PROMPT}]

    if config.use_one_shot:
      prompt.extend([
          {'role': 'user', 'content': 'What is the largest single-digit prime number?'},
          {'role': 'assistant', 'content': XML_COT_FORMAT.format(
              reasoning="Let me check each single-digit number from 9 down to 2.\n9 is divisible by 3 (9 = 3 × 3).\n8 is divisible by 2 (8 = 2 × 4).\n7 is only divisible by 1 and 7, so it is prime.",
              answer="7"
          )}
      ])

    prompt.append({"role": "user", "content": x["question"]})

    answer = extract_hash_answer(x["answer"])
    if answer:
      answer = normalize_number(answer)

    return {"prompt": prompt, "answer": answer}

  formatted_data = data.map(format_example)

  return formatted_data

In [7]:
def correctness_reward_func(prompts: List[str], completions: List[str], answer: list, config: TrainingConfig, **kwargs) -> List[float]:
  responses = [completion[0]["content"] for completion in completions]
  extracted_responses = [extract_xml_answer(r) for r in responses]

  normalized_responses = [normalize_number(r) for r in extracted_responses]
  normalized_answers = [normalize_number(a) for a in answer]

  q = prompts[0][-1]["content"]
  print(f"\n{'='*50}\nQuestion: {q[:100]}...\nGround Truth: {normalized_answers[0]}\nExtracted: {normalized_responses[0]}\nMatch: {normalized_responses[0] == normalized_answers[0]}\n{'='*50}")

  rewards = [config.correctness_reward if r==a else 0.0 for r,a in zip(normalized_responses, normalized_answers)]

  return rewards


def int_reward_func(completions: List[str], config: TrainingConfig, prompts: List[str], **kwargs) -> List[float]:
  responses = [completion[0]["content"] for completion in completions]
  extracted_responses = [extract_xml_answer(r) for r in responses]

  def is_numeric(text: str) -> bool:
    try:
      float(text.replace(",", "").replace("$", "").strip())
      return True
    except ValueError:
      return False

  return [config.int_reward if is_numeric(r) else 0.0 for r in extracted_responses]


def format_reward_func(completions: List[str], config: TrainingConfig, strict:bool=False, **kwargs) -> List[float]:
  if strict:
      pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n?$"
  else:
      pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"

  responses = [completion[0]["content"] for completion in completions]
  matches = [re.search(pattern, r, re.DOTALL) for r in responses]

  return [config.format_reward if m else 0.0 for m in matches]


def xmlcount_reward_func(completions: List[str], config: TrainingConfig, **kwargs) -> List[float]:
  def count_xml(text: str) -> float:
    count = 0.0

    # Reward correct opening tags
    if text.count("<reasoning>") == 1:
      count += 0.125
    if text.count("</reasoning>") == 1:
      count += 0.125
    if text.count("<answer>") == 1:
      count += 0.125
    if text.count("</answer>") == 1:
      count += 0.125

    # Penalize content after closing answer tag
    if "</answer>" in text:
      after_answer = text.split("</answer>", 1)[1]
      # Small penalty for extra content
      count -= min(len(after_answer.strip()) * 0.001, 0.1)

    # Penalize duplicate tags
    if text.count("<reasoning>") > 1 or text.count("</reasoning>") > 1:
      count -= 0.2
    if text.count("<answer>") > 1 or text.count("</answer>") > 1:
      count -= 0.2

    return count

  responses = [completion[0]["content"] for completion in completions]
  base_rewards = [count_xml(c) for c in responses]
  return [r * config.xmlcount_weight for r in base_rewards]


In [8]:
def setup_model_and_tokenizer(config: TrainingConfig):
  model = AutoModelForCausalLM.from_pretrained(
      config.model_name,
      device_map = "auto",
      trust_remote_code = True #some models might need this
  )

  tokenizer = AutoTokenizer.from_pretrained(config.model_name, trust_remote_code = True)
  if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

  return model, tokenizer

def create_lora_config(config: TrainingConfig) -> Optional[LoraConfig]:
  if not config.use_lora:
    return None

  return LoraConfig(
      r=config.lora_r,
      lora_alpha=config.lora_alpha,
      target_modules=[
          "q_proj", "k_proj", "v_proj", "o_proj",
          "up_proj", "down_proj", "gate_proj"
      ],
      task_type="CAUSAL_LM",
      lora_dropout=config.lora_dropout,
)

In [9]:
config = TrainingConfig()
dataset = get_gsm8k_questions(config)

model, tokenizer = setup_model_and_tokenizer(config)

peft_config = create_lora_config(config)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

main/train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

main/test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

config.json:   0%|          | 0.00/861 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/269M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/132 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/655 [00:00<?, ?B/s]

In [10]:
training_args = GRPOConfig(
    output_dir=config.output_dir,
    learning_rate=config.learning_rate,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type='cosine',
    logging_steps=1,
    per_device_train_batch_size=config.per_device_train_batch_size,
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    num_generations=config.num_generations,
    # max_prompt_length=256,
    # max_completion_length=786,
    num_train_epochs=config.num_train_epochs,
    save_steps=100,
    save_total_limit=3,  # Only keep 3 most recent checkpoints
    max_grad_norm=config.max_grad_norm,
    report_to="wandb",
    bf16=False,
)

reward_funcs = [
    lambda **kwargs: xmlcount_reward_func(config=config, **kwargs),
    lambda **kwargs: format_reward_func(config=config, **kwargs),
    lambda **kwargs: int_reward_func(config=config, **kwargs),
    lambda **kwargs: correctness_reward_func(config=config, **kwargs)
]

In [11]:
import os
os.environ["WANDB_PROJECT"] = config.project_name

In [12]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=reward_funcs,
    args=training_args,
    train_dataset=dataset,
    peft_config=peft_config
)

In [None]:
trainer.train()
final_output_dir = os.path.join(config.output_dir, "final_model")
trainer.save_model(final_output_dir)

  | |_| | '_ \/ _` / _` |  _/ -_)


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33myashwantherukulla[0m ([33myashwantherukulla-vellore-institute-of-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


[34m[1mwandb[0m: Detected [huggingface_hub.inference, openai] in use.
[34m[1mwandb[0m: Use W&B Weave for improved LLM call tracing. Install Weave with `pip install weave` then add `import weave` to the top of your script.
[34m[1mwandb[0m: For more information, check out the docs at: https://weave-docs.wandb.ai/



Question: Ahmed and Emily are having a contest to see who can get the best grade in the class. There have been...
Ground Truth: 100
Extracted: 
Match: False


Step,Training Loss
1,0.0439
2,0.0454
3,0.0303
4,0.1083
5,0.1368
6,0.0979
7,0.079
8,0.1059
9,0.0912
10,0.0764



Question: Marie has 98 unread messages on her phone. She decides to clear them by reading 20 messages a day. H...
Ground Truth: 7
Extracted: 
Match: False

Question: Mary bought a packet of 1500 stickers. She shared them between Susan, Andrew and Sam in the ratio 1:...
Ground Truth: 900
Extracted: 
Match: False

Question: A thirsty traveler found an oasis in the desert. He drank 32 ounces of water. His camel drank seven ...
Ground Truth: 2
Extracted: 
Match: False

Question: Frank invites his friends over to play video games. He bakes a pan of brownies before he arrives. He...
Ground Truth: 3
Extracted: 
Match: False

Question: Cathy and Chris got summer jobs at the cake shop and were supposed to work 20 hours per week each fo...
Ground Truth: 180
Extracted: 
Match: False

Question: John uses the bathroom every 50 minutes.  How many times does he use the bathroom during a 2.5-hour ...
Ground Truth: 3
Extracted: 14 bathroom use
Match: False

Question: John went on a mission that was su

Step,Training Loss
1,0.0439
2,0.0454
3,0.0303
4,0.1083
5,0.1368
6,0.0979
7,0.079
8,0.1059
9,0.0912
10,0.0764



Question: Jack went to a supermarket with $100 and bought 4 bottles of water. Then his mother called him and a...
Ground Truth: 71
Extracted: 51
Match: False

Question: Magdalena has an apple tree on their farm, producing very few apples each year for a while now. Howe...
Ground Truth: 20
Extracted: 
Match: False

Question: John has taken 10 pictures every day for the past 3 years.  He saves them in raw format so each memo...
Ground Truth: 13140
Extracted: John purchases 50 worth of memories and 2.50 worth of resources all for 115.
Match: False

Question: Robert, a sales agent, earns a basic salary of $1250 per month and,  10% commission on his monthly s...
Ground Truth: 2888
Extracted: 
Match: False

Question: Meena bakes 5 dozen cookies for the school’s bake sale.  She sells 2 dozen cookies to her biology te...
Ground Truth: 15
Extracted: 
Match: False

Question: Mr. Wells has a garden of flowers with 50 rows. If each row has 400 flowers and Mr. Wells cuts 60% o...
Ground Truth: 800