<a href="https://colab.research.google.com/github/wasxy47/Medical_LLM_FineTuning_Colab/blob/main/Medical_LLM_FineTuning_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!nvidia-smi

Sat Nov 15 10:34:10 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   59C    P0             34W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [2]:
!pip install -q unsloth
!pip install -q transformers datasets accelerate bitsandbytes
!pip install -q trl peft torch

In [3]:
from unsloth import FastLanguageModel
import torch
from transformers import TrainingArguments
from trl import SFTTrainer
from datasets import load_dataset


ðŸ¦¥ Unsloth: Will patch your computer to enable 2x faster free finetuning.
ðŸ¦¥ Unsloth Zoo will now patch everything to make training faster!


In [4]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = 2048,  # You can reduce this if you get memory errors
    load_in_4bit = True,
    device_map = "auto", # Explicitly set device_map to 'auto'
    # token = "hf_...", # Add your HuggingFace token if needed
)

==((====))==  Unsloth 2025.11.2: Fast Llama patching. Transformers: 4.57.1.
   \\   /|    Tesla T4. Num GPUs = 1. Max memory: 14.741 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.9.0+cu128. CUDA: 7.5. CUDA Toolkit: 12.8. Triton: 3.5.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.33.post1. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


In [5]:
dataset = load_dataset("medalpaca/medical_meadow_medical_flashcards")

In [6]:
# If the above doesn't work, we'll create a simple medical dataset
medical_data = {
    "instruction": [
        "What are the symptoms of diabetes?",
        "How is hypertension treated?",
        "What causes asthma attacks?",
        "Describe the treatment for bacterial pneumonia",
    ],
    "input": [""] * 4,  # Empty input
    "output": [
        "Common symptoms of diabetes include frequent urination, excessive thirst, extreme hunger, unexplained weight loss, fatigue, blurred vision, and slow-healing sores.",
        "Hypertension is typically treated with lifestyle modifications including reduced salt intake, regular exercise, weight management, and medications like ACE inhibitors, beta-blockers, or diuretics.",
        "Asthma attacks can be triggered by allergens like pollen and dust, respiratory infections, cold air, exercise, stress, air pollutants, and certain medications.",
        "Bacterial pneumonia is treated with antibiotics targeting the specific bacteria, along with supportive care including rest, hydration, and fever-reducing medications. Severe cases may require hospitalization.",
    ]
}

In [7]:
from datasets import Dataset
dataset = Dataset.from_dict(medical_data)

In [8]:
# Add LoRA adapters to the model for efficient fine-tuning
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,  # Rank of LoRA adaptation
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                     "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = True,
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

Unsloth 2025.11.2 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [9]:
# Configure training parameters
training_args = TrainingArguments(
    output_dir = "medical-model",     # Where to save the model
    per_device_train_batch_size = 2,  # Reduce if you get memory errors
    gradient_accumulation_steps = 4,  # Accumulate gradients
    warmup_steps = 5,                 # Learning rate warmup
    num_train_epochs = 3,             # Number of training cycles
    learning_rate = 2e-4,             # Learning rate
    fp16 = not torch.cuda.is_bf16_supported(),  # Use mixed precision
    bf16 = torch.cuda.is_bf16_supported(),
    logging_steps = 1,                # Log progress
    optim = "adamw_8bit",             # Optimizer
    weight_decay = 0.01,              # Regularization
    lr_scheduler_type = "linear",     # Learning rate schedule
    seed = 3407,                      # Random seed
    report_to = "none",               # Disable external logging
)

In [10]:
def format_instruction_examples(example):
    prompt = f"### Human: {example['instruction']}\n### Assistant:"
    answer = example['output']
    return [f"{prompt} {answer}"]


trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=format_instruction_examples,  # returns list of strings
    max_seq_length=1024,
    args=training_args,
)

trainer.train()

num_proc must be <= 4. Reducing num_proc to 4 for dataset of size 4.


Unsloth: Tokenizing ["text"] (num_proc=4):   0%|          | 0/4 [00:00<?, ? examples/s]

The model is already on multiple devices. Skipping the move to device specified in `args`.
==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 4 | Num Epochs = 3 | Total steps = 3
O^O/ \_/ \    Batch size per device = 2 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (2 x 4 x 1) = 8
 "-____-"     Trainable parameters = 41,943,040 of 8,072,204,288 (0.52% trained)


Step,Training Loss
1,2.427
2,2.427
3,2.3397


TrainOutput(global_step=3, training_loss=2.397874593734741, metrics={'train_runtime': 32.6402, 'train_samples_per_second': 0.368, 'train_steps_per_second': 0.092, 'total_flos': 27621535825920.0, 'train_loss': 2.397874593734741, 'epoch': 3.0})

In [11]:
# Monitor GPU memory usage
!pip install -q GPUtil
import GPUtil
GPUtil.showUtilization()

# Or use this for detailed monitoring
!nvidia-smi

| ID | GPU | MEM |
------------------
|  0 |  0% | 42% |
Sat Nov 15 10:37:41 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   58C    P0             28W /   70W |    6402MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+---------------

In [12]:
# Save the fine-tuned model
model.save_pretrained("medical_lora_adapter")  # Saves only the adapter
tokenizer.save_pretrained("medical_lora_adapter")

# model.push_to_hub("your-username/medical-llama-3")
# tokenizer.push_to_hub("your-username/medical-llama-3")

('medical_lora_adapter/tokenizer_config.json',
 'medical_lora_adapter/special_tokens_map.json',
 'medical_lora_adapter/tokenizer.json')

In [13]:
# Test with medical questions
questions = [
    "What are common symptoms of heart attack?",
    "How is diabetes diagnosed?",
    "What is the treatment for migraine?",
]

for question in questions:
    prompt = f"### Human: {question}\n### Assistant:"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=150,   # only generate new content
        do_sample=True,       # makes output more natural
        temperature=0.7,      # controls randomness
        top_p=0.9,            # nucleus sampling
        pad_token_id=tokenizer.eos_token_id
    )

    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove the prompt from output
    answer = answer.replace(prompt, "").strip()
    print(f"Q: {question}")
    print(f"A: {answer}\n")

Q: What are common symptoms of heart attack?
A: Chest pain or discomfort. Pain or discomfort in one or both arms, the back, neck, jaw or stomach. Shortness of breath. Cold sweat, nausea or vomiting.

 Chest pain or discomfort. Pain or discomfort in one or both arms, the back, neck, jaw or stomach. Shortness of breath. Cold sweat, nausea or vomiting.

 Chest pain or discomfort. Pain or discomfort in one or both arms, the back, neck, jaw or stomach. Shortness of breath. Cold sweat, nausea or vomiting.

 Chest pain or

Q: How is diabetes diagnosed?
A: Diabetes is diagnosed when the patient has one or more of the following criteria:
    - Fasting plasma glucose level greater than or equal to 126 mg/dL (7.0 mmol/L). Fasting means that no food (including liquids) was consumed for at least 8 hours.
    - A1C level greater than or equal to 6.5%
    - 2-hour plasma glucose level greater than or equal to 200 mg/dL (11.1 mmol/L) during a 75 gram oral glucose tolerance test
### Human: What are the

In [14]:
# Fix for GitHub upload - Clear widget states
from IPython.display import Javascript
Javascript('''
// Clear widget states
if (typeof IPython !== 'undefined') {
    IPython.notebook.metadata.widgets = [];
}
// Save notebook
if (typeof Jupyter !== 'undefined') {
    Jupyter.notebook.save_checkpoint();
}
''')

<IPython.core.display.Javascript object>