To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News

**Read our [blog post](https://unsloth.ai/blog/r1-reasoning) for guidance on how to train reasoning models.**

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [1]:
%%capture
# Skip restarting message in Colab
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

!pip install unsloth vllm
!pip install --upgrade pillow

### Unsloth

Load up `Phi-4 14B`, and set parameters

In [3]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch
max_seq_length = 512 # Can increase for longer reasoning traces
lora_rank = 16 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Phi-4",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["gate_proj", "up_proj", "down_proj",],
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

INFO 03-12 22:11:07 __init__.py:207] Automatically detected platform cuda.
Unsloth: Switching from Unsloth dynamic quant to normal quant since
we do not yet support fast inference for unsloth/phi-4-unsloth-bnb-4bit
==((====))==  Unsloth 2025.3.9: Fast Llama patching. Transformers: 4.48.3. vLLM: 0.7.3.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


config.json:   0%|          | 0.00/1.40k [00:00<?, ?B/s]

Unsloth: vLLM loading unsloth/phi-4-bnb-4bit with actual GPU utilization = 69.34%
Unsloth: Your GPU has CUDA compute capability 7.5 with VRAM = 14.74 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 512. Num Sequences = 128.
Unsloth: vLLM's KV Cache can use up to 0.25 GB. Also swap space = 2 GB.
INFO 03-12 22:11:39 config.py:549] This model supports multiple tasks: {'classify', 'reward', 'embed', 'score', 'generate'}. Defaulting to 'generate'.
Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'float16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': ['lm_head', 'multi_modal_projector', 'merger', 'modality_projection'], 'llm_int8_threshold': 6.0}
INFO 03-12 22:11:39 llm_engine.py:234] Initializing a V0 LLM engine (v0.7.3) with config: model='uns

tokenizer_config.json:   0%|          | 0.00/18.0k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.61M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/917k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.15M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/170 [00:00<?, ?B/s]

INFO 03-12 22:11:41 cuda.py:178] Cannot use FlashAttention-2 backend for Volta and Turing GPUs.
INFO 03-12 22:11:41 cuda.py:226] Using XFormers backend.
INFO 03-12 22:11:42 model_runner.py:1110] Starting to load model unsloth/phi-4-bnb-4bit...
INFO 03-12 22:11:43 loader.py:1089] Loading weights with BitsAndBytes quantization.  May take a while ...
INFO 03-12 22:11:43 weight_utils.py:254] Using model weights format ['*.safetensors']


model-00001-of-00002.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.11G [00:00<?, ?B/s]

INFO 03-12 22:13:52 weight_utils.py:270] Time spent downloading weights for unsloth/phi-4-bnb-4bit: 129.053992 seconds


Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 03-12 22:15:20 model_runner.py:1115] Loading model weights took 8.4920 GB
INFO 03-12 22:15:20 punica_selector.py:18] Using PunicaWrapperGPU.
INFO 03-12 22:15:34 worker.py:267] Memory profiling takes 14.12 seconds
INFO 03-12 22:15:34 worker.py:267] the current vLLM instance can use total_gpu_memory (14.74GiB) x gpu_memory_utilization (0.69) = 10.22GiB
INFO 03-12 22:15:34 worker.py:267] model weights take 8.49GiB; non_torch_memory takes 0.03GiB; PyTorch activation peak memory takes 0.47GiB; the rest of the memory reserved for KV Cache is 1.23GiB.
INFO 03-12 22:15:35 executor_base.py:111] # cuda blocks: 403, # CPU blocks: 655
INFO 03-12 22:15:35 executor_base.py:116] Maximum concurrency for 512 tokens per request: 12.59x
INFO 03-12 22:15:37 model_runner.py:1434] Capturing cudagraphs for decoding. This may lead to unexpected consequences if the model is not static. To run the model in eager mode, set 'enforce_eager=True' or use '--enforce-eager' in the CLI. If out-of-memory error occu

Capturing CUDA graph shapes: 100%|██████████| 19/19 [00:46<00:00,  2.45s/it]

INFO 03-12 22:16:23 model_runner.py:1562] Graph capturing finished in 47 secs, took 0.65 GiB
INFO 03-12 22:16:23 llm_engine.py:436] init engine (profile, create kv cache, warmup model) took 63.71 seconds





tokenizer_config.json:   0%|          | 0.00/18.0k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.61M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/917k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/7.15M [00:00<?, ?B/s]

Not an error, but Unsloth cannot patch Attention layers with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Not an error, but Unsloth cannot patch O projection layer with our manual autograd engine since either LoRA adapters
are not enabled or a bias term (like in Qwen) is used.
Unsloth 2025.3.9 patched 40 layers with 0 QKV layers, 0 O layers and 40 MLP layers.


### Data Prep
<a name="Data"></a>

We directly leverage [@willccbb](https://gist.github.com/willccbb/4676755236bb08cab5f4e54a0475d6fb) for data prep and all reward functions. You are free to create your own!

In [24]:
import re
from datasets import load_dataset, Dataset

# Keep the same system prompt and XML format
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

#########################################################################
# IMPROVED EXTRACTION FUNCTIONS (used by reward functions)
#########################################################################
dataset_natural_reasoning = load_dataset('facebook/natural_reasoning')

def get_natural_reasoning_questions(split="train") -> Dataset:
    data = dataset_natural_reasoning[split]
    data = data.map(lambda x: {
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': x['responses']
    })
    return data

def robust_extract_xml_answer(text: str) -> str:
    """More robust extraction of answer from XML tags."""
    try:
        # Standard format
        match = re.search(r"<answer>(.*?)</answer>", text, re.DOTALL)
        if match:
            return match.group(1).strip()

        # Handle case variants and formatting issues
        match = re.search(r"<Answer>(.*?)</Answer>", text, re.DOTALL)
        if match:
            return match.group(1).strip()

        # Last resort attempt
        if "<answer>" in text.lower():
            answer_text = text.lower().split("<answer>")[1]
            if "</answer>" in answer_text:
                answer_text = answer_text.split("</answer>")[0]
                return answer_text.strip()
    except:
        pass

    return ""

def robust_extract_xml_reasoning(text: str) -> str:
    """More robust extraction of reasoning from XML tags."""
    try:
        # Standard format
        match = re.search(r"<reasoning>(.*?)</reasoning>", text, re.DOTALL)
        if match:
            return match.group(1).strip()

        # Handle case variants and formatting issues
        match = re.search(r"<Reasoning>(.*?)</Reasoning>", text, re.DOTALL)
        if match:
            return match.group(1).strip()

        # Last resort attempt
        if "<reasoning>" in text.lower():
            reasoning_text = text.lower().split("<reasoning>")[1]
            if "</reasoning>" in reasoning_text:
                reasoning_text = reasoning_text.split("</reasoning>")[0]
                return reasoning_text.strip()
    except:
        pass

    return ""

def extract_hash_answer(text: str) -> str | None:
    """Extract answer from GSM8K-style #### format."""
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

#########################################################################
# IMPROVED VERSIONS OF ORIGINAL REWARD FUNCTIONS
#########################################################################

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    """
    Improved XML tag formatting reward function that checks for XML tags
    and gives partial credit for tag presence and ordering.
    """
    contents = [completion[0]["content"] for completion in completions]
    rewards = []

    for content in contents:
        score = 0.0

        # Check for opening tags with proper formatting
        if "<reasoning>" in content:
            score += 0.15
        if "</reasoning>" in content:
            score += 0.15
        if "<answer>" in content:
            score += 0.15
        if "</answer>" in content:
            score += 0.15

        # Check for proper tag ordering and nesting
        reasoning_start = content.find("<reasoning>")
        reasoning_end = content.find("</reasoning>")
        answer_start = content.find("<answer>")
        answer_end = content.find("</answer>")

        if reasoning_start != -1 and reasoning_end != -1 and reasoning_start < reasoning_end:
            score += 0.1  # Proper reasoning tag ordering

        if answer_start != -1 and answer_end != -1 and answer_start < answer_end:
            score += 0.1  # Proper answer tag ordering

        if reasoning_start != -1 and reasoning_end != -1 and answer_start != -1 and answer_end != -1:
            if reasoning_start < reasoning_end < answer_start < answer_end:
                score += 0.2  # Perfect ordering of all tags

        rewards.append(score)

    return rewards

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Improved soft format reward function that checks if the completion
    has the required XML structure with more generous pattern matching.
    """
    # More flexible pattern that allows for various whitespace and formatting
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]

    rewards = []
    for response in responses:
        # Check for the basic structure with relaxed constraints
        if re.search(pattern, response, re.DOTALL):
            rewards.append(0.5)
        else:
            # Partial credit for having both tags but not in perfect format
            has_reasoning_tags = "<reasoning>" in response and "</reasoning>" in response
            has_answer_tags = "<answer>" in response and "</answer>" in response

            if has_reasoning_tags and has_answer_tags:
                rewards.append(0.3)  # Partial credit
            elif has_reasoning_tags or has_answer_tags:
                rewards.append(0.1)  # Minimal credit
            else:
                rewards.append(0.0)  # No credit

    return rewards

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """
    Improved strict format reward function that checks if the completion
    has a precise and well-structured format.
    """
    # Strict pattern requiring proper line breaks and structure
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]

    rewards = []
    for response in responses:
        matches = re.match(pattern, response, re.DOTALL)

        if matches:
            rewards.append(0.5)  # Full credit for perfect formatting
        else:
            # Check for a close match with minor formatting differences
            close_pattern = r"<reasoning>[\s\n].*?[\s\n]</reasoning>[\s\n]<answer>[\s\n].*?[\s\n]</answer>"
            close_match = re.search(close_pattern, response, re.DOTALL)

            if close_match:
                rewards.append(0.25)  # Partial credit for close formatting
            else:
                rewards.append(0.0)  # No credit

    return rewards

def int_reward_func(completions, **kwargs) -> list[float]:
    """
    Enhanced integer reward function that gives credit for numeric answers
    and partial credit for answers containing numbers.
    """
    responses = [completion[0]["content"] for completion in completions]
    extracted_responses = [robust_extract_xml_answer(r) for r in responses]

    rewards = []
    for response in extracted_responses:
        # Full credit for pure integer answers
        if response.isdigit():
            rewards.append(0.5)
        # Partial credit for answers containing numbers
        elif any(char.isdigit() for char in response):
            # Extract all numbers from the answer
            numbers = re.findall(r'\d+', response)
            if numbers:
                # More numbers or longer numbers get slightly more credit
                digit_count = sum(len(num) for num in numbers)
                rewards.append(min(0.3, 0.1 + 0.02 * digit_count))  # Cap at 0.3
            else:
                rewards.append(0.0)
        else:
            rewards.append(0.0)

    return rewards

def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    """
    Improved correctness reward function with better answer extraction
    and partial credit for close numeric answers.
    """
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content'] if prompts and len(prompts[0]) > 0 else ""

    # Use robust extraction to handle more formats
    extracted_responses = [robust_extract_xml_answer(r) for r in responses]

    rewards = []
    for extracted in extracted_responses:
        # Check for exact match first
        exact_match = extracted == answer[0] if isinstance(answer, list) else extracted == answer

        if exact_match:
            rewards.append(2.0)
            continue

        # If not an exact match, try numerical comparison for math problems
        try:
            # Extract numbers from both the response and answer
            extracted_numbers = re.findall(r'[-+]?\d*\.?\d+', extracted)
            correct_numbers = re.findall(r'[-+]?\d*\.?\d+', answer[0] if isinstance(answer, list) else str(answer))

            # If both contain numbers, compare them
            if extracted_numbers and correct_numbers:
                extracted_val = float(extracted_numbers[0])
                correct_val = float(correct_numbers[0])

                # Check for close match (within 1%)
                rel_error = abs((extracted_val - correct_val) / correct_val) if correct_val != 0 else abs(extracted_val)

                if abs(extracted_val - correct_val) < 1e-6:  # Exact numeric match
                    rewards.append(2.0)
                elif rel_error < 0.01:  # Very close (within 1%)
                    rewards.append(1.5)  # Partial credit
                elif rel_error < 0.05:  # Somewhat close (within 5%)
                    rewards.append(0.5)  # Minimal credit
                else:
                    rewards.append(0.0)
                continue
        except:
            pass

        # If we get here, there was no match
        rewards.append(0.0)

    # Print debug info the same way the original function did
    if len(completions) > 0:
        print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}",
              f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")

    return rewards

#########################################################################
# NEW REWARD FUNCTIONS TO ADD TO THE TRAINER
#########################################################################

def reasoning_coherence_reward(completions, **kwargs) -> list[float]:
    """
    Rewards logical coherence and flow in reasoning.

    Looks for step-by-step progression, logical connectors, and a clear
    path from premises to conclusion.
    """
    responses = [completion[0]["content"] for completion in completions]
    reasoning_parts = [robust_extract_xml_reasoning(response) for response in responses]

    rewards = []

    for reasoning in reasoning_parts:
        if not reasoning:
            rewards.append(0.0)
            continue

        score = 0.0

        # Check for step indicators and sequence markers
        step_indicators = ["first", "second", "third", "then", "next", "finally"]
        step_count = sum(1 for indicator in step_indicators if indicator in reasoning.lower())
        score += min(step_count / 3, 0.2)  # Max 0.2 for step indicators

        # Check for logical connectors
        logical_connectors = ["if", "then", "because", "therefore", "thus", "so", "as a result"]
        connector_count = sum(1 for connector in logical_connectors if connector in reasoning.lower())
        score += min(connector_count / 3, 0.2)  # Max 0.2 for logical connectors

        # Check for numerical computations/operations
        has_calculations = len(re.findall(r"\d+\s*[+\-*/]\s*\d+", reasoning.lower())) > 0
        score += 0.2 if has_calculations else 0.0  # 0.2 for calculations

        # Check for reasoning length and substance
        lines = reasoning.split("\n")
        meaningful_lines = [line for line in lines if len(line.strip()) > 15]  # Lines with substance
        score += min(len(meaningful_lines) / 5, 0.2)  # Max 0.2 for reasoning length/substance

        # Check for a conclusion that connects reasoning to answer
        if "therefore" in reasoning.lower() or "in conclusion" in reasoning.lower() or "so the answer is" in reasoning.lower():
            score += 0.2  # 0.2 for explicit conclusion

        rewards.append(min(score, 1.0))  # Cap at 1.0

    return rewards

def faithful_reasoning_reward(completions, prompts, **kwargs) -> list[float]:
    """
    Rewards faithful reasoning (absence of unfaithful patterns).

    Checks for self-contradictions, post-hoc rationalization,
    and other patterns indicating unfaithful reasoning.
    """
    responses = [completion[0]["content"] for completion in completions]
    reasoning_parts = [robust_extract_xml_reasoning(response) for response in responses]
    answer_parts = [robust_extract_xml_answer(response) for response in responses]

    rewards = []

    for i, (reasoning, answer) in enumerate(zip(reasoning_parts, answer_parts)):
        if not reasoning or not answer:
            rewards.append(0.0)
            continue

        # Start with maximum score
        score = 1.0

        # Check for contradictory statements
        sentences = re.split(r'[.!?]', reasoning)
        sentences = [s.strip() for s in sentences if s.strip()]

        contradiction_penalty = 0.0

        # Look for direct contradictions like "X is greater than Y" and later "Y is greater than X"
        comparisons = {}
        for sentence in sentences:
            # Match patterns like "X is greater than Y"
            match = re.search(r"(\w+)\s+is\s+(greater|larger|bigger|more|less|smaller|fewer)\s+than\s+(\w+)", sentence, re.IGNORECASE)
            if match:
                item1, relation, item2 = match.groups()
                relation_type = "greater" if relation in ["greater", "larger", "bigger", "more"] else "less"

                # Check if we've seen a comparison between these items before
                key = tuple(sorted([item1.lower(), item2.lower()]))
                if key in comparisons:
                    prev_item1, prev_relation, prev_item2 = comparisons[key]

                    # Check if this is a contradiction
                    if ((item1.lower() == prev_item1.lower() and relation_type != prev_relation) or
                        (item1.lower() == prev_item2.lower() and relation_type == prev_relation)):
                        contradiction_penalty += 0.3

                comparisons[key] = (item1.lower(), relation_type, item2.lower())

        # Check for restoration errors (silent correction of mistakes)
        restoration_penalty = 0.0
        calc_pattern = r'(\d+\s*[+\-*/]\s*\d+\s*=\s*\d+)'
        calculations = re.findall(calc_pattern, reasoning)

        for i in range(len(calculations) - 1):
            # Extract the result of the first calculation
            first_result = re.search(r'=\s*(\d+)', calculations[i])
            if not first_result:
                continue

            # Check if the result is used in the next calculation but gets a different result
            result_val = first_result.group(1)
            second_calc = calculations[i+1]

            if result_val in second_calc.split('=')[0] and "correction" not in reasoning.lower() and "mistake" not in reasoning.lower():
                # Extract the result of the second calculation
                second_result = re.search(r'=\s*(\d+)', second_calc)
                if second_result and second_result.group(1) != result_val:
                    restoration_penalty += 0.2

        # Apply penalties
        score -= (contradiction_penalty + restoration_penalty)

        rewards.append(max(score, 0.0))  # Ensure non-negative

    return rewards

def math_reasoning_quality(completions, **kwargs) -> list[float]:
    """
    Rewards high-quality mathematical reasoning.

    Checks for proper calculations, equation setup, variable definitions,
    and correct intermediate steps.
    """
    responses = [completion[0]["content"] for completion in completions]
    reasoning_parts = [robust_extract_xml_reasoning(response) for response in responses]

    rewards = []

    for reasoning in reasoning_parts:
        if not reasoning:
            rewards.append(0.0)
            continue

        score = 0.0

        # Check for explicit calculations
        calculations = re.findall(r'\d+\s*[+\-*/]\s*\d+\s*=\s*\d+', reasoning)
        score += min(len(calculations) * 0.1, 0.3)  # Up to 0.3 for explicit calculations

        # Check for variable definitions
        var_definitions = re.findall(r'(?:let|set|define)\s+([a-zA-Z])\s*=', reasoning.lower())
        score += min(len(var_definitions) * 0.05, 0.15)  # Up to 0.15 for variable definitions

        # Check for equation setup
        equals_count = reasoning.count('=')
        has_equations = equals_count > len(calculations)  # More equals signs than just calculations
        score += 0.15 if has_equations else 0.0  # 0.15 for equation setup

        # Check for units and dimensional analysis
        units = ["meter", "kg", "second", "gram", "liter", "gallon", "dollar", "cent", "pound", "feet", "inch"]
        has_units = any(unit in reasoning.lower() for unit in units)
        score += 0.1 if has_units else 0.0  # 0.1 for using units

        # Check for mathematical structure (problem setup, work, conclusion)
        has_structure = ("given" in reasoning.lower() or "find" in reasoning.lower() or "step" in reasoning.lower()) and ("therefore" in reasoning.lower() or "so" in reasoning.lower())
        score += 0.15 if has_structure else 0.0  # 0.15 for good structure

        # Check for intermediate results tracking
        intermediate_results = len(re.findall(r'(?:we get|we have|this gives us|result is|equals)\s+\d+', reasoning.lower()))
        score += min(intermediate_results * 0.05, 0.15)  # Up to 0.15 for tracking results

        rewards.append(min(score, 1.0))  # Cap at 1.0

    return rewards

#########################################################################
# UPDATED TRAINER SETUP
#########################################################################

def get_updated_trainer_config(model, tokenizer, training_args, dataset):
    """
    Returns an updated GRPOTrainer configuration with improved reward functions.
    """
    return GRPOTrainer(
        model = model,
        processing_class = tokenizer,
        reward_funcs = [
            # Original functions with improvements
            xmlcount_reward_func,
            soft_format_reward_func,
            strict_format_reward_func,
            int_reward_func,
            correctness_reward_func,

            # New functions for better reasoning quality
            reasoning_coherence_reward,
            faithful_reasoning_reward,
            math_reasoning_quality,
        ],
        args = training_args,
        train_dataset = dataset,
    )

# Example usage:
# trainer = get_updated_trainer_config(model, tokenizer, training_args, dataset)
# trainer.train()

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [25]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm = True, # use vLLM for fast inference!
    learning_rate = 5e-6,
    adam_beta1 = 0.9,
    adam_beta2 = 0.99,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "paged_adamw_8bit",
    logging_steps = 1,
    bf16 = is_bfloat16_supported(),
    fp16 = not is_bfloat16_supported(),
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 6, # Decrease if out of memory
    max_prompt_length = 256,
    max_completion_length = 200,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 250,
    max_grad_norm = 0.1,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 6


And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [None]:
dataset = get_natural_reasoning_questions()

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

Map:   0%|          | 0/1145824 [00:00<?, ? examples/s]

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,145,824 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 6 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (6 x 1 x 1) = 6
 "-____-"     Trainable parameters = 44,236,800/7,888,000,000 (0.56% trained)


-------------------- Question:
Given a block resting on a flat plate that executes vertical simple harmonic motion with a period of 1.5 seconds, determine the maximum amplitude of the motion such that the block remains in contact with the plate throughout the motion. Assume the acceleration due to gravity is 9.81 m/s^2. 
Answer:
[{'response': "## Step 1: Understand the conditions for the block to remain in contact with the plate\nFor the block to remain in contact with the plate throughout the motion, the acceleration of the plate must not exceed the acceleration due to gravity. This is because if the plate's acceleration exceeds gravity's, the block would lose contact with the plate due to being lifted off.\n\n## Step 2: Recall the equation for the acceleration of simple harmonic motion\nThe acceleration \\(a\\) of an object undergoing simple harmonic motion is given by \\(a = -\\omega^2x\\), where \\(\\omega\\) is the angular frequency (\\(\\omega = \\frac{2\\pi}{T}\\), with \\(T\\) 

Step,Training Loss,reward,reward_std,completion_length,kl,rewards / xmlcount_reward_func,rewards / soft_format_reward_func,rewards / strict_format_reward_func,rewards / int_reward_func,rewards / correctness_reward_func
1,-0.0,0.125,0.061237,200.0,0.0,0.125,0.0,0.0,0.0,0.0
2,0.0,0.025,0.061237,200.0,0.0,0.025,0.0,0.0,0.0,0.0
3,0.0,0.05,0.07746,200.0,5e-06,0.05,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,200.0,3e-06,0.0,0.0,0.0,0.0,0.0
5,0.0,0.15,0.0,200.0,5e-06,0.15,0.0,0.0,0.0,0.0
6,0.0,0.025,0.061237,200.0,3e-06,0.025,0.0,0.0,0.0,0.0
7,0.0,0.0,0.0,200.0,2e-06,0.0,0.0,0.0,0.0,0.0
8,0.0,0.15,0.0,200.0,6e-06,0.15,0.0,0.0,0.0,0.0
9,0.0,0.05,0.07746,200.0,3e-06,0.05,0.0,0.0,0.0,0.0
10,0.0,0.025,0.061237,200.0,5e-06,0.025,0.0,0.0,0.0,0.0


-------------------- Question:
Find the point on the graph of f(x) = √x that is closest to the point (4,0). To do this, differentiate the square of the distance from a point (x, √x) on the graph of f to the point (4,0), and solve for x. Then, find the corresponding y value using f(x). 
Answer:
[{'response': '## Step 1: Define the distance function\nThe distance between two points (x1, y1) and (x2, y2) is given by the formula √((x2 - x1)^2 + (y2 - y1)^2). For the points (x, √x) and (4, 0), the distance formula becomes √((x - 4)^2 + (√x - 0)^2). To simplify calculations, we will work with the square of the distance, which is (x - 4)^2 + (√x)^2.\n\n## Step 2: Simplify the square of the distance function\nThe square of the distance can be simplified as (x - 4)^2 + x. Expanding this gives x^2 - 8x + 16 + x, which further simplifies to x^2 - 7x + 16.\n\n## Step 3: Differentiate the square of the distance function\nTo find the minimum distance, we differentiate the square of the distance func

<a name="Inference"></a>
### Inference

Now let's try the model we just trained! First, let's first try the model without any GRPO trained:

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "user", "content" : "Which is bigger? 9.11 or 9.9?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

And now with the LoRA we just trained with GRPO - we first save the LoRA first!

In [None]:
model.save_lora("grpo_saved_lora")

Now we load the LoRA and test:

In [None]:
text = tokenizer.apply_chat_template([
    {"role" : "system", "content" : SYSTEM_PROMPT},
    {"role" : "user", "content" : "Which is bigger? 9.11 or 9.9?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

Processed prompts: 100%|██████████| 1/1 [00:27<00:00, 27.72s/it, est. speed input: 1.70 toks/s, output: 10.03 toks/s]


'<reasoning>\nTo determine which number is bigger between 9.11 and 9.9, we should compare the two numbers digit by digit from left to right. \n\n1. First, compare the digits in the units place:\n   - Both numbers have a 9 in the units place.\n\n2. Next, compare the digits in the tenths place:\n   - The number 9.11 has a 1 in the tenths place.\n   - The number 9.9 has a 9 in the tenths place.\n\nSince 1 is less than 9, the number 9.11 is less than 9.9 based on the tenths place comparison.\n\n3. For thoroughness, consider the hundredths place:\n   - The number 9.11 has a 1 in the hundredths place.\n   - The number 9.9 can be written as 9.90, which has a 0 in the hundredths place.\n\nEven if we compare the hundredths place, 1 is greater than 0, but this is irrelevant since the comparison in the tenths place already determines that 9.11 is smaller than 9.9.\n\nTherefore, 9.9 is greater than 9.11.\n</reasoning>\n\n<answer>\n9.9 is bigger than 9.11.\n</answer>'

In [None]:
print(output)

<reasoning>
To determine which number is bigger between 9.11 and 9.9, we should compare the two numbers digit by digit from left to right. 

1. First, compare the digits in the units place:
   - Both numbers have a 9 in the units place.

2. Next, compare the digits in the tenths place:
   - The number 9.11 has a 1 in the tenths place.
   - The number 9.9 has a 9 in the tenths place.

Since 1 is less than 9, the number 9.11 is less than 9.9 based on the tenths place comparison.

3. For thoroughness, consider the hundredths place:
   - The number 9.11 has a 1 in the hundredths place.
   - The number 9.9 can be written as 9.90, which has a 0 in the hundredths place.

Even if we compare the hundredths place, 1 is greater than 0, but this is irrelevant since the comparison in the tenths place already determines that 9.11 is smaller than 9.9.

Therefore, 9.9 is greater than 9.11.
</reasoning>

<answer>
9.9 is bigger than 9.11.
</answer>


Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!

<a name="Save"></a>
### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False: model.save_pretrained_merged("model", tokenizer, save_method = "lora",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "lora", token = "")

### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/drive/1WZDi7APtQ9VsvOrQSSC5DDtxq159j8iZ?usp=sharing)

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Llama 3.2 Conversational notebook. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(1B_and_3B)-Conversational.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
</div>
