<a href="https://colab.research.google.com/github/vizcayal/aha_moment/blob/main/aha_Llama_countdown_working%20(jun25).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

In [None]:
from unsloth import FastLanguageModel, PatchFastRL
from unsloth import is_bfloat16_supported             #check if bfloat16 is supported
from datasets import load_dataset, Dataset
from trl import GRPOConfig, GRPOTrainer
from vllm import SamplingParams

In [None]:
import torch
#PatchFastRL("GRPO", FastLanguageModel)                # apply patch for training LLMs with Group Relative Policy Optimization (GRPO) used in deepseek
max_seq_length = 1024                                  # max sequencen length
lora_rank = 32                                        # dim for lora matrix

#load the Llama 3.1-8B model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
                                                    #model_name = "meta-llama/meta-Llama-3.1-8B-Instruct",
                                                    model_name = "meta-llama/Llama-3.2-3B-Instruct",
                                                    max_seq_length = max_seq_length,
                                                    load_in_4bit = True,                                    # False for LoRA 16bit
                                                    fast_inference = True,                                  # Enable vLLM fast inference
                                                    max_lora_rank = lora_rank,
                                                    gpu_memory_utilization = 0.6,
                                                    )

model = FastLanguageModel.get_peft_model(                                                                   #Parameter efficient Fine-Tuning (peft)
                                        model,
                                        r = lora_rank,
                                        target_modules = [
                                                          "q_proj",
                                                          "k_proj",
                                                          "v_proj",
                                                          "o_proj",
                                                          "gate_proj",
                                                          "up_proj",
                                                          "down_proj",
                                                        ],
                                        lora_alpha = lora_rank,
                                        use_gradient_checkpointing = "unsloth",
                                        random_state = 3407,
                                        )

In [None]:
def generate_r1_prompt(numbers, target):
    r1_prefix = [{
        "role": "system",
        "content": "You are a math assistant, specifically creating arithmetic equations. you first thinks about the reasoning process and then provides the user with the answer."
      },
      {
        "role": "user",
        "content": f"Using the numbers {numbers}, create an equation that equals {target}. You can use only the following basic arithmetic operations (+, -, *, /) one or multiple times but each number can only be used once.\
         Show your work in <think> </think> tags. And return the final equation with all and only ({numbers}) in <answer> </answer> tags, for example if the numbers are 1,2, 3 <answer> (1 + 2) / 3 </answer>. Think step by step inside <think></think> tags."
      },
      {
        "role": "assistant",
        "content": "Let me solve this step by step.\n<think>"
      }]
    return {"prompt": tokenizer.apply_chat_template(r1_prefix, tokenize=False, continue_final_message=True), "target": target, "nums": numbers}

def format_reward_func(completions, target, **kwargs):
    """
    Format: <think>...</think><answer>...</answer>
    Args:
        completions (list[str]): Generated outputs
        target (list[str]): Expected answers

      Returns:
          list[float]: Reward scores
    """
    rewards = []

    for completion, gt in zip(completions, target):
        try:
            # Add synthetic <think> if needed (based on your prompt structure)
            completion_with_think = "<think>" + completion

            # Step 1: Check for <think> and </think> tags
            think_match = re.search(r"<think>(.*?)</think>", completion_with_think, re.DOTALL)

            if think_match is not None:
                rewards.append(0.0)
                continue

            # Step 2: Check for <answer> and </answer> tags
            answer_match = re.search(r"<answer>(.*?)</answer>", completion_with_think, re.DOTALL)

            if answer_match is None:
                rewards.append(0.0)
                continue

            # If both tags are present, reward is 1.0
            rewards.append(1.0)

        except Exception:
            # If any error occurs during the process, reward is 0
            rewards.append(0.0)

    return rewards

# def format_reward_func(completions, target, **kwargs):
#     """
#     Format: <think>...</think><answer>...</answer>
#     Args:
#         completions (list[str]): Generated outputs
#         target (list[str]): Expected answers

#       Returns:
#           list[float]: Reward scores
#     """
#     rewards = []

#     # print("completions: ", completions)
#     # print("target: ", target)
#     i = 0

#     for completion, gt in zip(completions, target):


#       try:
#         completion = "<think>" + completion

#         # add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
#         #completion = "<think>" + completion

#         # Check if the format is correct
#         regex = r"^<think>([^<]*(?:<(?!/?think>)[^<]*)*)<\/think>\n<answer>([\s\S]*?)<\/answer>$"

#         match = re.search(regex, completion, re.DOTALL)
#         # if the format is not correct, reward is 0
#         if match is None or len(match.groups()) != 2:
#             rewards.append(0.0)
#         else:
#             rewards.append(1.0)
#       except Exception:
#         rewards.append(0.0)
#     return rewards

def equation_reward_func(completions, target, nums, **kwargs):
    """
    Evaluates completions based on:
    2. Mathematical correctness of the answer

    Args:
        completions (list[str]): Generated outputs
        target (list[str]): Expected answers
        nums (list[str]): Available numbers

    Returns:
        list[float]: Reward scores
    """
    rewards = []
    for completion, gt, numbers in zip(completions, target, nums):
      try:
        # add synthetic <think> as its already part of the prompt and prefilled for the assistant to more easily match the regex
        completion = "<think>" + completion
        print(f'{completion = }')
        print(f'{gt = }')
        print(f'{numbers = }')
        # Check if the format is correct
        match = re.search(r"<answer>(.*?)<\/answer>", completion)
        if match is None:
            rewards.append(0.0)
            continue
        # Extract the "answer" part from the completion
        equation = match.group(1).strip()
        # Extract all numbers from the equation
        used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
        print(f'{equation = }')
        print(f'{used_numbers = }')

        # Check if all numbers are used exactly once
        if sorted(used_numbers) != sorted(numbers):
            rewards.append(0.0)
            continue
        # Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
        allowed_pattern = r'^[\d+\-*/().\s]+$'
        if not re.match(allowed_pattern, equation):
           rewards.append(0.0)
           continue

        # Evaluate the equation with restricted globals and locals
        result = eval(equation, {"__builtins__": None}, {})
        print(f'{result = }')
        # Check if the equation is correct and matches the ground truth
        if abs(float(result) - float(gt)) < 1e-5:
            rewards.append(1.0)
            print('******************we got it****************')

        else:
            rewards.append(0.0)
      except Exception:
            # If evaluation fails, reward is 0
            rewards.append(0.0)
    return rewards



set up params for grpo trainer

In [None]:
# convert our dataset to the r1 prompt
dataset = load_dataset("Jiayi-Pan/Countdown-Tasks-3to4", split = "train")
dataset = dataset.map(lambda x: generate_r1_prompt(x["nums"], x["target"]))
train_test_split = dataset.train_test_split(test_size=0.1)

train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]

In [None]:
training_args = GRPOConfig(
                          use_vllm = True,
                          learning_rate = 5e-7,
                          adam_beta1 = 0.9,
                          adam_beta2 = 0.99,
                          weight_decay = 0.1,
                          warmup_ratio = 0.03,
                          lr_scheduler_type = 'cosine',
                          optim = 'paged_adamw_8bit',
                          logging_steps = 1,
                          bf16 = is_bfloat16_supported(),
                          fp16 = not is_bfloat16_supported(),
                          per_device_train_batch_size = 4,
                          gradient_accumulation_steps = 1,
                          num_generations = 4,
                          max_prompt_length = 512,
                          max_completion_length = 2048,
                          max_steps = 2000,
                          save_steps = 250,
                          max_grad_norm = 1,
                          report_to = 'none',
                          output_dir = 'outputs'
                          )


seed: 42
model.save_lora("model_grpo_lora")

In [None]:
grpo_trainer = GRPOTrainer(
                          model = model,
                          processing_class = tokenizer,
                          reward_funcs=[format_reward_func, equation_reward_func],
                          args = training_args,
                          train_dataset = dataset
                          )

grpo_trainer.train()