<a href="https://colab.research.google.com/github/perfect7613/sarvam-reasoning/blob/main/Sarvam_Reasoning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%%capture
# Skip restarting message in Colab
import sys; modules = list(sys.modules.keys())
for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None

!pip install unsloth vllm
!pip install --upgrade pillow

In [2]:
from unsloth import FastLanguageModel, PatchFastRL, is_bfloat16_supported
PatchFastRL("GRPO", FastLanguageModel)
import torch

Unsloth: Patching Xformers to fix some performance issues.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
INFO 02-20 11:28:54 __init__.py:190] Automatically detected platform cuda.


In [4]:
max_seq_length = 1024  # Increase for longer reasoning traces
lora_rank = 64

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name="sarvamai/sarvam-2b",
    max_seq_length=max_seq_length,
    load_in_4bit=False,       # Disabled for this model
    fast_inference=False,     # Disabled for this model
    max_lora_rank=lora_rank,
    gpu_memory_utilization=0.5,
)

# --- Apply LoRA adapter ---
model = FastLanguageModel.get_peft_model(
    model,
    r=lora_rank,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha=lora_rank,
    use_gradient_checkpointing="unsloth",  # Enable long-context fine-tuning
    random_state=3407,
)

==((====))==  Unsloth 2025.2.12: Fast Llama patching. Transformers: 4.48.3.
   \\   /|    GPU: Tesla T4. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.28.post3. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json:   0%|          | 0.00/21.0k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.77G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/279M [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/193 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/775k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/8.51M [00:00<?, ?B/s]

Unsloth: Will load sarvamai/sarvam-2b as a legacy tokenizer.


sarvamai/sarvam-2b does not have a padding token! Will use pad_token = <unk>.


Unsloth 2025.2.12 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


In [5]:
import re
from datasets import load_dataset, Dataset

# Load and prep dataset
SYSTEM_PROMPT = """
Respond in the following format:
<reasoning>
...
</reasoning>
<answer>
...
</answer>
"""

XML_COT_FORMAT = """\
<reasoning>
{reasoning}
</reasoning>
<answer>
{answer}
</answer>
"""

def extract_xml_answer(text: str) -> str:
    answer = text.split("<answer>")[-1]
    answer = answer.split("</answer>")[0]
    return answer.strip()

def extract_hash_answer(text: str) -> str | None:
    if "####" not in text:
        return None
    return text.split("####")[1].strip()

# uncomment middle messages for 1-shot prompting
def get_gsm8k_questions(split = "train") -> Dataset:
    data = load_dataset('openai/gsm8k', 'main')[split] # type: ignore
    data = data.map(lambda x: { # type: ignore
        'prompt': [
            {'role': 'system', 'content': SYSTEM_PROMPT},
            {'role': 'user', 'content': x['question']}
        ],
        'answer': extract_hash_answer(x['answer'])
    }) # type: ignore
    return data # type: ignore

dataset = get_gsm8k_questions()

# Reward functions
def correctness_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    q = prompts[0][-1]['content']
    extracted_responses = [extract_xml_answer(r) for r in responses]
    print('-'*20, f"Question:\n{q}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]

def int_reward_func(completions, **kwargs) -> list[float]:
    responses = [completion[0]['content'] for completion in completions]
    extracted_responses = [extract_xml_answer(r) for r in responses]
    return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]

def strict_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def soft_format_reward_func(completions, **kwargs) -> list[float]:
    """Reward function that checks if the completion has a specific format."""
    pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
    responses = [completion[0]["content"] for completion in completions]
    matches = [re.match(pattern, r) for r in responses]
    return [0.5 if match else 0.0 for match in matches]

def count_xml(text) -> float:
    count = 0.0
    if text.count("<reasoning>\n") == 1:
        count += 0.125
    if text.count("\n</reasoning>\n") == 1:
        count += 0.125
    if text.count("\n<answer>\n") == 1:
        count += 0.125
        count -= len(text.split("\n</answer>\n")[-1])*0.001
    if text.count("\n</answer>") == 1:
        count += 0.125
        count -= (len(text.split("\n</answer>")[-1]) - 1)*0.001
    return count

def xmlcount_reward_func(completions, **kwargs) -> list[float]:
    contents = [completion[0]["content"] for completion in completions]
    return [count_xml(c) for c in contents]

README.md:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.31M [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/419k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

In [6]:
from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    use_vllm=False,  # Disable vLLM since our current setup doesn't support it with this model
    learning_rate=5e-6,
    adam_beta1=0.9,
    adam_beta2=0.99,
    weight_decay=0.1,
    warmup_ratio=0.1,
    lr_scheduler_type="cosine",
    optim="adamw_8bit",
    logging_steps=1,
    bf16=is_bfloat16_supported(),
    fp16=not is_bfloat16_supported(),
    per_device_train_batch_size=1,
    gradient_accumulation_steps=1,
    num_generations=8,
    max_prompt_length=256,
    max_completion_length=200,
    max_steps=250,
    save_steps=250,
    max_grad_norm=0.1,
    report_to="none",
    output_dir="outputs",
)

Unsloth: We know expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 8


In [7]:
trainer = GRPOTrainer(
    model=model,
    processing_class=tokenizer,
    reward_funcs=[
        xmlcount_reward_func,
        soft_format_reward_func,
        strict_format_reward_func,
        int_reward_func,
        correctness_reward_func,
    ],
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 7,473 | Num Epochs = 1
O^O/ \_/ \    Batch size per device = 8 | Gradient Accumulation steps = 1
\        /    Total batch size = 8 | Total steps = 250
 "-____-"     Number of trainable parameters = 95,879,168


-------------------- Question:
A concert ticket costs $40. Mr. Benson bought 12 tickets and received a 5% discount for every ticket bought that exceeds 10. How much did Mr. Benson pay in all? 
Answer:
476 
Response:
<reasoning> 
 Mr. Benson bought 12 tickets, but 5% of the first 10 is $2.50, and 5% of the next 2 is $1.25. Therefore, there are 12 - 10 - 2 - 1 = 1 tickets that exceeds 10, and 12 - 10 - 2 - 1 - 1 = 4 tickets that are below 10. The total discount is $12 * 5% = $12. 
 So, Mr. Benson paid 2.50 + 1.25 + $12 = $15.25 for 12 tickets. $40 * 12 - $15.25 = ₹ 503.75. 
 </reasoning> 
 < 
Extracted:
<reasoning> 
 Mr. Benson bought 12 tickets, but 5% of the first 10 is $2.50, and 5% of the next 2 is $1.25. Therefore, there are 12 - 10 - 2 - 1 = 1 tickets that exceeds 10, and 12 - 10 - 2 - 1 - 1 = 4 tickets that are below 10. The total discount is $12 * 5% = $12. 
 So, Mr. Benson paid 2.50 + 1.25 + $12 = $15.25 for 12 tickets. $40 * 12 - $15.25 = ₹ 503.75. 
 </reasoning> 
 <


Step,Training Loss,reward,reward_std,completion_length,kl,rewards / xmlcount_reward_func,rewards / soft_format_reward_func,rewards / strict_format_reward_func,rewards / int_reward_func,rewards / correctness_reward_func
1,0.0,0.0,0.0,195.25,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,155.375,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,157.625,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,172.125,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,0.0,115.125,0.0,0.0,0.0,0.0,0.0,0.0
6,0.0,0.0,0.0,193.25,0.0,0.0,0.0,0.0,0.0,0.0
7,0.0,0.0,0.0,193.375,0.0,0.0,0.0,0.0,0.0,0.0
8,0.0,0.0625,0.176777,171.75,0.0,0.0,0.0,0.0,0.0625,0.0
9,0.0,0.25,0.377964,158.625,2.8e-05,0.0,0.0625,0.0,0.1875,0.0
10,0.0,0.0,0.0,196.375,1.9e-05,0.0,0.0,0.0,0.0,0.0


-------------------- Question:
Jane is trying to decide whether to buy a house or a trailer. A house costs $480,000 and a trailer costs $120,000. Each loan will be paid in monthly installments over 20 years. How much more is the monthly payment on the house compared to the trailer? 
Answer:
1500 
Response:
<answer> $72,000</answer> 
 <reasoning> 
 The monthly payment on the house will be $480,000 / 20 years * $120,000 / $120,000. 
 The monthly payment on the trailer will be $120,000 / 20 years * $120,000 / $120,000. 
 
 The difference between the monthly payments is $72,000. </reasoning> 
 
 <answer> $72,000</answer> </s> 
 
Extracted:
$72,000
-------------------- Question:
Janet pays $40/hour for 3 hours per week of clarinet lessons and $28/hour for 5 hours a week of piano lessons. How much more does she spend on piano lessons than clarinet lessons in a year? 
Answer:
1040 
Response:
<reasoning> 
 Janet spends $1,148 on piano lessons in a year and $1,084 on clarinet lessons in a year.

Step,Training Loss,reward,reward_std,completion_length,kl,rewards / xmlcount_reward_func,rewards / soft_format_reward_func,rewards / strict_format_reward_func,rewards / int_reward_func,rewards / correctness_reward_func
1,0.0,0.0,0.0,195.25,0.0,0.0,0.0,0.0,0.0,0.0
2,0.0,0.0,0.0,155.375,0.0,0.0,0.0,0.0,0.0,0.0
3,0.0,0.0,0.0,157.625,0.0,0.0,0.0,0.0,0.0,0.0
4,0.0,0.0,0.0,172.125,0.0,0.0,0.0,0.0,0.0,0.0
5,0.0,0.0,0.0,115.125,0.0,0.0,0.0,0.0,0.0,0.0
6,0.0,0.0,0.0,193.25,0.0,0.0,0.0,0.0,0.0,0.0
7,0.0,0.0,0.0,193.375,0.0,0.0,0.0,0.0,0.0,0.0
8,0.0,0.0625,0.176777,171.75,0.0,0.0,0.0,0.0,0.0625,0.0
9,0.0,0.25,0.377964,158.625,2.8e-05,0.0,0.0625,0.0,0.1875,0.0
10,0.0,0.0,0.0,196.375,1.9e-05,0.0,0.0,0.0,0.0,0.0


-------------------- Question:
Mr. and Mrs. Hugo went on a road trip. On the first day, they traveled 200 miles. On the second day, they traveled 3/4 as far. On the third day, they traveled 1/2 as many miles as the first two days combined. How many miles did they travel for 3 days? 
Answer:
525 
Response:
<reasoning>200</reasoning> <answer>200</answer> 
 
 <answer>200</answer> </s> 
 
Extracted:
200
-------------------- Question:
Billy's family likes to keep their bicycles stored in the garage when they're not being used.  They own a total of 4 bicycles.  Each bicycle wheel has 10 spokes.  How many spokes are inside the garage? 
Answer:
80 
Response:
<reasoning> 
 <answer>4 spokes </answer> 
 </reasoning> 
 <answer>4</answer> </s> 
 
Extracted:
4
-------------------- Question:
Beth is a scuba diver.  She is excavating a sunken ship off the coast of a small Caribbean island and she must remain underwater for long periods.  Her primary tank, which she wears when she first enters the wate

TrainOutput(global_step=250, training_loss=0.0008460036162418874, metrics={'train_runtime': 4374.7416, 'train_samples_per_second': 0.457, 'train_steps_per_second': 0.057, 'total_flos': 0.0, 'train_loss': 0.0008460036162418874})

In [36]:
model.eval()

# IMPORTANT: Enable fast inference for the Unsloth model.
FastLanguageModel.for_inference(model)

# Define a plain text prompt suitable for a text-completion model.
prompt = "Natalia sold clips to 50 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May?"

# Tokenize the prompt and move tensors to the appropriate device.
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

# Generate a text completion using the standard generate method.
output_ids = model.generate(
    inputs.input_ids,
    max_length=1024,
    do_sample=True,
    top_p=0.95,
    temperature=0.8
)

# Decode the generated tokens to produce the final text.
generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
print(generated_text)

Natalia sold clips to 50 of her friends in April, and then she sold half as many clips in May. How many clips did Natalia sell altogether in April and May? 
 To solve this problem, we need to figure out how much Natalia sold each month individually, multiply that by the number of people who bought from her, and add it up. Let's start with April: 
 April - Total clip sales = (1/2 * 50) + 50 
 Now let's do the same for May: 
 May - Total clip sales = 50 
 Adding them together gives us: 
 (1/2 * 50) + 50 + 50 = total_clips_sold 
 Let's simplify this equation: 
 (25 + 50) + 50 = total_clips_sold 
 35 + 50 = total_clips_sold 
 65 = total_clips_sold 
 This means that on average, Natalia sold 65 clips per day over these two months. However, since she only had a limited time to sell all the clips before they expired, it's likely that the actual number was higher but still within the range of what could be managed during such a short period. The key takeaway here is understanding the concept be

In [31]:
model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit",)
model.push_to_hub_merged("Perfect7613/sarvam-reasoning", tokenizer, save_method = "merged_16bit", token = "YOUR_HF_TOKEN")

Unsloth: You have 1 CPUs. Using `safe_serialization` is 10x slower.
We shall switch to Pytorch saving, which might take 3 minutes and not 30 minutes.
To force `safe_serialization`, set it to `None` instead.
Unsloth: Kaggle/Colab has limited disk space. We need to delete the downloaded
model which will save 4-16GB of disk space, allowing you to save on Kaggle/Colab.
Unsloth: Will remove a cached repo with size 5.1G


Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 4.36 out of 12.67 RAM for saving.
Unsloth: Saving model... This might take 5 minutes ...


100%|██████████| 28/28 [00:01<00:00, 27.61it/s]


Unsloth: Saving tokenizer... Done.
Unsloth: Saving model/pytorch_model-00001-of-00002.bin...
Unsloth: Saving model/pytorch_model-00002-of-00002.bin...
Done.


Unsloth: You are pushing to hub, but you passed your HF username = Perfect7613.
We shall truncate Perfect7613/sarvam-reasoning to sarvam-reasoning


Unsloth: Merging 4bit and LoRA weights to 16bit...
Unsloth: Will use up to 4.3 out of 12.67 RAM for saving.
Unsloth: Saving model... This might take 5 minutes ...


100%|██████████| 28/28 [00:00<00:00, 38.16it/s]


Unsloth: Saving tokenizer...

  0%|          | 0/1 [00:00<?, ?it/s]

tokenizer.model:   0%|          | 0.00/1.94M [00:00<?, ?B/s]

 Done.
Unsloth: Saving sarvam-reasoning/pytorch_model-00001-of-00002.bin...
Unsloth: Saving sarvam-reasoning/pytorch_model-00002-of-00002.bin...


  0%|          | 0/2 [00:00<?, ?it/s]

pytorch_model-00001-of-00002.bin:   0%|          | 0.00/4.77G [00:00<?, ?B/s]

pytorch_model-00002-of-00002.bin:   0%|          | 0.00/279M [00:00<?, ?B/s]

Done.
Saved merged model to https://huggingface.co/Perfect7613/sarvam-reasoning
