To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ⭐ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News

Unsloth now supports Text-to-Speech (TTS) models. Read our [guide here](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning).

Read our **[Gemma 3N Guide](https://docs.unsloth.ai/basics/gemma-3n-how-to-run-and-fine-tune)** and check out our new **[Dynamic 2.0](https://docs.unsloth.ai/basics/unsloth-dynamic-2.0-ggufs)** quants which outperforms other quantization methods!

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm==0.8.5.post1

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm==0.8.5.post1
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

### Unsloth

Goal: To convert `DeepSeek-R1-0528-Qwen3-8B` into a reasoning model via GRPO by using OpenR1's Math dataset.

We also use `langid` for language detection. Our main goal is to force the model to generate reasoning traces in Indonesian, and we create a reward function using `langid` to check this.

In [None]:
!pip install langid -qq

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.9 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.9/1.9 MB[0m [31m75.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for langid (setup.py) ... [?25l[?25hdone


In [4]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/DeepSeek-R1-0528-Qwen3-8B",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.7, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = lora_rank*2, # *2 speeds up training
    use_gradient_checkpointing = "unsloth", # Reduces memory usage
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 07-10 13:18:14 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 07-10 13:18:14 [__init__.py:239] Automatically detected platform cuda.


Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'attn_factor'}


==((====))==  Unsloth 2025.7.1: Fast Qwen3 patching. Transformers: 4.53.1. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA A100-SXM4-40GB. Num GPUs = 1. Max memory: 39.557 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'attn_factor'}


Unsloth: vLLM loading unsloth/deepseek-r1-0528-qwen3-8b-unsloth-bnb-4bit with actual GPU utilization = 69.2%
Unsloth: Your GPU has CUDA compute capability 8.0 with VRAM = 39.56 GB.
Unsloth: Using conservativeness = 1.0. Chunked prefill tokens = 1024. Num Sequences = 288.
Unsloth: vLLM's KV Cache can use up to 20.85 GB. Also swap space = 6 GB.


Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'attn_factor'}


INFO 07-10 13:18:48 [config.py:717] This model supports multiple tasks: {'generate', 'reward', 'classify', 'embed', 'score'}. Defaulting to 'generate'.
INFO 07-10 13:18:48 [config.py:2003] Chunked prefill is enabled with max_num_batched_tokens=1024.
Unsloth: vLLM Bitsandbytes config using kwargs = {'load_in_8bit': False, 'load_in_4bit': True, 'bnb_4bit_compute_dtype': 'bfloat16', 'bnb_4bit_quant_storage': 'uint8', 'bnb_4bit_quant_type': 'nf4', 'bnb_4bit_use_double_quant': True, 'llm_int8_enable_fp32_cpu_offload': False, 'llm_int8_has_fp16_weight': False, 'llm_int8_skip_modules': ['lm_head', 'multi_modal_projector', 'merger', 'modality_projection', 'model.layers.33.self_attn', 'model.layers.34.self_attn', 'model.layers.1.self_attn', 'model.layers.6.self_attn', 'model.layers.34.mlp', 'model.layers.4.mlp', 'model.layers.2.mlp', 'model.layers.5.mlp', 'model.layers.6.mlp'], 'llm_int8_threshold': 6.0}


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

generation_config.json:   0%|          | 0.00/171 [00:00<?, ?B/s]

INFO 07-10 13:18:50 [core.py:58] Initializing a V1 LLM engine (v0.8.5.post1) with config: model='unsloth/deepseek-r1-0528-qwen3-8b-unsloth-bnb-4bit', speculative_config=None, tokenizer='unsloth/deepseek-r1-0528-qwen3-8b-unsloth-bnb-4bit', skip_tokenizer_init=False, tokenizer_mode=auto, revision=None, override_neuron_config=None, tokenizer_revision=None, trust_remote_code=False, dtype=torch.bfloat16, max_seq_len=1024, download_dir=None, load_format=LoadFormat.BITSANDBYTES, tensor_parallel_size=1, pipeline_parallel_size=1, disable_custom_all_reduce=False, quantization=bitsandbytes, enforce_eager=False, kv_cache_dtype=auto,  device_config=cuda:0, decoding_config=DecodingConfig(guided_decoding_backend='auto', reasoning_backend=None), observability_config=ObservabilityConfig(show_hidden_metrics=False, otlp_traces_endpoint=None, collect_model_forward_time=False, collect_model_execute_time=False), seed=0, served_model_name=unsloth/deepseek-r1-0528-qwen3-8b-unsloth-bnb-4bit, num_scheduler_step

model-00001-of-00002.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

INFO 07-10 13:19:18 [weight_utils.py:281] Time spent downloading weights for unsloth/deepseek-r1-0528-qwen3-8b-unsloth-bnb-4bit: 24.093990 seconds


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 07-10 13:19:23 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 07-10 13:19:24 [gpu_model_runner.py:1347] Model loading took 7.1825 GiB and 30.284388 seconds
INFO 07-10 13:19:48 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/039a8575a4/rank_0_0 for vLLM's torch.compile
INFO 07-10 13:19:48 [backends.py:430] Dynamo bytecode transform time: 23.64 s


Inductor Compilation: 100%|██████████| 6/6 [00:01<00:00,  4.57it/s, triton_poi_fused_add_mul_sub_5]

INFO 07-10 13:19:54 [backends.py:136] Cache the graph of shape None for later use



Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 11.90it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 110.92it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 109.11it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 111.24it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 109.71it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 105.57it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 103.44it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 111.18it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 112.48it/s, triton_poi_fused_add_mul_sub_9]
Inductor Compilation: 100%|██████████| 10/10 [00:00<00:00, 111.72it/s, tr

INFO 07-10 13:21:05 [backends.py:148] Compiling a graph for general shape takes 75.03 s





INFO 07-10 13:23:37 [monitor.py:33] torch.compile takes 98.67 s in total
INFO 07-10 13:23:40 [kv_cache_utils.py:634] GPU KV cache size: 131,440 tokens
INFO 07-10 13:23:40 [kv_cache_utils.py:637] Maximum concurrency for 1,024 tokens per request: 128.36x
INFO 07-10 13:25:23 [gpu_model_runner.py:1686] Graph capturing finished in 102 secs, took 1.56 GiB
INFO 07-10 13:25:23 [core.py:159] init engine (profile, create kv cache, warmup model) took 359.18 seconds
Unsloth: Just some info: will skip parsing ['pre_feedforward_layernorm', 'post_feedforward_layernorm']
Unsloth: Just some info: will skip parsing ['pre_feedforward_layernorm', 'post_feedforward_layernorm']


tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/472 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

Unsloth 2025.7.1 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


### GRPO Chat Template

Distill Qwen3 from Deepseek has a chat template that is used to format the input and output of the model. This is used to make the model output in a chat format. Including the reasoning step. We have to use that chat template since the model is trained using it.

Let's see how our chat template behaves on an example:

In [5]:
reasoning_start = None
reasoning_end = None
user_token = None
assistant_token = None

for token in tokenizer.get_added_vocab().keys():
    if "think" in token and "/" in token:
        reasoning_end = token
    elif "think" in token:
        reasoning_start = token
    elif "user" in token:
        user_token = token
    elif "assistant" in token:
        assistant_token = token

system_prompt = \
f"""You are given a problem.
Think about the problem and provide your working out.
You must think in Bahasa Indonesia."""
system_prompt

'You are given a problem.\nThink about the problem and provide your working out.\nYou must think in Bahasa Indonesia.'

In [6]:
print(tokenizer.apply_chat_template([
    {"role" : "user", "content" : "What is 1+1?"},
    {"role" : "assistant", "content" : f"<think>I think it's 2.2</think>2"},
    {"role" : "user", "content" : "What is 1+1?"},
    {"role" : "assistant", "content" : f"<think>I think it's 2.2</think>2"},
], tokenize = False, add_generation_prompt = True))

<｜begin▁of▁sentence｜><｜User｜>What is 1+1?<｜Assistant｜>2<｜end▁of▁sentence｜><｜User｜>What is 1+1?<｜Assistant｜>2<｜end▁of▁sentence｜><｜Assistant｜>


### Data Prep
<a name="Data"></a>

We're using Hugging Face's [Open R1 Math dataset](https://huggingface.co/datasets/open-r1/DAPO-Math-17k-Processed). You can also utilize OpenAI's famous [GSM8K dataset](https://huggingface.co/datasets/openai/gsm8k)

In [7]:
from datasets import load_dataset
dataset = load_dataset("open-r1/DAPO-Math-17k-Processed", "en", split = "train")
dataset

README.md: 0.00B [00:00, ?B/s]

en/train-00000-of-00001.parquet:   0%|          | 0.00/5.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/14116 [00:00<?, ? examples/s]

Dataset({
    features: ['prompt', 'solution', 'data_source', 'source_prompt', 'ability', 'reward_model', 'extra_info'],
    num_rows: 14116
})

Let's look at the first row:

In [8]:
dataset[0]["prompt"]

'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$ be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$ and $\\angle BDC = 90^\\circ$. Suppose that $AD = 1$ and that $\\frac{BD}{CD} = \\frac{3}{2}$. If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.'

In [9]:
dataset[0]["solution"]

'34'

In GSM8K, ee notice all answers like about have a ####, so we extract it. But for the Open R1 dataset, we can skip the below.

In [10]:
def extract_hash_answer(text):
    # if "####" not in text: return None
    # return text.split("####")[1].strip()
    return text
extract_hash_answer(dataset[0]["solution"])

'34'

Let's map the dataset! and see the first row:

In [11]:
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": x["prompt"]},
    ],
    "answer": extract_hash_answer(x["solution"]),
})
dataset[0]

Map:   0%|          | 0/14116 [00:00<?, ? examples/s]

{'prompt': [{'content': 'You are given a problem.\nThink about the problem and provide your working out.\nYou must think in Bahasa Indonesia.',
   'role': 'system'},
  {'content': 'In triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$ be a point outside triangle $ABC$ such that $\\angle BAD = \\angle DAC$ and $\\angle BDC = 90^\\circ$. Suppose that $AD = 1$ and that $\\frac{BD}{CD} = \\frac{3}{2}$. If $AB + AC$ can be expressed in the form $\\frac{a\\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.',
   'role': 'user'}],
 'solution': '34',
 'data_source': 'math_dapo',
 'source_prompt': [{'content': 'Solve the following math problem step by step. The last line of your response should be of the form Answer: $Answer (without quotes) where $Answer is the answer to the problem.\n\nIn triangle $ABC$, $\\sin \\angle A = \\frac{4}{5}$ and $\\angle A < 90^\\circ$. Let $D$ be a point outside triangle $ABC$ such that $\\angl

We create a regex format to match the reasoning sections and answers:

In [12]:
import re

# Add optional EOS token matching
solution_end_regex = rf"{reasoning_end}(.*)"

match_format = re.compile(solution_end_regex, re.DOTALL)
match_format

re.compile(r'</think>(.*)', re.DOTALL|re.UNICODE)

We verify it works:

In [13]:
match_format.findall(
    "Let me think!</think>"\
    f"Hence, the solution is 2.",
)

['Hence, the solution is 2.']

In [14]:
match_format.findall(
    "<think>Let me think!</think>"\
    f"\n\nHence, the solution is 2",
)

['\n\nHence, the solution is 2']

We now want to create a reward function to match the format exactly - we reward it with 3 points if it succeeds:

In [15]:
# This function checks if each completion in the list strictly matches a required output format.
# It returns a list of scores, giving 3.0 points if the format is matched exactly, otherwise 0.
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

If it fails, we want to reward the model if it at least follows the format partially, by counting each symbol:

In [16]:
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!

        # No need to reward <think> since we always prepend it!
        score += 0.5 if response.count(reasoning_start) == 1 else -1.0
        score += 0.5 if response.count(reasoning_end)   == 1 else -1.0
        scores.append(score)
    return scores

We want to extract the generated answer, and reward or penalize it! We also reward it based on how close the answer is to the true one via ratios:

In [17]:
def check_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(-2.0)
            continue
        # Correct answer gets 5 points!
        if guess == true_answer:
            score += 5.0
        # Match if spaces are seen, but less reward
        elif guess.strip() == true_answer.strip():
            score += 3.5
        else:
            # We also reward it if the answer is close via ratios!
            # Ie if the answer is within some range, reward it!
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 2.0
                elif ratio >= 0.8 and ratio <= 1.2: score += 1.5
                else: score -= 2.5 # Penalize wrong answers
            except:
                score -= 4.5 # Penalize
        scores.append(score)
    return scores

Also sometimes it might not be 1 number as the answer, but like a sentence for example "The solution is $20" -> we extract 20.

We also remove possible commas for example as in 123,456

In [18]:
match_numbers = re.compile(
    r".*?[\s]{0,}([-]?[\d\.\,]{1,})",
    flags = re.MULTILINE | re.DOTALL
)
print(match_numbers.findall("  0.34  "))
print(match_numbers.findall("  123,456  "))
print(match_numbers.findall("  -0.234  "))
print(match_numbers.findall("17"))

['0.34']
['123,456']
['-0.234']
['17']


Finally, we will try to enforce the thinking process to be in Bahasa Indonesia. This is a simple version of the `language consistency reward` that is used in DeepSeek R1 paper

In [19]:
import langid

def get_lang(text: str) -> str:
    if not text:
        return "und"
    lang, _ = langid.classify(text)
    return lang


print(get_lang("Hello, How are you")) # This should return en
print(get_lang("Aku berpikir kalau aku adalah kamu")) # This should return id
print(get_lang("我在这里")) # This should return zh

en
id
zh


In [20]:
import re

# The function below evaluates a list of completions (model outputs) and assigns a reward score based on the detected language of each completion.
# It is designed to encourage completions written in Bahasa Indonesia ('id'), and penalize those in English ('en'), Chinese ('zh'), or any other language.
# If a completion is malformed (e.g., missing expected structure), it assigns a strong penalty and prints a warning.

def format_and_language_reward_func(completions, **kwargs):
    scores = []

    for completion_item in completions:
        # Check if the completion item is well-formed: it should be a non-empty list, with the first element a dict containing a "content" key.
        if not completion_item or not isinstance(completion_item[0], dict) or "content" not in completion_item[0]:
            scores.append(-5.0)  # Assign a strong penalty for malformed input
            print(f"Warning: Malformed completion item, assigning default low score: {completion_item}")
            continue

        content = completion_item[0]["content"]  # Extract the text content from the completion

        lang = get_lang(content)  # Detect the language of the content using the get_lang function

        # Assign scores based on detected language:
        # Bahasa Indonesia ('id') gets the highest reward,
        # English ('en') and Chinese ('zh') get a penalty,
        # Any other language gets a stronger penalty.
        if lang == 'id':
            score = 5.0
        elif lang == 'en':
            score = -3.0
        elif lang == 'zh':
            score = -3.0
        else:
            score = -5.0

        scores.append(score)  # Add the score for this completion

    return scores  # Return the list of scores for all completions

In [21]:
prompts = [
    [{"role": "assistant", "content": "What is the result of (1 + 2) * 4?"}],
    [{"role": "assistant", "content": "What is the result of (3 + 1) * 2?"}],
]
completions = [
    [{"role": "assistant", "content": "<think>The sum of 1 and 2 is 3, which we multiply by 4 to get 12.</think><answer>(1 + 2) * 4 = 12</answer>"}],
    [{"role": "assistant", "content": "The sum of 3 and 1 is 4, which we multiply by 2 to get 8. So (3 + 1) * 2 = 8."}],
]
format_and_language_reward_func(prompts=prompts, completions=completions)

[-3.0, -3.0]

We now prepare our main function which will print out the generated responses and the true answer, along with another reward function which converts text to float via `float` and sees if it's the same.

In [22]:
# The selected code defines a function `check_numbers` that evaluates whether the model's generated responses contain the correct numerical answer.
# It also includes logic to print out the question, the true answer, the model's response, and the extracted number from the response at regular intervals for debugging.

# Global variables to control how often information is printed
global PRINTED_TIMES
PRINTED_TIMES = 0
global PRINT_EVERY_STEPS
PRINT_EVERY_STEPS = 5

def check_numbers(prompts, completions, answer, **kwargs):
    """
    Evaluates model completions by extracting numbers from the responses and comparing them to the true answer.
    Prints debug information every PRINT_EVERY_STEPS calls.

    Args:
        prompts: List of prompt messages (each a list of dicts with 'content').
        completions: List of model completions (each a list of dicts with 'content').
        answer: List of true answers (as strings or numbers).
        **kwargs: Additional arguments (unused).

    Returns:
        scores: List of float scores for each completion.
    """
    # Get the question text from the last message in the first prompt
    question = prompts[0][-1]["content"]
    # Get the response text from each completion
    responses = [completion[0]["content"] for completion in completions]

    # Try to extract a number from each response using a regex (match_numbers)
    # For each response, if a number is found, extract it; otherwise, set to None
    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None
        for r in responses
    ]

    scores = []
    # Print debug info every PRINT_EVERY_STEPS calls
    global PRINTED_TIMES
    global PRINT_EVERY_STEPS
    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
        print(
            '*'*20 + f"Question:\n{question}",
            f"\nAnswer:\n{answer[0]}",
            f"\nResponse:\n{responses[0]}",
            f"\nExtracted:\n{extracted_responses[0]}"
        )
    PRINTED_TIMES += 1

    # For each extracted guess and true answer, compare numerically
    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            # If no number could be extracted, assign a penalty
            scores.append(-2.5)
            continue
        try:
            # Convert both guess and true answer to float for comparison
            true_answer = float(true_answer.strip())
            # Remove commas from guess (e.g., "1,234" -> "1234")
            guess = float(guess.strip().replace(",", ""))
            # Assign a high score if correct, else a penalty
            scores.append(3.5 if guess == true_answer else -1.5)
        except:
            # If conversion fails, assign a neutral score
            scores.append(0)
            continue
    return scores

Get the top 90% prompt length so we don't accidentally truncate them!

Ie we'll remove the top 10% long prompts.

In [23]:
tokenized = dataset.map(
    lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
    batched = True,
)
print(tokenizer.decode(tokenized[0]["tokens"]))
tokenized = tokenized.map(lambda x: {"L" : len(x["tokens"])})

import numpy as np
maximum_length = int(np.quantile(tokenized["L"], 0.9))
print("Max Length = ", maximum_length)

# Filter only samples smaller than 90% max length
dataset = dataset.select(np.where(np.array(tokenized["L"]) <= maximum_length)[0])
del tokenized

Map:   0%|          | 0/14116 [00:00<?, ? examples/s]

<｜begin▁of▁sentence｜>You are given a problem.
Think about the problem and provide your working out.
You must think in Bahasa Indonesia.<｜User｜>In triangle $ABC$, $\sin \angle A = \frac{4}{5}$ and $\angle A < 90^\circ$. Let $D$ be a point outside triangle $ABC$ such that $\angle BAD = \angle DAC$ and $\angle BDC = 90^\circ$. Suppose that $AD = 1$ and that $\frac{BD}{CD} = \frac{3}{2}$. If $AB + AC$ can be expressed in the form $\frac{a\sqrt{b}}{c}$ where $a, b, c$ are pairwise relatively prime integers, find $a + b + c$.<｜Assistant｜>


Map:   0%|          | 0/14116 [00:00<?, ? examples/s]

Max Length =  180


<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [24]:
max_prompt_length = maximum_length + 1 # + 1 just in case!
max_completion_length = max_seq_length - max_prompt_length

from vllm import SamplingParams
vllm_sampling_params = SamplingParams(
    min_p = 0.1,
    top_p = 1.0,
    top_k = -1,
    seed = 3407,
    stop = [tokenizer.eos_token],
    include_stop_str_in_output = True,
)

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    vllm_sampling_params = vllm_sampling_params,
    temperature = 1.0,
    learning_rate = 5e-6,
    weight_decay = 0.01,
    warmup_ratio = 0.1,
    lr_scheduler_type = "linear",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 1, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_completion_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 100,
    save_steps = 100,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",

    # For optional training + evaluation
    # fp16_full_eval = True,
    # per_device_eval_batch_size = 4,
    # eval_accumulation_steps = 1,
    # eval_strategy = "steps",
    # eval_steps = 1,
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 4


And let's run the trainer! If you scroll up, you'll see a table of rewards. The goal is to see the `reward` column increase!

You might have to wait 150 to 200 steps for any action. You'll probably get 0 reward for the first 100 steps. Please be patient!

| Step | Training Loss | reward    | reward_std | completion_length | kl       |
|------|---------------|-----------|------------|-------------------|----------|
| 1    | 0.000000      | 0.125000  | 0.000000   | 200.000000        | 0.000000 |
| 2    | 0.000000      | 0.072375  | 0.248112   | 200.000000        | 0.000000 |
| 3    | 0.000000      | -0.079000 | 0.163776   | 182.500000        | 0.000005 |


In [None]:
# For optional training + evaluation
# new_dataset = dataset.train_test_split(test_size = 0.01)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
        format_and_language_reward_func,
    ],
    args = training_args,
    train_dataset = dataset,

    # For optional training + evaluation
    # train_dataset = new_dataset["train"],
    # eval_dataset = new_dataset["test"],
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 12,728 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 1 x 1) = 4
 "-____-"     Trainable parameters = 87,293,952 of 8,278,029,312 (1.05% trained)


********************Question:
In the diagram, each of the three identical circles touch the other two.  The circumference of each circle is 36.  What is the perimeter of the shaded region? [asy]

defaultpen(1);

path p = (1, 0){down}..{-dir(30)}dir(-60){dir(30)}..{dir(-30)}((2, 0) + dir(-120)){-dir(-30)}..{up}(1, 0)--cycle;
fill(p, gray(0.75));

draw(unitcircle);
draw(shift(2 * dir(-60)) * unitcircle);
draw(shift(2) * unitcircle);
[/asy] 
Answer:
18 
Response:
<think>
Saya perlu mencari keliling daerah yang diarsir dalam diagram. Ada tiga lingkaran identik, masing-masing dengan keliling 36. Jadi, untuk setiap lingkaran, keliling adalah 36, yang berarti panjang jari-jarinya adalah \( r = \frac{36}{2\pi} = \frac{18}{\pi} \).

Ketiga lingkaran saling bersentuhan. Jadi, ketika dua lingkaran bersentuhan, pusat mereka adalah satu sama lain. Jarak antar pusat harus sama dengan diameter lingkaran karena bersentuhan secara eksternal. Jika dua lingkaran bersentuhan, jarak antar pusat adalah \( 2

Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_exactly / mean,rewards / match_format_exactly / std,rewards / match_format_approximately / mean,rewards / match_format_approximately / std,rewards / check_answer / mean,rewards / check_answer / std,rewards / check_numbers / mean,rewards / check_numbers / std,rewards / format_and_language_reward_func / mean,rewards / format_and_language_reward_func / std
1,-0.0,0.5,4.0,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.5,0.0,-2.0,0.0,0.0,0.0,3.0,4.0
2,-0.0,0.5,4.0,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.5,0.0,-2.0,0.0,0.0,0.0,3.0,4.0
3,0.0,-2.25,5.484828,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.000418,0.0,0.0,-0.5,0.0,-2.0,0.0,-0.75,0.866025,1.0,4.618802


Unsloth: Will smartly offload gradients to save VRAM!


Step,Training Loss,reward,reward_std,completions / mean_length,completions / min_length,completions / max_length,completions / clipped_ratio,completions / mean_terminated_length,completions / min_terminated_length,completions / max_terminated_length,kl,rewards / match_format_exactly / mean,rewards / match_format_exactly / std,rewards / match_format_approximately / mean,rewards / match_format_approximately / std,rewards / check_answer / mean,rewards / check_answer / std,rewards / check_numbers / mean,rewards / check_numbers / std,rewards / format_and_language_reward_func / mean,rewards / format_and_language_reward_func / std
1,-0.0,0.5,4.0,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.5,0.0,-2.0,0.0,0.0,0.0,3.0,4.0
2,-0.0,0.5,4.0,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,-0.5,0.0,-2.0,0.0,0.0,0.0,3.0,4.0
3,0.0,-2.25,5.484828,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.000418,0.0,0.0,-0.5,0.0,-2.0,0.0,-0.75,0.866025,1.0,4.618802
4,0.0,-5.875,0.75,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.001017,0.0,0.0,-0.5,0.0,-2.0,0.0,-0.375,0.75,-3.0,0.0
5,0.0,1.75,0.866025,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.000442,0.0,0.0,-0.5,0.0,-2.0,0.0,-0.75,0.866025,5.0,0.0
6,0.0,-7.0,0.0,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.001041,0.0,0.0,-0.5,0.0,-2.0,0.0,-1.5,0.0,-3.0,0.0
7,0.0,-7.0,0.0,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.000855,0.0,0.0,-0.5,0.0,-2.0,0.0,-1.5,0.0,-3.0,0.0
8,0.0,1.0,0.0,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.000717,0.0,0.0,-0.5,0.0,-2.0,0.0,-1.5,0.0,5.0,0.0
9,0.0,-7.0,0.0,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.000435,0.0,0.0,-0.5,0.0,-2.0,0.0,-1.5,0.0,-3.0,0.0
10,0.0,-0.625,4.308422,843.0,843.0,843.0,1.0,0.0,0.0,0.0,0.000617,0.0,0.0,-0.5,0.0,-2.0,0.0,-1.125,0.75,3.0,4.0


********************Question:
In a class of 20 students, all but 4 of the students put their names on a typed assignment. If the teacher randomly guesses, what is the probability that she correctly guesses which paper belongs to each of the four remaining students? Express your answer as a common fraction.The answer is in the form rac{m}{n}, where gcd(m, n) = 1. Please provide the value of m + n. 
Answer:
25 
Response:
<think>
The problem states there are 20 students, and all but 4 put their names on their assignments. So, 16 students have their names on their papers, and 4 do not. The teacher guesses randomly, and I need to find the probability that she correctly identifies which paper belongs to each of the four students who didn't put their names on.

First, I need to understand what's being asked. The teacher is guessing which paper is whose for these four students. But there are 20 papers in total, and she's trying to assign names to the papers that don't have names, I suppose.



<a name="Inference"></a>
### Inference
Now let's try the model we just trained! First, let's first try the model without any GRPO trained:

In [None]:
text = "What is the sqrt of 101?"

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

' | Socratic\nWhat is the sqrt of 101?\nAlgebra\nQuestion\nVincent A.\nAnswer\n10.049875, or as an irrational number, it is not integer, I see what you mean though.\nExplanation:\nsqrt(100) = 10, sqrt(121)=11, so sqrt(101) is between 10 and 11, and there is not integer between 10 and 11, so it is irrational.\nWe can leave it as sqrt(101) or give the approximate value.\nThe question asks "what is the sqrt of 101", and it is a math problem, so probably they want to know if it is integer or not.\nOr perhaps do some calculations.\nI think the intended answer is that it is irrational and not an integer, but since 10^2=100 and 11^2=121, and 10^2 <101<11^2, so no integer square.\nI could also use calculator, but I think since it\'s a math problem, perhaps we need to show'

And now with the LoRA we just trained with GRPO - we first save the LoRA first!

In [None]:
model.save_lora("grpo_lora")

Unrecognized keys in `rope_scaling` for 'rope_type'='yarn': {'attn_factor'}


Verify LoRA is actually trained!

In [None]:
from safetensors import safe_open

tensors = {}
with safe_open("grpo_lora/adapter_model.safetensors", framework = "pt") as f:
    # Verify both A and B are non zero
    for key in f.keys():
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())

Now we load the LoRA and test. We tested without using our custom system prompt which should not (or minimal) affect toward the model's original reasoning ability.:

In [None]:
messages = [
    {"role": "user",   "content": "Solve (x + 2)^2 = 0"},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 2048,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_lora"),
)[0].outputs[0].text

output

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

"<think>\nI have this equation: (x + 2)² = 0. It looks simple, but I need to solve for x. Since it's a squared term equal to zero, that means the thing inside the parentheses must be zero because only zero squared is zero.\n\nSo, if (x + 2)² = 0, then x + 2 must be equal to zero. Because if x + 2 were anything else, say 1, squared is 1, which is not zero. Or -1, squared is also 1, not zero. So only when x + 2 is zero, the square is zero.\n\nSo, x + 2 = 0, which means x = -2.\n\nI think that's it. Let me verify by plugging it back into the equation.\n\nx = -2, so ( -2 + 2 )² = (0)² = 0, which equals 0. Perfect.\n\nI recall that in algebra, this is related to the zero product property or something. Basically, if a product is zero, then one of the factors must be zero. Here, it's not a product, but"

Next, let's test using our system prompt which should use the new language :

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": "Solve (x + 2)^2 = 0"},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 2048,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_lora"),
)[0].outputs[0].text

output

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

"<think>\nI need to solve the equation (x + 2)^2 = 0. It's a quadratic equation, but it's squared, so it might be simpler. Let me think about this in Bahasa Indonesia first.\n\nPermulaan: Saya diberikan persamaan (x + 2)^2 = 0, dan saya harus menyelesaikannya. Saya perlu mencari nilai x yang memenuhi persamaan ini.\n\nSaya mulai dengan merasakan apa itu persamaan. Ini adalah persamaan kuadrat karena ada eksponen 2, tapi tidak semua persamaan kuadrat memiliki dua solusi; beberapa mungkin memiliki solusi ganda.\n\nSaya tahu bahwa jika sesuatu dikalikan dengan dirinya sendiri dan hasilnya nol, maka satu di antaranya harus nol. Jadi, untuk (x + 2)^2 = 0, itu berarti x + 2 harus sama dengan 0, karena jika x + 2 tidak"

Lets compare our results with system prompt but without our LoRA

In [None]:
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": "Solve (x + 2)^2 = 0"},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 1.0,
    top_k = 50,
    max_tokens = 2048,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

'<think>\nBaik, mari kita selesaikan persamaan kuadrat sederhana ini. Soalnya adalah: Solve (x + 2)^2 = 0.\n\nPertama, saya perlu memahami apa yang diminta. Ini adalah persamaan kuadrat yang diatur dalam bentuk kuadrat. Saya harus mencari nilai dari x yang memenuhi persamaan tersebut.\n\nSaya tahu bahwa ketika suatu bilangan kuadrat sama dengan nol, itu berarti bilangan tersebut adalah nol. Jadi, dalam hal ini, (x + 2)^2 = 0 berarti kuadrat dari (x + 2) adalah nol. Maka, untuk kuadrat suatu bilangan sama dengan nol, bilangan tersebut haruslah nol itu sendiri.\n\nOleh karena itu, (x + 2) haruslah sama dengan nol. Jadi, x + 2 = 0.\n\nSekarang, untuk mencari x, saya'

Let's take 20 samples, and compare the the amount of using our LoRA and not using it, and see which one has better amount of correct language

In [None]:
sample_dataset = dataset.shuffle(seed = 3407).select(range(20))
sample_dataset

Dataset({
    features: ['prompt', 'solution', 'data_source', 'source_prompt', 'ability', 'reward_model', 'extra_info', 'answer'],
    num_rows: 20
})

In [None]:
with_lora_id_count = 0
without_lora_id_count = 0

print("Comparing language usage with and without LoRA on 20 samples:")
print("=" * 60)

for i, sample in enumerate(sample_dataset):
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": sample["prompt"][1]["content"]},
    ]

    text = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )

    output_with_lora = model.fast_generate(
        text,
        sampling_params=sampling_params,
        lora_request=model.load_lora("grpo_lora"),
    )[0].outputs[0].text

    output_without_lora = model.fast_generate(
        text,
        sampling_params=sampling_params,
        lora_request=None,
    )[0].outputs[0].text

    lang_with_lora = get_lang(output_with_lora)
    lang_without_lora = get_lang(output_without_lora)

    if lang_with_lora == 'id':
        with_lora_id_count += 1
    if lang_without_lora == 'id':
        without_lora_id_count += 1

    # Print progress every 5 samples
    if (i + 1) % 5 == 0:
        print(f"Processed {i + 1}/20 samples...")

print("\n" + "=" * 60)
print("RESULTS:")
print(f"With LoRA - Indonesian responses: {with_lora_id_count}/20 ({with_lora_id_count/20*100:.1f}%)")
print(f"Without LoRA - Indonesian responses: {without_lora_id_count}/20 ({without_lora_id_count/20*100:.1f}%)")
print(f"Improvement: +{with_lora_id_count - without_lora_id_count} Indonesian responses with LoRA")

Comparing language usage with and without LoRA on 20 samples:


Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed 5/20 samples...


Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed 10/20 samples...


Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed 15/20 samples...


Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed prompts:   0%|          | 0/1 [00:00<?, ?it/s, est. speed input: 0.00 toks/s, output: 0.00 toks/s]

Processed 20/20 samples...

RESULTS:
With LoRA - Indonesian responses: 16/20 (80.0%)
Without LoRA - Indonesian responses: 9/20 (45.0%)
Improvement: +7 Indonesian responses with LoRA


Our reasoning model is much better - it's not always correct, since we only trained it for an hour or so - it'll be better if we extend the sequence length and train for longer!

<a name="Save"></a>
### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False:
    model.save_pretrained("model")
    tokenizer.save_pretrained("model")
if False:
    model.push_to_hub("hf/model", token = "")
    tokenizer.push_to_hub("hf/model", token = "")


### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now! We clone `llama.cpp` and we default save it to `q8_0`. We allow all methods like `q4_k_m`. Use `save_pretrained_gguf` for local saving and `push_to_hub_gguf` for uploading to HF.

Some supported quant methods (full list on our [Wiki page](https://github.com/unslothai/unsloth/wiki#gguf-quantization-options)):
* `q8_0` - Fast conversion. High resource use, but generally acceptable.
* `q4_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q4_K.
* `q5_k_m` - Recommended. Uses Q6_K for half of the attention.wv and feed_forward.w2 tensors, else Q5_K.

[**NEW**] To finetune and auto export to Ollama, try our [Ollama notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )

Now, use the `model-unsloth.gguf` file or `model-unsloth-Q4_K_M.gguf` file in llama.cpp or a UI based system like Jan or Open WebUI. You can install Jan [here](https://github.com/janhq/jan) and Open WebUI [here](https://github.com/open-webui/open-webui)

And we're done! If you have any questions on Unsloth, we have a [Discord](https://discord.gg/unsloth) channel! If you find any bugs or want to keep updated with the latest LLM stuff, or need help, join projects etc, feel free to join our Discord!

Some other links:
1. Train your own reasoning model - Llama GRPO notebook [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.1_(8B)-GRPO.ipynb)
2. Saving finetunes to Ollama. [Free notebook](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3_(8B)-Ollama.ipynb)
3. Llama 3.2 Vision finetuning - Radiography use case. [Free Colab](https://colab.research.google.com/github/unslothai/notebooks/blob/main/nb/Llama3.2_(11B)-Vision.ipynb)
6. See notebooks for DPO, ORPO, Continued pretraining, conversational finetuning and more on our [documentation](https://docs.unsloth.ai/get-started/unsloth-notebooks)!

<div class="align-center">
  <a href="https://unsloth.ai"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
  <a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord.png" width="145"></a>
  <a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a>

  Join Discord if you need help + ⭐️ <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ⭐️
</div>
