<div align="center">
<a href="https://rapidfire.ai/"><img src="https://raw.githubusercontent.com/RapidFireAI/rapidfireai/main/images/RapidFire - Blue bug -white text.svg" width="115"></a>
<a href="https://discord.gg/6vSTtncKNN"><img src="https://raw.githubusercontent.com/RapidFireAI/rapidfireai/main/images/discord-button.svg" width="145"></a>
<a href="https://oss-docs.rapidfire.ai/"><img src="https://raw.githubusercontent.com/RapidFireAI/rapidfireai/main/images/documentation-button.svg" width="125"></a>
<br/>
Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/RapidFireAI/rapidfireai">GitHub</a></i> ⭐
<br/>
To install RapidFire AI on your own machine, see the <a href="https://oss-docs.rapidfire.ai/en/latest/walkthrough.html">Install and Get Started</a> guide in our docs.
</div>

### RapidFire AI Tutorial Use Case: GRPO for Math Reasoning

In [None]:
from rapidfireai import Experiment
from rapidfireai.automl import List, RFGridSearch, RFModelConfig, RFLoraConfig, RFGRPOConfig

### Load Dataset and Specify Train and Eval Partitions

In [None]:
from datasets import load_dataset, Dataset

def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] 
    return data 

# Select a subset of the dataset for demo purposes
train_dataset = get_gsm8k_questions(split="train").select(range(128))
eval_dataset = get_gsm8k_questions(split="test").select(range(24))
train_dataset = train_dataset.shuffle(seed=42)
eval_dataset =  eval_dataset.shuffle(seed=42)

### Define Data Processing Function

In [None]:
def sample_formatting_function(row):
    """Function to preprocess each example from dataset"""

    def extract_hash_answer(text: str) -> str | None:
        if "####" not in text:
            return None
        answer = text.split("####")[1].strip()
        try:
            answer = answer.replace(",", "")
        except:
            return None
        return answer
        
    SYSTEM_PROMPT = """
    Respond in the following format:
    <reasoning>
    ...
    </reasoning>
    <answer>
    ...
    </answer>
    """
    return { # Return a conversation format dictionary
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': row['question']}
        ],
        'question': row['question'],
        'answer': extract_hash_answer(row['answer'])
    }

### Initialize Experiment

In [None]:
# Every experiment instance must be uniquely named
experiment = Experiment(experiment_name="exp1-math-reasoning-lite")

#### Define Custom Reward Functions

In [None]:
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:

    def extract_xml_answer(text: str) -> str:
        answer = text.split("<answer>")[-1]
        answer = answer.split("</answer>")[0]
        return answer.strip()

    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    
    def extract_xml_answer(text: str) -> str:
        answer = text.split("<answer>")[-1]
        answer = answer.split("</answer>")[0]
        return answer.strip()
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    import re
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    import re
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    def count_xml(text) -> float:
        count = 0.0
        if text.count("<reasoning>\n") == 1:
            count += 0.125
        if text.count("\n</reasoning>\n") == 1:
            count += 0.125
        if text.count("\n<answer>\n") == 1:
            count += 0.125
            count -= len(text.split("\n</answer>\n")[-1])*0.001
        if text.count("\n</answer>") == 1:
            count += 0.125
            count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
        return count
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

### Define Multi-Config Knobs for Model, LoRA, and GRPO Trainer using RapidFire AI Wrapper APIs

In [None]:
lora_config = RFLoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "down_proj", "up_proj"],
    bias="none"
)

grpo_config_base = RFGRPOConfig(
    learning_rate=5e-6,
    warmup_ratio=0.1,
    weight_decay=0.1,
    max_grad_norm=0.1,
    adam_beta1=0.9,
    adam_beta2=0.99,
    lr_scheduler_type = "linear",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    num_generations=8,
    optim ="adamw_8bit",
    num_train_epochs=1,
    max_prompt_length=256,
    max_completion_length=256,
    logging_steps=2,
    beta=0.0 # No reference model
)

grpo_config_2 = grpo_config_base.copy()
grpo_config_2.learning_rate = 1e-6

reward_funcs = [
    correctness_reward_func,
    int_reward_func,
    strict_format_reward_func,
    soft_format_reward_func,
    xmlcount_reward_func,
]

# List of 2 separate configs
config_set_lite = List([
    RFModelConfig(
        model_name="Qwen/Qwen2.5-0.5B-Instruct",
        peft_config=lora_config,
        training_args=grpo_config_base,
        formatting_func=sample_formatting_function,
        reward_funcs=reward_funcs,
        model_kwargs={"load_in_4bit": False, "device_map": "auto", "torch_dtype": "float16", "use_cache": False},
        tokenizer_kwargs={"model_max_length": 512, "padding_side": "left", "truncation": True}
    ),
    RFModelConfig(
        model_name="meta-llama/Llama-3.2-1B-Instruct",
        peft_config=lora_config,
        training_args=grpo_config_2,
        formatting_func=sample_formatting_function,
        reward_funcs=reward_funcs,
        model_kwargs={"load_in_4bit": False, "device_map": "auto", "torch_dtype": "float16", "use_cache": False},
        tokenizer_kwargs={"model_max_length": 512, "padding_side": "left", "truncation": True}
    ),
])

#### Define Model Creation Function

In [None]:
def sample_create_model(model_config):
   """Function to create model object for any given config; must return tuple of (model, tokenizer)"""
   from transformers import AutoModelForCausalLM, AutoTokenizer
   
   model_name = model_config["model_name"]
   model_kwargs = model_config["model_kwargs"]
   tokenizer_kwargs = model_config["tokenizer_kwargs"]
   return (
      AutoModelForCausalLM.from_pretrained(model_name, **model_kwargs),
      AutoTokenizer.from_pretrained(model_name, **tokenizer_kwargs)
   )

#### Generate Config Group

In [None]:
# Simple grid search across all sets of config knob values = 3 combinations in total
config_group = RFGridSearch(
    configs=config_set_lite,
    trainer_type="GRPO",
)

### Run Multi-Config Training

In [None]:
# Launch training of all configs in the config_group with swap granularity of 6 chunks
experiment.run_fit(config_group, sample_create_model, train_dataset, eval_dataset, num_chunks=4, seed=42)

### End Current Experiment

In [None]:
experiment.end()

<div align="center">
<a href="https://rapidfire.ai/"><img src="https://raw.githubusercontent.com/RapidFireAI/rapidfireai/main/images/RapidFire - Blue bug -white text.svg" width="115"></a>
<a href="https://discord.gg/6vSTtncKNN"><img src="https://raw.githubusercontent.com/RapidFireAI/rapidfireai/main/images/discord-button.svg" width="145"></a>
<a href="https://oss-docs.rapidfire.ai/"><img src="https://raw.githubusercontent.com/RapidFireAI/rapidfireai/main/images/documentation-button.svg" width="125"></a>
<br/>
Thanks for trying RapidFire AI! ⭐ <i>Star us on <a href="https://github.com/RapidFireAI/rapidfireai">GitHub</a></i> ⭐
</div>
