In [None]:
# 1. Install Unsloth for 2x faster training and 70% less VRAM usage.
# 2. We install xformers for efficient attention mechanisms.
!pip install unsloth
!pip install --no-deps "xformers<0.0.29" "trl<0.13.0" peft accelerate bitsandbytes

In [None]:
from unsloth import FastLanguageModel
import torch

max_seq_length = 4096 # Adjust based on how long you want the "thinking" to be
dtype = None # None for auto detection
load_in_4bit = True # Use 4-bit quantization to save memory

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen2.5-3B-Instruct", # Excellent base for reasoning
    max_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
)

# Apply LoRA (Low-Rank Adaptation)
# This adds a small number of trainable weights to the model while keeping
# the base weights frozen, allowing us to train on a consumer GPU.
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Rank: Higher = more capacity but more VRAM. 16 is a sweet spot.
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0, # Optimized to 0 for Unsloth
    bias = "none",    # Optimized to "none" for Unsloth
    use_gradient_checkpointing = "unsloth", # Saves massive VRAM
    random_state = 3407,
)

In [None]:
from datasets import load_dataset

# We use a reasoning dataset. 'Magpie-Reasoning' or 'OpenThoughts' are great.
# For this example, we'll use a subset that has clear step-by-step logic.
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en", split = "train[:500]")

# Print the dataset features to identify correct column names
print(dataset.features)

# The Chat Template for Qwen (Imperfect/ChatML style)
# We structure the 'assistant' response to ALWAYS start with <think>
prompt_style = """<|im_start|>system
You are a helpful assistant that thinks step-by-step before answering.<|im_end|>
<|im_start|>user
{}<|im_end|>
<|im_start|>assistant
<think>
{}
</think>
{}<|im_end|>"""

def formatting_prompts_func(examples):
    # These keys need to be updated based on the actual dataset features
    instructions = examples["Question"]
    reasoning    = examples["Complex_CoT"] # The logic chain
    responses    = examples["Response"]    # The final answer
    texts = []
    for instruction, reason, response in zip(instructions, reasoning, responses):
        # We combine them into the template
        text = prompt_style.format(instruction, reason, response)
        texts.append(text)
    return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True)

In [None]:
from trl import SFTTrainer
from transformers import TrainingArguments

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 2,
    args = TrainingArguments(
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4, # Total batch size = 8
        warmup_steps = 5,
        max_steps = 60, # Small number for testing; increase to 300+ for better results
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
    ),
)

trainer.train()

In [None]:
# Enable native 2x faster inference
FastLanguageModel.for_inference(model)

inputs = tokenizer(
[
    prompt_style.format(
        """Alice knows that Bob knows that Carol is wrong.
Carol knows that Bob does not know whether Alice is lying.
Can Alice know that Carol knows the truth? Why or why not?""", # Instruction
        "", # Leave reasoning empty for generation
        ""  # Leave response empty for generation
    )
], return_tensors = "pt").to("cuda")

# We use a streamer to see the "thinking" happen in real time
from transformers import TextStreamer
text_streamer = TextStreamer(tokenizer)

_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 512)