<a href="https://colab.research.google.com/github/shashanksrajak/finetuned-gemma-3-medical-QnA/blob/main/fine_tuning_gemma3_medical.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Fine Tuning Gemma 3 270M for Medical QnA

### Install Packages

In [1]:
%%capture
import os
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf "datasets>=3.4.1,<4.0.0" "huggingface_hub>=0.34.0" hf_transfer
    !pip install --no-deps unsloth

### Loading the Gemma3 Model

`FastModel` will load this model `unsloth/gemma-3-270m-it` which we will fine tune.

In [2]:
from unsloth import FastModel
import torch
max_seq_length = 2048

model, tokenizer = FastModel.from_pretrained(
    model_name = "unsloth/gemma-3-270m-it",
    max_seq_length = max_seq_length,
    load_in_4bit = False,  # 4 bit quantization to reduce memory
    load_in_8bit = False, # [NEW!] A bit more accurate, uses 2x memory
    full_finetuning = False,
    # token = "hf_...", # use one if using gated models
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.8.6: Fast Gemma3 patching. Transformers: 4.55.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 7.5. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Using float16 precision for gemma3 won't work! Using float32.
Unsloth: QLoRA and full finetuning all not selected. Switching to 16bit LoRA.


model.safetensors:   0%|          | 0.00/536M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/233 [00:00<?, ?B/s]

tokenizer_config.json: 0.00B [00:00, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/670 [00:00<?, ?B/s]

chat_template.jinja: 0.00B [00:00, ?B/s]

We now add LoRA adapters so we only need to update a small amount of parameters!

In [3]:
model = FastModel.get_peft_model(
    model,
    r = 128, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 128,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth: Making `model.base_model.model.model` require gradients


<a name="Data"></a>
### Data Prep

We will be using the medical QnA dataset with complex chain of thoughts which is sourced from this dataset on Hugging Face [FreedomIntelligence/medical-o1-reasoning-SFT](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT)

This dataset has 4 subsets and 1 train split with approx 19k samples. We will only use 10k samples from en subset for this tuning task.

In [7]:
from datasets import load_dataset
dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", 'en',  split = "train[:10000]")

medical_o1_sft.json:   0%|          | 0.00/58.2M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/19704 [00:00<?, ? examples/s]

In [9]:
# lets see how this data looks liks
dataset

Dataset({
    features: ['Question', 'Complex_CoT', 'Response'],
    num_rows: 10000
})

In [10]:
dataset[100]

{'Question': 'A 25-year-old woman presents to the ED with a diffuse, erythematous rash, nausea, vomiting, and fever for 2 days. Physical examination reveals a soaked tampon in her vagina, and blood cultures are negative, suggesting toxic shock syndrome. Which specific molecule on T cells does the toxin most likely bind to?',
 'Complex_CoT': "Alright, here's a situation with a 25-year-old woman who showed up in the emergency department. She's got this widespread red rash, feeling nauseous, she's vomiting, and running a fever for two days. Something's not quite right here, and it all starts connecting to the idea of toxic shock syndrome. Oh, and there's a crucial detail: they found a soaked tampon during her exam.\n\nOkay, let's dig into what's happening in toxic shock syndrome. It's a bit of a nightmare because it's associated with these things called superantigens. These are basically like the rogue agents of the bacterial world, and they're mostly coming from bugs like Staphylococcus 

#### Chat Template

We now use the `Gemma-3` format for conversation style finetunes. We use [Maxime Labonne's FineTome-100k](https://huggingface.co/datasets/mlabonne/FineTome-100k) dataset in ShareGPT style.

Gemma-3 renders multi turn conversations like below:

```
<bos><start_of_turn>user
Hello!<end_of_turn>
<start_of_turn>model
Hey there!<end_of_turn>
```

We use `get_chat_template` function to get the correct chat template.

In [11]:
from unsloth.chat_templates import get_chat_template
tokenizer = get_chat_template(
    tokenizer,
    chat_template = "gemma3",
)

We now use `convert_to_chatml` to try converting datasets to the correct format for finetuning purposes!

In [13]:
SYSTEM_INSTRUCTION = """
You are a highly knowledgeable medical expert. Your task is to analyze a given medical question and provide a comprehensive, step-by-step reasoning process (Chain-of-Thought) followed by a clear and concise final answer.

The Chain-of-Thought should:
1.  **Deconstruct the clinical scenario**: Identify key symptoms, patient demographics, and relevant medical findings (e.g., "diffuse, erythematous rash," "soaked tampon").
2.  **Formulate a differential diagnosis**: Based on the deconstructed elements, consider the most likely medical condition (e.g., toxic shock syndrome).
3.  **Explain the underlying pathophysiology**: Describe the biological mechanism of the identified condition, focusing on the role of the toxin and its interaction with the immune system.
4.  **Identify the specific target**: Pinpoint the precise molecule or structure the toxin binds to, explaining why this interaction leads to the observed symptoms.

The final answer should be a direct, succinct summary of the Chain-of-Thought, providing the correct medical term or concept without extraneous details. The final answer must be a single paragraph.

Ensure that both the Chain-of-Thought and the final answer are accurate, medically sound, and logically consistent.
"""

In [17]:
def convert_to_chatml(example):

    assistant_response = f"{example['Complex_CoT']}\n\nFinal Answer: {example['Response']}"

    return {
        "conversations": [
            {"role": "system", "content": SYSTEM_INSTRUCTION},
            {"role": "user", "content": example["Question"]},
            {"role": "assistant", "content": assistant_response},
        ]
    }

dataset = dataset.map(
    convert_to_chatml
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Let's see how row 100 looks like!

In [18]:
dataset[100]["conversations"]

[{'content': '\nYou are a highly knowledgeable medical expert. Your task is to analyze a given medical question and provide a comprehensive, step-by-step reasoning process (Chain-of-Thought) followed by a clear and concise final answer.\n\nThe Chain-of-Thought should:\n1.  **Deconstruct the clinical scenario**: Identify key symptoms, patient demographics, and relevant medical findings (e.g., "diffuse, erythematous rash," "soaked tampon").\n2.  **Formulate a differential diagnosis**: Based on the deconstructed elements, consider the most likely medical condition (e.g., toxic shock syndrome).\n3.  **Explain the underlying pathophysiology**: Describe the biological mechanism of the identified condition, focusing on the role of the toxin and its interaction with the immune system.\n4.  **Identify the specific target**: Pinpoint the precise molecule or structure the toxin binds to, explaining why this interaction leads to the observed symptoms.\n\nThe final answer should be a direct, succin

We now have to apply the chat template for `Gemma3` onto the conversations, and save it to `text`.

In [19]:
def formatting_prompts_func(examples):
   convos = examples["conversations"]
   texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False).removeprefix('<bos>') for convo in convos]
   return { "text" : texts, }

dataset = dataset.map(formatting_prompts_func, batched = True)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

Let's see how the chat template did!


In [21]:
dataset

Dataset({
    features: ['Question', 'Complex_CoT', 'Response', 'conversations', 'text'],
    num_rows: 10000
})

In [20]:
dataset[100]['text']

'<start_of_turn>user\n\nYou are a highly knowledgeable medical expert. Your task is to analyze a given medical question and provide a comprehensive, step-by-step reasoning process (Chain-of-Thought) followed by a clear and concise final answer.\n\nThe Chain-of-Thought should:\n1.  **Deconstruct the clinical scenario**: Identify key symptoms, patient demographics, and relevant medical findings (e.g., "diffuse, erythematous rash," "soaked tampon").\n2.  **Formulate a differential diagnosis**: Based on the deconstructed elements, consider the most likely medical condition (e.g., toxic shock syndrome).\n3.  **Explain the underlying pathophysiology**: Describe the biological mechanism of the identified condition, focusing on the role of the toxin and its interaction with the immune system.\n4.  **Identify the specific target**: Pinpoint the precise molecule or structure the toxin binds to, explaining why this interaction leads to the observed symptoms.\n\nThe final answer should be a direct

## See the model responses before fine tuning

In [29]:
dataset['conversations'][1]

[{'content': '\nYou are a highly knowledgeable medical expert. Your task is to analyze a given medical question and provide a comprehensive, step-by-step reasoning process (Chain-of-Thought) followed by a clear and concise final answer.\n\nThe Chain-of-Thought should:\n1.  **Deconstruct the clinical scenario**: Identify key symptoms, patient demographics, and relevant medical findings (e.g., "diffuse, erythematous rash," "soaked tampon").\n2.  **Formulate a differential diagnosis**: Based on the deconstructed elements, consider the most likely medical condition (e.g., toxic shock syndrome).\n3.  **Explain the underlying pathophysiology**: Describe the biological mechanism of the identified condition, focusing on the role of the toxin and its interaction with the immune system.\n4.  **Identify the specific target**: Pinpoint the precise molecule or structure the toxin binds to, explaining why this interaction leads to the observed symptoms.\n\nThe final answer should be a direct, succin

In [30]:
messages = [
    {'role': 'system','content':dataset['conversations'][1][0]['content']},
    {"role" : 'user', 'content' : dataset['conversations'][1][1]['content']}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
).removeprefix('<bos>')

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    # max_new_tokens = 125,
    temperature = 1, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

The 33-year-old woman is brought to the emergency department 15 minutes after being stabbed in the chest with a screwdriver. Given her vital signs of pulse 110/min, respirations 22/min, and blood pressure 90/65 mm Hg, along with the presence of a 5-cm deep stab wound at the upper border of the 8th rib in the left midaxillary line, which anatomical structure in her chest is most likely to be injured, the patient is considered to be a potential stab wound.
<end_of_turn>


<a name="Train"></a>
### Train the model
Now let's train our model. We do 60 steps to speed things up, but you can set `num_train_epochs=1` for a full run, and turn off `max_steps=None`.

In [31]:
from trl import SFTTrainer, SFTConfig

trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = dataset,
    eval_dataset = None, # Can set up evaluation!
    args = SFTConfig(
        dataset_text_field = "text",
        per_device_train_batch_size = 8,
        gradient_accumulation_steps = 1, # Use GA to mimic batch size!
        warmup_steps = 5,
        # num_train_epochs = 1, # Set this for 1 full training run.
        max_steps = 100,
        learning_rate = 5e-5, # Reduce to 2e-5 for long training runs
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir="outputs",
        report_to = "none", # Use this for WandB etc
    ),
)

Unsloth: Switching to float32 training since model cannot work with float16


Unsloth: Tokenizing ["text"] (num_proc=2):   0%|          | 0/10000 [00:00<?, ? examples/s]

We also use Unsloth's `train_on_completions` method to only train on the assistant outputs and ignore the loss on the user's inputs. This helps increase accuracy of finetunes!

In [32]:
from unsloth.chat_templates import train_on_responses_only
trainer = train_on_responses_only(
    trainer,
    instruction_part = "<start_of_turn>user\n",
    response_part = "<start_of_turn>model\n",
)

Map (num_proc=2):   0%|          | 0/10000 [00:00<?, ? examples/s]

Let's verify masking the instruction part is done! Let's print the 100th row again.

In [33]:
tokenizer.decode(trainer.train_dataset[100]["input_ids"])

'<bos><start_of_turn>user\n\nYou are a highly knowledgeable medical expert. Your task is to analyze a given medical question and provide a comprehensive, step-by-step reasoning process (Chain-of-Thought) followed by a clear and concise final answer.\n\nThe Chain-of-Thought should:\n1.  **Deconstruct the clinical scenario**: Identify key symptoms, patient demographics, and relevant medical findings (e.g., "diffuse, erythematous rash," "soaked tampon").\n2.  **Formulate a differential diagnosis**: Based on the deconstructed elements, consider the most likely medical condition (e.g., toxic shock syndrome).\n3.  **Explain the underlying pathophysiology**: Describe the biological mechanism of the identified condition, focusing on the role of the toxin and its interaction with the immune system.\n4.  **Identify the specific target**: Pinpoint the precise molecule or structure the toxin binds to, explaining why this interaction leads to the observed symptoms.\n\nThe final answer should be a d

Now let's print the masked out example - you should see only the answer is present:

In [34]:
tokenizer.decode([tokenizer.pad_token_id if x == -100 else x for x in trainer.train_dataset[100]["labels"]]).replace(tokenizer.pad_token, " ")

"                                                                                                                                                                                                                                                                                                                                            Alright, here's a situation with a 25-year-old woman who showed up in the emergency department. She's got this widespread red rash, feeling nauseous, she's vomiting, and running a fever for two days. Something's not quite right here, and it all starts connecting to the idea of toxic shock syndrome. Oh, and there's a crucial detail: they found a soaked tampon during her exam.\n\nOkay, let's dig into what's happening in toxic shock syndrome. It's a bit of a nightmare because it's associated with these things called superantigens. These are basically like the rogue agents of the bacterial world, and they're mostly coming from bugs like Staphylococcus aureus.\n\n

In [35]:
# @title Show current memory stats
gpu_stats = torch.cuda.get_device_properties(0)
start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)
print(f"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.")
print(f"{start_gpu_memory} GB of memory reserved.")

GPU = Tesla T4. Max memory = 14.741 GB.
1.646 GB of memory reserved.


Let's train the model! To resume a training run, set `trainer.train(resume_from_checkpoint = True)`

In [36]:
trainer_stats = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 10,000 | Num Epochs = 1 | Total steps = 100
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 1
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 1 x 1) = 8
 "-____-"     Trainable parameters = 30,375,936 of 298,474,112 (10.18% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,2.8597
2,2.7577
3,2.8233
4,2.5999
5,2.5743
6,2.5388
7,2.5573
8,2.4769
9,2.4997
10,2.6075


Unsloth: Will smartly offload gradients to save VRAM!


In [37]:
# @title Show final memory and time stats
used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)
used_memory_for_lora = round(used_memory - start_gpu_memory, 3)
used_percentage = round(used_memory / max_memory * 100, 3)
lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)
print(f"{trainer_stats.metrics['train_runtime']} seconds used for training.")
print(
    f"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training."
)
print(f"Peak reserved memory = {used_memory} GB.")
print(f"Peak reserved memory for training = {used_memory_for_lora} GB.")
print(f"Peak reserved memory % of max memory = {used_percentage} %.")
print(f"Peak reserved memory for training % of max memory = {lora_percentage} %.")

265.1092 seconds used for training.
4.42 minutes used for training.
Peak reserved memory = 3.047 GB.
Peak reserved memory for training = 1.401 GB.
Peak reserved memory % of max memory = 20.67 %.
Peak reserved memory for training % of max memory = 9.504 %.


<a name="Inference"></a>
### Inference
Let's run the model via Unsloth native inference! According to the `Gemma-3` team, the recommended settings for inference are `temperature = 1.0, top_p = 0.95, top_k = 64`

In [38]:
messages = [
    {'role': 'system','content':dataset['conversations'][1][0]['content']},
    {"role" : 'user', 'content' : dataset['conversations'][1][1]['content']}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize = False,
    add_generation_prompt = True, # Must add for generation
).removeprefix('<bos>')

from transformers import TextStreamer
_ = model.generate(
    **tokenizer(text, return_tensors = "pt").to("cuda"),
    # max_new_tokens = 125,
    temperature = 1, top_p = 0.95, top_k = 64,
    streamer = TextStreamer(tokenizer, skip_prompt = True),
)

Alright, I know this is bad stuff here, because after we've been doing our due diligence, these signs are not passing on these doctors like they should. So it's really difficult to determine what we're talking about, it's really bad because she's usually a 33-year-old woman that looks like she's got her chest settled up.

Okay, let's peek into this situation. The emergency doctor has put her at 15 minutes after the injury, which seems like everything should be putting her right in the morning. So, it's okay to say the emergency doctor is putting her in the morning, but it's wondering if the doctors might still be putting her in the morning coming out because everything is ready.

But wait, there's something more serious to tell me – the emergency doctor also had a five-cm deep stab wound in the upper border of the 8th rib. So that might tell me something is not going to seem okay coming back down the road, which it does not in question.

Right, if this doctor says this was done by an e

<a name="Save"></a>
### Saving, loading finetuned models
To save the final model as LoRA adapters, either use Huggingface's `push_to_hub` for an online save or `save_pretrained` for a local save.

**[NOTE]** This ONLY saves the LoRA adapters, and not the full model. To save to 16bit or GGUF, scroll down!

In [39]:
model.save_pretrained("lora_model")  # Local saving
tokenizer.save_pretrained("lora_model")
# model.push_to_hub("your_name/lora_model", token = "...") # Online saving
# tokenizer.push_to_hub("your_name/lora_model", token = "...") # Online saving

('lora_model/tokenizer_config.json',
 'lora_model/special_tokens_map.json',
 'lora_model/chat_template.jinja',
 'lora_model/tokenizer.model',
 'lora_model/added_tokens.json',
 'lora_model/tokenizer.json')

Now if you want to load the LoRA adapters we just saved for inference, set `False` to `True`:

In [None]:
if False:
    from unsloth import FastLanguageModel
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "lora_model", # YOUR MODEL YOU USED FOR TRAINING
        max_seq_length = 2048,
        load_in_4bit = True,
    )

### Saving to float16 for VLLM

We also support saving to `float16` directly. Select `merged_16bit` for float16 or `merged_4bit` for int4. We also allow `lora` adapters as a fallback. Use `push_to_hub_merged` to upload to your Hugging Face account! You can go to https://huggingface.co/settings/tokens for your personal tokens.

In [None]:
# Merge to 16bit
if False:
    model.save_pretrained_merged("model", tokenizer, save_method = "merged_16bit")
if True: # Pushing to HF Hub
    model.push_to_hub_merged("shashanksrajak/gemma-3-finetuned-medical-QnA", tokenizer, save_method = "merged_16bit", token = "")

# Merge to 4bit
if False:
    model.save_pretrained_merged("model", tokenizer, save_method = "merged_4bit",)
if False: # Pushing to HF Hub
    model.push_to_hub_merged("hf/model", tokenizer, save_method = "merged_4bit", token = "")

# Just LoRA adapters
if False:
    model.save_pretrained("model")
    tokenizer.save_pretrained("model")
if False: # Pushing to HF Hub
    model.push_to_hub("hf/model", token = "")
    tokenizer.push_to_hub("hf/model", token = "")


### GGUF / llama.cpp Conversion
To save to `GGUF` / `llama.cpp`, we support it natively now for all models! For now, you can convert easily to `Q8_0, F16 or BF16` precision. `Q4_K_M` for 4bit will come later!

In [None]:
!pip install mistral_common

In [None]:
if True: # Change to True to save to GGUF
    model.save_pretrained_gguf(
        "model",
        quantization_type = "Q8_0", # For now only Q8_0, BF16, F16 supported
    )

Likewise, if you want to instead push to GGUF to your Hugging Face account, set `if False` to `if True` and add your Hugging Face token and upload location!

In [None]:
if True: # Change to True to upload GGUF
    model.push_to_hub_gguf(
        "model",
        quantization_type = "Q8_0", # Only Q8_0, BF16, F16 supported
        repo_id = "shashanksrajak/gemma-3-finetuned-medicalQnA-gguf",
        token = "",
    )

### Convert to `tflite` format
Now we will convert the model into tflite format that can be directly run completely on device for local inference.

We will use `ai-edge-torch` from google for this task.

Check this repo https://github.com/google-ai-edge/ai-edge-torch/tree/main/ai_edge_torch/generative/

In [None]:
!pip install ai-edge-torch -q

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# The directory where you saved your merged model
# Use the directory name from your `save_pretrained_merged` command
model_directory = "model"

# Load the merged model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_directory)
tokenizer = AutoTokenizer.from_pretrained(model_directory)

# Set the model to evaluation mode
model.eval()