In [1]:
from unsloth import FastLanguageModel
import torch

from huggingface_hub import snapshot_download

snapshot_download(repo_id="bert-base-uncased", cache_dir="/workspace/cache")

max_seq_length = 2048 # Can increase for longer reasoning traces
lora_rank = 64 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/Llama-3.2-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = False, # False for LoRA 16bit
    fast_inference = True, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 07-06 02:35:12 [importing.py:53] Triton module has been replaced with a placeholder.
INFO 07-06 02:35:12 [__init__.py:239] Automatically detected platform cuda.
Standard import failed for UnslothORPOTrainer: No module named 'UnslothORPOTrainer'. Using tempfile instead!
==((====))==  Unsloth 2025.6.12: Fast Llama patching. Transformers: 4.53.1. vLLM: 0.8.5.post1.
   \\   /|    NVIDIA A100-SXM4-80GB. Num GPUs = 1. Max memory: 79.138 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: vLLM loading unsloth/Llama-3.2-3B-Instruct with actual GPU utilization = 59.66%
Unsloth: Y

Loading safetensors checkpoint shards:   0% Completed | 0/2 [00:00<?, ?it/s]


INFO 07-06 02:35:32 [loader.py:458] Loading weights took 1.17 seconds
INFO 07-06 02:35:32 [punica_selector.py:18] Using PunicaWrapperGPU.
INFO 07-06 02:35:33 [gpu_model_runner.py:1347] Model loading took 6.2472 GiB and 2.335246 seconds
INFO 07-06 02:35:56 [backends.py:420] Using cache directory: /root/.cache/vllm/torch_compile_cache/487b710aa9/rank_0_0 for vLLM's torch.compile
INFO 07-06 02:35:56 [backends.py:430] Dynamo bytecode transform time: 11.28 s


Inductor Compilation: 100%|██████████| 6/6 [00:00<00:00,  7.80it/s, triton_poi_fused_cat_5]                                     


INFO 07-06 02:36:00 [backends.py:136] Cache the graph of shape None for later use


Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00,  8.88it/s, triton_poi_fused_cat_7]                            
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 105.08it/s, triton_poi_fused_cat_7]                           
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 98.18it/s, triton_poi_fused_cat_7]                            
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 91.49it/s, triton_poi_fused_cat_7]                            
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 111.09it/s, triton_poi_fused_cat_7]                           
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 106.89it/s, triton_poi_fused_cat_7]                           
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 106.42it/s, triton_poi_fused_cat_7]                           
Inductor Compilation: 100%|██████████| 8/8 [00:00<00:00, 102.90it/s, triton_poi_fused_cat_7]                           
Inductor Compilation: 100%|██████████| 8

INFO 07-06 02:36:34 [backends.py:148] Compiling a graph for general shape takes 36.61 s





INFO 07-06 02:38:10 [monitor.py:33] torch.compile takes 47.88 s in total
INFO 07-06 02:38:14 [kv_cache_utils.py:634] GPU KV cache size: 368,208 tokens
INFO 07-06 02:38:14 [kv_cache_utils.py:637] Maximum concurrency for 2,048 tokens per request: 179.79x


Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7a9f2c6a5290>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


KeyboardInterrupt: 

In [None]:
from datasets import load_dataset
dataset = load_dataset("tech-tao/my-reasoning-traces-10k", "main", split = "train")
dataset

In [None]:
dataset[0]["prompt"]

In [None]:
dataset[0]["response"]

In [None]:
def extract_hash_answer(text):
    if "####" not in text: return None
    return text.split("####")[1].strip()
extract_hash_answer(dataset[0]["response"])

In [None]:
reasoning_start = "<think>"
reasoning_end   = "</think>"
answer_start = "<answer>"
answer_end = "</answer>"

system_prompt = \
f"""You are a thoughtful assistant.
You will first think systematically before answering.
Place your thought process between {reasoning_start} and {reasoning_end}.
Then, provide your answer between {answer_start} and {answer_end}"""
system_prompt

In [None]:
dataset = dataset.map(lambda x: {
    "prompt" : [
        {"role": "system", "content": system_prompt},
        {"role": "user",   "content": x["prompt"]},
    ],
    "answer": extract_hash_answer(x["response"]),
})
dataset[0]

In [None]:
import re

match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{answer_start}(.+?){answer_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

match_format.search(
    "<think>Let me think!</think>"\
    "<answer>2</answer>",
)

In [None]:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.5 if response.count(reasoning_start) == 1 else -1.0
        score += 0.5 if response.count(reasoning_end)   == 1 else -1.0
        score += 0.5 if response.count(solution_start)  == 1 else -1.0
        score += 0.5 if response.count(solution_end)    == 1 else -1.0
        scores.append(score)
    return scores

def check_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        # Correct answer gets 3 points!
        if guess == true_answer:
            score += 3.0
        # Match if spaces are seen, but less reward
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            # We also reward it if the answer is close via ratios!
            # Ie if the answer is within some range, reward it!
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 1.0
                elif ratio >= 0.8 and ratio <= 1.2: score += 0.5
                else: score -= 1.5 # Penalize wrong answers
            except:
                score -= 1.5 # Penalize
        scores.append(score)
    return scores

In [None]:
match_numbers = re.compile(
    solution_start + r".*?([\d\.\,]{1,})",
    flags = re.MULTILINE | re.DOTALL
)
print(match_numbers.findall("<answer>  0.34  </answer>"))
print(match_numbers.findall("<answer>  123,456  </answer>"))

## Fine tune the model with GRPO

In [None]:
max_prompt_length = 1024

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 4, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    max_completion_length = max_seq_length - max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 500,
    save_steps = 250,
    max_grad_norm = 1.0,
    report_to = "none", # Can use Weights & Biases
    output_dir = "outputs",
)

trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args = training_args,
    train_dataset = dataset,
)
trainer.train()

## Inference

In [None]:
text = tokenizer.apply_chat_template([
    {"role": "user", "content": "What is the sqrt of 101?"},
], tokenize = False, add_generation_prompt = True)

from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    [text],
    sampling_params = sampling_params,
    lora_request = None,
)[0].outputs[0].text

output

In [None]:
model.save_lora("grpo_saved_lora")

### Verify LoRA

In [None]:
from safetensors import safe_open

tensors = {}
with safe_open("grpo_saved_lora/adapter_model.safetensors", framework = "pt") as f:
    # Verify both A and B are non zero
    for key in f.keys():
        tensor = f.get_tensor(key)
        n_zeros = (tensor == 0).sum() / tensor.numel()
        assert(n_zeros.item() != tensor.numel())


messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user",   "content": "What is the sqrt of 101?"},
]

text = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt = True, # Must add for generation
    tokenize = False,
)
from vllm import SamplingParams
sampling_params = SamplingParams(
    temperature = 0.8,
    top_p = 0.95,
    max_tokens = 1024,
)
output = model.fast_generate(
    text,
    sampling_params = sampling_params,
    lora_request = model.load_lora("grpo_saved_lora"),
)[0].outputs[0].text

output

### Saving as fb16 for VLLM

In [None]:
# Merge to 16bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False: model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False:
    model.save_pretrained("model")
    tokenizer.save_pretrained("model")
if False:
    model.push_to_hub("hf/model", token = "")
    tokenizer.push_to_hub("hf/model", token = "")

### GGUF Conversion

In [None]:
# Save to 8bit Q8_0
if False: model.save_pretrained_gguf("model", tokenizer,)
# Remember to go to https://huggingface.co/settings/tokens for a token!
# And change hf to your username!
if False: model.push_to_hub_gguf("hf/model", tokenizer, token = "")

# Save to 16bit GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "f16")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "f16", token = "")

# Save to q4_k_m GGUF
if False: model.save_pretrained_gguf("model", tokenizer, quantization_method = "q4_k_m")
if False: model.push_to_hub_gguf("hf/model", tokenizer, quantization_method = "q4_k_m", token = "")

# Save to multiple GGUF options - much faster if you want multiple!
if False:
    model.push_to_hub_gguf(
        "hf/model", # Change hf to your username!
        tokenizer,
        quantization_method = ["q4_k_m", "q8_0", "q5_k_m",],
        token = "",
    )