<a href="https://colab.research.google.com/github/wasxy47/Medical_LLM_FineTuning_Colab/blob/main/Medical_LLM_FineTuning_Colab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

In [None]:
!pip install -q unsloth
!pip install -q transformers datasets accelerate bitsandbytes
!pip install -q trl peft torch

In [None]:
from unsloth import FastLanguageModel
import torch
from transformers import TrainingArguments
from trl import SFTTrainer
from datasets import load_dataset


In [None]:
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-bnb-4bit",
    max_seq_length = 2048,  # You can reduce this if you get memory errors
    load_in_4bit = True,
    device_map = "auto", # Explicitly set device_map to 'auto'
    # token = "hf_...", # Add your HuggingFace token if needed
)

In [None]:
dataset = load_dataset("medalpaca/medical_meadow_medical_flashcards")

In [None]:
# If the above doesn't work, we'll create a simple medical dataset
medical_data = {
    "instruction": [
        "What are the symptoms of diabetes?",
        "How is hypertension treated?",
        "What causes asthma attacks?",
        "Describe the treatment for bacterial pneumonia",
    ],
    "input": [""] * 4,  # Empty input
    "output": [
        "Common symptoms of diabetes include frequent urination, excessive thirst, extreme hunger, unexplained weight loss, fatigue, blurred vision, and slow-healing sores.",
        "Hypertension is typically treated with lifestyle modifications including reduced salt intake, regular exercise, weight management, and medications like ACE inhibitors, beta-blockers, or diuretics.",
        "Asthma attacks can be triggered by allergens like pollen and dust, respiratory infections, cold air, exercise, stress, air pollutants, and certain medications.",
        "Bacterial pneumonia is treated with antibiotics targeting the specific bacteria, along with supportive care including rest, hydration, and fever-reducing medications. Severe cases may require hospitalization.",
    ]
}

In [None]:
from datasets import Dataset
dataset = Dataset.from_dict(medical_data)

In [None]:
# Add LoRA adapters to the model for efficient fine-tuning
model = FastLanguageModel.get_peft_model(
    model,
    r = 16,  # Rank of LoRA adaptation
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                     "gate_proj", "up_proj", "down_proj"],
    lora_alpha = 16,
    lora_dropout = 0,
    bias = "none",
    use_gradient_checkpointing = True,
    random_state = 3407,
    use_rslora = False,
    loftq_config = None,
)

In [None]:
# Configure training parameters
training_args = TrainingArguments(
    output_dir = "medical-model",     # Where to save the model
    per_device_train_batch_size = 2,  # Reduce if you get memory errors
    gradient_accumulation_steps = 4,  # Accumulate gradients
    warmup_steps = 5,                 # Learning rate warmup
    num_train_epochs = 3,             # Number of training cycles
    learning_rate = 2e-4,             # Learning rate
    fp16 = not torch.cuda.is_bf16_supported(),  # Use mixed precision
    bf16 = torch.cuda.is_bf16_supported(),
    logging_steps = 1,                # Log progress
    optim = "adamw_8bit",             # Optimizer
    weight_decay = 0.01,              # Regularization
    lr_scheduler_type = "linear",     # Learning rate schedule
    seed = 3407,                      # Random seed
    report_to = "none",               # Disable external logging
)

In [None]:
def format_instruction_examples(example):
    prompt = f"### Human: {example['instruction']}\n### Assistant:"
    answer = example['output']
    return [f"{prompt} {answer}"]


trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=format_instruction_examples,  # returns list of strings
    max_seq_length=1024,
    args=training_args,
)

trainer.train()

In [None]:
# Monitor GPU memory usage
!pip install -q GPUtil
import GPUtil
GPUtil.showUtilization()

# Or use this for detailed monitoring
!nvidia-smi

In [None]:
# Save the fine-tuned model
model.save_pretrained("medical_lora_adapter")  # Saves only the adapter
tokenizer.save_pretrained("medical_lora_adapter")

# model.push_to_hub("your-username/medical-llama-3")
# tokenizer.push_to_hub("your-username/medical-llama-3")

In [None]:
# Test with medical questions
questions = [
    "What are common symptoms of heart attack?",
    "How is diabetes diagnosed?",
    "What is the treatment for migraine?",
]

for question in questions:
    prompt = f"### Human: {question}\n### Assistant:"
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    outputs = model.generate(
        **inputs,
        max_new_tokens=150,   # only generate new content
        do_sample=True,       # makes output more natural
        temperature=0.7,      # controls randomness
        top_p=0.9,            # nucleus sampling
        pad_token_id=tokenizer.eos_token_id
    )

    answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Remove the prompt from output
    answer = answer.replace(prompt, "").strip()
    print(f"Q: {question}")
    print(f"A: {answer}\n")