In [1]:
from unsloth import FastLanguageModel, is_bfloat16_supported
import torch

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 04-24 19:41:51 [__init__.py:239] Automatically detected platform cuda.


In [2]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 1024 # Can increase for longer reasoning traces
lora_rank = 32 # Larger rank = smarter, but slower

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen2.5-3B-Instruct",
    max_seq_length = max_seq_length,
    load_in_4bit = True, # False for LoRA 16bit
    fast_inference = False, # Enable vLLM fast inference
    max_lora_rank = lora_rank,
    gpu_memory_utilization = 0.6, # Reduce if out of memory
)

model = FastLanguageModel.get_peft_model(
    model,
    r = lora_rank, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ], # Remove QKVO if out of memory
    lora_alpha = lora_rank,
    use_gradient_checkpointing = "unsloth", # Enable long context finetuning
    random_state = 3407,
)

==((====))==  Unsloth 2025.3.19: Fast Qwen2 patching. Transformers: 4.51.3. vLLM: 0.8.4.
   \\   /|    NVIDIA GeForce RTX 4070 Laptop GPU. Num GPUs = 1. Max memory: 7.747 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.9. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.3.19 patched 36 layers with 36 QKV layers, 36 O layers and 36 MLP layers.


In [12]:
import re
import pandas as pd
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

def extract_answer_from_output(completion):
    ANS_RE = re.compile(r"(\-?[0-9]+)")
    match = ANS_RE.search(completion)
    if match:
        match_str = match.group(1).strip()
        match_str = match_str.replace(",", "")
        return match_str
    else:
        return None

def check_answer(output_text, answer):
    # Extract the answer from the output text
    extracted_answer = extract_answer_from_output(extract_xml_answer(output_text))
    # print(f"Extracted answer: {extracted_answer}")
    # Compare with the expected answer
    return extracted_answer == answer


# uncomment middle messages for 1-shot prompting
def get_gsm8k_french_questions():
    data = pd.read_csv("gsm8k-french2-translated.csv")
    # print(data['french_question'])
    data = data.apply(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['french_question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }, axis=1) # type: ignore
    return data # type: ignore

def get_gsm8k_questions():
    data = pd.read_csv("gsm8k-french2-translated.csv")
    # print(data['french_question'])
    data = data.apply(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }, axis=1) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_answer_from_output(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_answer_from_output(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 8  # Tune based on your VRAM

def collate_fn(batch):
    tokenizer.padding_side = "right"
    prompts = [tokenizer.apply_chat_template(item['prompt'], tokenize=False, add_generation_prompt=True) for item in batch]
    inputs = tokenizer(prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
    return inputs, batch

dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)


In [None]:
total, correct = 0, 0

with torch.inference_mode():
    model.eval()
    for inputs, batch in dataloader:
        output_ids = model.generate(
            inputs['input_ids'],
            do_sample=False,
            max_new_tokens=512,
            output_attentions=False,
            output_hidden_states=False,
        )
        outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
        # print(outputs[0])
        for output_text, item in zip(outputs, batch):
            if check_answer(output_text, item['answer']):
                correct += 1
            # else:
            #     print(output_text)
            #     extracted_answer = extract_answer_from_output(extract_xml_answer(output_text))
            #     print(extracted_answer, item['answer'])
            total += 1

        print(correct, total)
print(f"Accuracy: {correct / total * 100:.2f}%")

system

Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>

user
Julie is reading a 120-page book. Yesterday, she was able to read 12 pages and today, she read twice as many pages as yesterday. If she wants to read half of the remaining pages tomorrow, how many pages should she read?
assistant
ością
<reasoning>
To determine how many pages Julie should read tomorrow, we need to follow these steps:

1. Calculate the total number of pages Julie has read so far.
2. Determine the number of pages remaining after her current readings.
3. Find out how many pages she plans to read tomorrow, which is half of the remaining pages.

First, let's calculate the total number of pages Julie has read so far:
- Yesterday, she read 12 pages.
- Today, she read twice as many pages as she did yesterday, which means she read \(2 \times 12 = 24\) pages.

So, the total number of pages she has read so far is \(12 + 24 = 36\) pages.

Next, we find out how many pages are left to b

KeyboardInterrupt: 