<a href="https://colab.research.google.com/github/roshjaison03/roshjaison03-Fine-tuned-Models-using-Unsloth-Framework-/blob/main/LLama_Fine_Tuning_for_Medical_Data_%26_Reasoning_with_3840_size_context_window_using_Unsloth.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

To run this, press "*Runtime*" and press "*Run all*" on a **free** Tesla T4 Google Colab instance!
<div class="align-center">
<a href="https://unsloth.ai/"><img src="https://github.com/unslothai/unsloth/raw/main/images/unsloth%20new%20logo.png" width="115"></a>
<a href="https://discord.gg/unsloth"><img src="https://github.com/unslothai/unsloth/raw/main/images/Discord button.png" width="145"></a>
<a href="https://docs.unsloth.ai/"><img src="https://github.com/unslothai/unsloth/blob/main/images/documentation%20green%20button.png?raw=true" width="125"></a></a> Join Discord if you need help + ‚≠ê <i>Star us on <a href="https://github.com/unslothai/unsloth">Github</a> </i> ‚≠ê
</div>

To install Unsloth on your own computer, follow the installation instructions on our Github page [here](https://docs.unsloth.ai/get-started/installing-+-updating).

You will learn how to do [data prep](#Data), how to [train](#Train), how to [run the model](#Inference), & [how to save it](#Save)


### News

Read our **[TTS Guide](https://docs.unsloth.ai/basics/text-to-speech-tts-fine-tuning)** for instructions and all our notebooks.

Read our **[Qwen3 Guide](https://docs.unsloth.ai/basics/qwen3-how-to-run-and-fine-tune)** and check out our new **[Dynamic 2.0](https://docs.unsloth.ai/basics/unsloth-dynamic-2.0-ggufs)** quants which outperforms other quantization methods!

Visit our docs for all our [model uploads](https://docs.unsloth.ai/get-started/all-our-models) and [notebooks](https://docs.unsloth.ai/get-started/unsloth-notebooks).


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


### Installation

In [None]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    !pip install --no-deps unsloth vllm

In [None]:
#@title Colab Extra Install { display-mode: "form" }
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth vllm
else:
    !pip install --no-deps unsloth vllm
    # [NOTE] Do the below ONLY in Colab! Use [[pip install unsloth vllm]]
    # Skip restarting message in Colab
    import sys, re, requests; modules = list(sys.modules.keys())
    for x in modules: sys.modules.pop(x) if "PIL" in x or "google" in x else None
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft "trl==0.15.2" triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1" huggingface_hub hf_transfer

    # vLLM requirements - vLLM breaks Colab due to reinstalling numpy
    f = requests.get("https://raw.githubusercontent.com/vllm-project/vllm/refs/heads/main/requirements/common.txt").content
    with open("vllm_requirements.txt", "wb") as file:
        file.write(re.sub(rb"(transformers|numpy|xformers)[^\n]{1,}\n", b"", f))
    !pip install -r vllm_requirements.txt

### Unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

# Load base model with extended context length
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "meta-llama/Llama-3.2-3B-Instruct",
    max_seq_length = 3840,  # Increased context length
    load_in_4bit = False,
    fast_inference = False,  # Avoid vLLM for training
    max_lora_rank = 64,
)

# Inject LoRA adapters (same config as your original training)
model = FastLanguageModel.get_peft_model(
    model,
    r = 64,
    target_modules = [
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    lora_alpha = 64,
    use_gradient_checkpointing = "unsloth",
    random_state = 3407,
)

==((====))==  Unsloth 2025.5.7: Fast Llama patching. Transformers: 4.51.3. vLLM: 0.8.5.post1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Unsloth 2025.5.7 patched 28 layers with 28 QKV layers, 28 O layers and 28 MLP layers.


In [None]:
model = loramodel.merge_and_unload()  # üîÅ Merges LoRA weights into base model

# Save merged model
model.save_pretrained("/content/drive/MyDrive/llamamodel/reasoning_model")

**DATASET OF MEDICAL-o1- REASONING**

In [None]:
from datasets import load_dataset

# Load the 'hq' subset of the dataset
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT","en",split="train")
dataset

Dataset({
    features: ['Question', 'Complex_CoT', 'Response'],
    num_rows: 19704
})

Let's look at the first row:

In [None]:
dataset[0]["Question"]

'Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?'

In [None]:
dataset[0]["Complex_CoT"]

"Okay, let's see what's going on here. We've got sudden weakness in the person's left arm and leg - and that screams something neuro-related, maybe a stroke?\n\nBut wait, there's more. The right lower leg is swollen and tender, which is like waving a big flag for deep vein thrombosis, especially after a long flight or sitting around a lot.\n\nSo, now I'm thinking, how could a clot in the leg end up causing issues like weakness or stroke symptoms?\n\nOh, right! There's this thing called a paradoxical embolism. It can happen if there's some kind of short circuit in the heart - like a hole that shouldn't be there.\n\nLet's put this together: if a blood clot from the leg somehow travels to the left side of the heart, it could shoot off to the brain and cause that sudden weakness by blocking blood flow there.\n\nHmm, but how would the clot get from the right side of the heart to the left without going through the lungs and getting filtered out?\n\nHere's where our cardiac anomaly comes in: 

In [None]:
def clean_reasoning_entry(entry):
    raw_query = "".join(entry["Question"])
    prompt = raw_query.strip("[]',\"\n")
    documents = entry["Complex_CoT"]

    return prompt,documents

query, docs = clean_reasoning_entry(dataset[2])
print("Clean Query:", query)
print("\nClean Documents:", docs)

Clean Query: A 61-year-old woman with a long history of involuntary urine loss during activities like coughing or sneezing but no leakage at night undergoes a gynecological exam and Q-tip test. Based on these findings, what would cystometry most likely reveal about her residual volume and detrusor contractions?

Clean Documents: Okay, let's think about this step by step. There's a 61-year-old woman here who's been dealing with involuntary urine leakages whenever she's doing something that ups her abdominal pressure like coughing or sneezing. This sounds a lot like stress urinary incontinence to me. Now, it's interesting that she doesn't have any issues at night; she isn't experiencing leakage while sleeping. This likely means her bladder's ability to hold urine is fine when she isn't under physical stress. Hmm, that's a clue that we're dealing with something related to pressure rather than a bladder muscle problem. 

The fact that she underwent a Q-tip test is intriguing too. This test

We now create a system prompt which can be customized. We add 4 extra symbols for working out or thinking / reasoning sections and a final answer:

In [None]:
reasoning_start = "<start_working_out>"
reasoning_end   = "<end_working_out>"
solution_start = "<SOLUTION>"
solution_end = "</SOLUTION>"

system_prompt = f"""You are given a complex question that requires careful reasoning and document-based inference.You are not allowed to reason or answer anything out of question given to you,
Think through the problem, and write your reasoning between {reasoning_start} and {reasoning_end}.
Then, provide your final answer or conclusion between {solution_start} and {solution_end}."""

system_prompt

'You are given a complex question that requires careful reasoning and document-based inference.You are not allowed to reason or answer anything out of question given to you,\nThink through the problem, and write your reasoning between <start_working_out> and <end_working_out>.\nThen, provide your final answer or conclusion between <SOLUTION> and </SOLUTION>.'

Let's map the dataset! and see the first row:

In [None]:
dataset = dataset.map(lambda x: {
    "prompt": [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": "".join(x["Question"]).strip(" []',\"\n")},
    ],
    "answer": x["Complex_CoT"],
})
dataset[0]

Map:   0%|          | 0/19704 [00:00<?, ? examples/s]

{'Question': 'Given the symptoms of sudden weakness in the left arm and leg, recent long-distance travel, and the presence of swollen and tender right lower leg, what specific cardiac abnormality is most likely to be found upon further evaluation that could explain these findings?',
 'Complex_CoT': "Okay, let's see what's going on here. We've got sudden weakness in the person's left arm and leg - and that screams something neuro-related, maybe a stroke?\n\nBut wait, there's more. The right lower leg is swollen and tender, which is like waving a big flag for deep vein thrombosis, especially after a long flight or sitting around a lot.\n\nSo, now I'm thinking, how could a clot in the leg end up causing issues like weakness or stroke symptoms?\n\nOh, right! There's this thing called a paradoxical embolism. It can happen if there's some kind of short circuit in the heart - like a hole that shouldn't be there.\n\nLet's put this together: if a blood clot from the leg somehow travels to the l

We create a regex format to match the reasoning sections and answers:

In [None]:
import re

match_format = re.compile(
    rf"^[\s]{{0,}}"\
    rf"{reasoning_start}.+?{reasoning_end}.*?"\
    rf"{solution_start}(.+?){solution_end}"\
    rf"[\s]{{0,}}$",
    flags = re.MULTILINE | re.DOTALL
)

We verify it works:

In [None]:
match_format.search(
    "<start_working_out>Let me think!<end_working_out>"\
    "<SOLUTION>2</SOLUTION>",
)

<re.Match object; span=(0, 71), match='<start_working_out>Let me think!<end_working_out>>

**Creating Reward Based System**

In [None]:
def match_format_exactly(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Match if format is seen exactly!
        if match_format.search(response) is not None: score += 3.0
        scores.append(score)
    return scores

If it fails, we want to reward the model if it at least follows the format partially, by counting each symbol:

In [None]:
def match_format_approximately(completions, **kwargs):
    scores = []
    for completion in completions:
        score = 0
        response = completion[0]["content"]
        # Count how many keywords are seen - we penalize if too many!
        # If we see 1, then plus some points!
        score += 0.5 if response.count(reasoning_start) == 1 else -1.0
        score += 0.5 if response.count(reasoning_end)   == 1 else -1.0
        score += 0.5 if response.count(solution_start)  == 1 else -1.0
        score += 0.5 if response.count(solution_end)    == 1 else -1.0
        scores.append(score)
    return scores

Finally, we want to extract the generated answer, and reward or penalize it! We also reward it based on how close the answer is to the true one via ratios:

In [None]:
def check_answer(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_format.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    for guess, true_answer in zip(extracted_responses, answer):
        score = 0
        if guess is None:
            scores.append(0)
            continue
        # Correct answer gets 3 points!
        if guess == true_answer:
            score += 3.0
        # Match if spaces are seen, but less reward
        elif guess.strip() == true_answer.strip():
            score += 1.5
        else:
            # We also reward it if the answer is close via ratios!
            # Ie if the answer is within some range, reward it!
            try:
                ratio = float(guess) / float(true_answer)
                if   ratio >= 0.9 and ratio <= 1.1: score += 1.0
                elif ratio >= 0.8 and ratio <= 1.2: score += 0.5
                else: score -= 1.5 # Penalize wrong answers
            except:
                score -= 1.5 # Penalize
        scores.append(score)
    return scores

Also sometimes it might not be 1 number as the answer, but like a sentence for example "The solution is $20" -> we extract 20.

We also remove possible commas for example as in 123,456

In [None]:
match_numbers = re.compile(
    solution_start + r".*?([\d\.\,]{1,})",
    flags = re.MULTILINE | re.DOTALL
)
print(match_numbers.findall("<SOLUTION>  0.34  </SOLUTION>"))
print(match_numbers.findall("<SOLUTION>  123,456  </SOLUTION>"))

['0.34']
['123,456']


In [None]:
global PRINTED_TIMES
PRINTED_TIMES = 0
global PRINT_EVERY_STEPS
PRINT_EVERY_STEPS = 5

def check_numbers(prompts, completions, answer, **kwargs):
    question = prompts[0][-1]["content"]
    responses = [completion[0]["content"] for completion in completions]

    extracted_responses = [
        guess.group(1)
        if (guess := match_numbers.search(r)) is not None else None \
        for r in responses
    ]

    scores = []
    # Print only every few steps
    global PRINTED_TIMES
    global PRINT_EVERY_STEPS
    if PRINTED_TIMES % PRINT_EVERY_STEPS == 0:
        print('*'*20, f"Question:\n{question}", f"\nAnswer:\n{answer[0]}", f"\nResponse:\n{responses[0]}", f"\nExtracted:\n{extracted_responses[0]}")
    PRINTED_TIMES += 1

    for guess, true_answer in zip(extracted_responses, answer):
        if guess is None:
            scores.append(0)
            continue
        # Convert to numbers
        try:
            true_answer = float(true_answer.strip())
            # Remove commas like in 123,456
            guess       = float(guess.strip().replace(",", ""))
            scores.append(1.5 if guess == true_answer else -0.5)
        except:
            scores.append(0)
            continue
    return scores

Get the maximum prompt length so we don't accidentally truncate it!

In [None]:
max(dataset.map(
    lambda x: {"tokens" : tokenizer.apply_chat_template(x["prompt"], add_generation_prompt = True, tokenize = True)},
    batched = True,
).map(lambda x: {"length" : len(x["tokens"])})["length"])

Map:   0%|          | 0/19704 [00:00<?, ? examples/s]

Map:   0%|          | 0/19704 [00:00<?, ? examples/s]

793

<a name="Train"></a>
### Train the model

Now set up GRPO Trainer and all configurations!

In [None]:
max_prompt_length = 287 # + 1 just in case!

from trl import GRPOConfig, GRPOTrainer
training_args = GRPOConfig(
    learning_rate = 5e-6,
    weight_decay = 0.1,
    warmup_ratio = 0.1,
    lr_scheduler_type = "cosine",
    optim = "adamw_8bit",
    logging_steps = 1,
    per_device_train_batch_size = 1,
    gradient_accumulation_steps = 4, # Increase to 4 for smoother training
    num_generations = 2, # Decrease if out of memory
    max_prompt_length = max_prompt_length,
    # num_train_epochs = 1, # Set to 1 for a full training run
    max_steps = 50,
    save_steps = 5,
    max_grad_norm = 1.0,
    report_to = "none", # Can use Weights & Biases
    output_dir = "/content/drive/MyDrive/med_model",
)

Unsloth: We now expect `per_device_train_batch_size` to be a multiple of `num_generations`.
We will change the batch size of 1 to the `num_generations` of 2


In [None]:
trainer = GRPOTrainer(
    model = model,
    processing_class = tokenizer,
    reward_funcs = [
        match_format_exactly,
        match_format_approximately,
        check_answer,
        check_numbers,
    ],
    args = training_args,
    train_dataset = dataset,
    model_init_kwargs = {"from_device": "meta"}, # Add this to handle models on meta device
)
trainer.train()

In [None]:
!huggingface-cli login


**Loading Tokenizer and Model**

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer,BitsAndBytesConfig, pipeline
import gradio as gr
import os

tokenizer_name = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer_cache_dir = "./tokenizer_cache"  # folder to save/load tokenizer

if os.path.exists(tokenizer_cache_dir):
    print("Loading tokenizer from local cache...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_cache_dir, trust_remote_code=True)
else:
    print("Downloading tokenizer and saving locally...")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, trust_remote_code=True)
    tokenizer.save_pretrained(tokenizer_cache_dir)

# Define 4-bit quantization config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

# Load merged model in 4-bit
model = AutoModelForCausalLM.from_pretrained(
    "Rosh03/Reasoning_finetuned_llama",  # ‚úÖ This is your merged model
    quantization_config=bnb_config,
    device_map="auto",  # auto = use GPU if available, fallback to CPU
    trust_remote_code=True
)

# Set evaluation mode
model.eval()

# Create text generation pipeline
generator = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")

# Define chatbot function
def chat(user_input, history=[]):
    prompt = user_input
    output = generator(prompt, max_new_tokens=200, do_sample=True, temperature=0.7)
    response = output[0]['generated_text']
    history.append((user_input, response))
    return history, history

# Gradio interface
chat_ui = gr.ChatInterface(fn=chat, title="üß† LoRA Reasoning Chatbot (Merged 4-bit)", theme="default")

# Launch Gradio app
if __name__ == "__main__":
    chat_ui.launch()