# Medical LLM Fine-Tuning with QLoRA
 
This notebook demonstrates how to fine-tune Llama 3.1-8B-Instruct for 
medical reasoning using QLoRA (Quantized Low-Rank Adaptation).
Inspired by: HuatuoGPT-o1 (https://github.com/FreedomIntelligence/HuatuoGPT-o1)

Dataset: FreedomIntelligence/medical-o1-reasoning-SFT
Model: meta-llama/Llama-3.1-8B-Instruct

Hardware: Optimized for Google Colab Free (T4 GPU, ~15GB VRAM)
Training time: ~2-3 hours on Colab Free
============================================================================

In [None]:
# ============================================================================
# 1. ENVIRONMENT SETUP
# ============================================================================

# Install required packages with compatible versions
!pip install -q \
    torch \
    transformers>=4.44.0 \
    datasets>=2.20.0 \
    accelerate>=0.34.0 \
    peft>=0.12.0 \
    trl>=0.9.0 \
    bitsandbytes>=0.44.0 \
    sentencepiece

# Reiniciar el kernel
#import IPython
#IPython.Application.instance().kernel.do_shutdown(True)

# ============================================================================
# 2. HUGGINGFACE AUTHENTICATION (REQUIRED)
# ============================================================================

**IMPORTANT**: This notebook uses Llama 3.1-8B-Instruct, which requires:
1. HuggingFace account
2. Approval to access Meta-Llama models
3. Valid access token

## Setup Instructions:

### Step 1: Request Llama Access (One-time)
1. Visit: https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct
2. Click **"Request Access"** and fill out Meta's form
3. Wait for approval email (typically 1-24 hours)

### Step 2: Create Access Token
1. Go to: https://huggingface.co/settings/tokens
2. Click **"New token"** → Name it (e.g., "medical-llm") → Select **"Read"** permission
3. Copy the token (starts with `hf_...`)

### Step 3: Run the authentication cell below ⬇️

---

In [None]:
# ============================================================================
# 2.1 AUTHENTICATE WITH HUGGINGFACE
# ============================================================================

from huggingface_hub import login
import os

# Check if already authenticated
try:
    from huggingface_hub import HfFolder
    token = HfFolder.get_token()
    if token:
        print("✅ Already authenticated with HuggingFace!")
        print(f"Token found: {token[:10]}..." if token else "No token")
    else:
        print("⚠️ No token found. Please authenticate below.")
        raise ValueError("Authentication required")
except:
    print("=" * 80)
    print("🔐 HUGGINGFACE AUTHENTICATION REQUIRED")
    print("=" * 80)
    print("\nThis will open an interactive prompt to enter your HuggingFace token.")
    print("Paste your token (starts with 'hf_...') and press Enter.\n")
    print("⚠️ If you don't have a token yet, follow the instructions above ⬆️\n")
    
    # Interactive login (will prompt for token)
    login()
    
    print("\n✅ Authentication successful!")
    print("You can now download Llama 3.1-8B-Instruct")
    print("=" * 80)

# Verify access to Llama model
print("\n🔍 Verifying access to meta-llama/Meta-Llama-3.1-8B-Instruct...")
try:
    from huggingface_hub import model_info
    info = model_info("meta-llama/Meta-Llama-3.1-8B-Instruct")
    print("✅ Access verified! You can proceed with the notebook.")
except Exception as e:
    print("❌ ERROR: Cannot access Llama 3.1 model")
    print(f"Error: {str(e)}\n")
    print("Possible issues:")
    print("  1. You haven't requested access to Meta-Llama models yet")
    print("     → Visit: https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct")
    print("  2. Your access request is still pending approval")
    print("     → Wait for approval email from Meta (1-24 hours)")
    print("  3. Invalid token")
    print("     → Generate a new token at: https://huggingface.co/settings/tokens")
    print("\n⚠️ DO NOT PROCEED until this is resolved")
    raise

In [None]:
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from trl import SFTTrainer
from datasets import load_dataset, Dataset
import json
from pathlib import Path
import pandas as pd
import matplotlib.pyplot as plt

# Check GPU availability
# QLoRA requires CUDA for 4-bit quantization
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

if device == "cuda":
    gpu_name = torch.cuda.get_device_name(0)
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"GPU: {gpu_name}")
    print(f"Total VRAM: {gpu_memory:.2f} GB")
else:
    print("⚠️ WARNING: No GPU detected. Training will be extremely slow.")

# ============================================================================
# 3. DATA UPLOAD (REQUIRED)
# ============================================================================

**IMPORTANT**: You need to upload the training data files to Colab.

## Option 1: Upload from Local Machine

If you have cloned the repository and have the data files:

1. Click the **folder icon** 📁 in the left sidebar
2. Click the **upload icon** ⬆️ 
3. Upload both files:
   - `train_data.jsonl` (~1.3 MB, ~450 examples)
   - `test_data.jsonl` (~145 KB, ~50 examples)

**The cell below will verify that files are uploaded before proceeding.**

In [None]:
# ============================================================================
# 2. CONFIGURATION
# ============================================================================

class TrainingConfig:
    """
    Centralized configuration for reproducibility and easy experimentation.
    
    These hyperparameters are optimized for:
    - Google Colab Free (T4 GPU, ~15GB VRAM)
    - ~450 training examples
    - Balance between quality and training time
    """
    
    # Model Configuration
    MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    
    # QLoRA Configuration
    # Using rank=16 provides good balance between parameter efficiency and model capacity
    LORA_R = 16
    # Alpha is typically set to 2*rank for stable training
    LORA_ALPHA = 32
    # We target attention layers as they're most impactful for adapting LLM behavior
    LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]
    LORA_DROPOUT = 0.05
    
    # Quantization Configuration
    # NF4 (4-bit NormalFloat) is optimal for LLMs as it preserves important weight distributions
    QUANTIZATION_TYPE = "nf4"
    # Double quantization further reduces memory with minimal quality loss
    USE_DOUBLE_QUANT = True
    COMPUTE_DTYPE = torch.bfloat16  # bfloat16 is more stable than float16 for training
    
    # Training Hyperparameters
    # Small batch size due to memory constraints; gradient accumulation compensates
    PER_DEVICE_BATCH_SIZE = 2
    # Effective batch size = 2 * 4 = 8
    GRADIENT_ACCUMULATION_STEPS = 4
    # 2e-4 is standard for LoRA; lower than full fine-tuning to avoid catastrophic forgetting
    LEARNING_RATE = 2e-4
    # More epochs on small dataset helps model learn the reasoning pattern
    NUM_EPOCHS = 3
    # Warmup prevents early training instability
    WARMUP_RATIO = 0.03
    # Longer sequences needed for complex medical reasoning chains
    MAX_SEQ_LENGTH = 2048
    
    # Optimization
    # AdamW is standard; paged optimizers reduce memory fragmentation
    OPTIMIZER_TYPE = "paged_adamw_32bit"
    # Cosine decay smoothly reduces LR, helping convergence
    LR_SCHEDULER_TYPE = "cosine"
    
    # Logging and Checkpointing
    LOGGING_STEPS = 10
    SAVE_STEPS = 50
    EVAL_STEPS = 50
    
    # Output
    OUTPUT_DIR = "./medical-llm-finetuned"
    
    # Reproducibility
    SEED = 42

config = TrainingConfig()
print("✓ Configuration loaded")

In [None]:
# ============================================================================
# 3. DATA LOADING AND PREPARATION
# ============================================================================

# Upload your train_data.jsonl and test_data.jsonl files using the file upload button
# or drag and drop them into Colab's file browser

from google.colab import files
import os

def load_local_jsonl(filepath):
    """Load JSONL file into list of dictionaries."""
    data = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            data.append(json.loads(line))
    return data

# Check if data files exist, otherwise prompt upload
if not os.path.exists('train_data.jsonl') or not os.path.exists('test_data.jsonl'):
    print("Please upload train_data.jsonl and test_data.jsonl")
    uploaded = files.upload()
else:
    print("✓ Data files found")

# Load datasets
train_data = load_local_jsonl('train_data.jsonl')
test_data = load_local_jsonl('test_data.jsonl')

print(f"Training examples: {len(train_data)}")
print(f"Test examples: {len(test_data)}")

# Display sample
print("\n" + "="*80)
print("SAMPLE TRAINING EXAMPLE")
print("="*80)
sample = train_data[0]
print(f"Question: {sample['question'][:200]}...")
print(f"\nComplex CoT: {sample['complex_cot'][:300]}...")
print(f"\nResponse: {sample['response'][:200]}...")

In [None]:
# ============================================================================
# 4. PROMPT FORMATTING
# ============================================================================

def format_medical_prompt(example):
    """
    Format training example into Llama 3.1 chat template with HuatuoGPT-o1 style.
    
    The model learns to:
    1. First generate step-by-step reasoning (## Thinking)
    2. Then provide a concise final answer (## Final Response)
    
    This two-stage format improves reasoning quality and interpretability.
    """
    system_prompt = (
        "You are a medical expert AI assistant. When answering medical questions, "
        "first provide your step-by-step reasoning in a '## Thinking' section, "
        "then provide your final answer in a '## Final Response' section."
    )
    
    # Combine CoT and response in HuatuoGPT-o1 format
    assistant_response = f"""## Thinking
{example['complex_cot']}

## Final Response
{example['response']}"""
    
    # Llama 3.1 chat template format
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": example['question']},
        {"role": "assistant", "content": assistant_response}
    ]
    
    return messages

# Test formatting
formatted_example = format_medical_prompt(train_data[0])
print("✓ Prompt formatting function ready")
print(f"\nFormatted message structure:")
for msg in formatted_example:
    print(f"  - {msg['role']}: {len(msg['content'])} chars")

In [None]:
# ============================================================================
# 5. MODEL AND TOKENIZER LOADING
# ============================================================================

# Configure 4-bit quantization
# This reduces model memory footprint from ~16GB to ~4GB with minimal quality loss
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type=config.QUANTIZATION_TYPE,
    bnb_4bit_compute_dtype=config.COMPUTE_DTYPE,
    bnb_4bit_use_double_quant=config.USE_DOUBLE_QUANT,
)

print("Loading base model (this may take a few minutes)...")

# Load model with quantization
# device_map="auto" automatically distributes model across available devices
model = AutoModelForCausalLM.from_pretrained(
    config.MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(config.MODEL_NAME)

# Llama models don't have a default pad token; we set it to EOS to avoid warnings
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

# Prepare model for k-bit training
# This freezes base model weights and prepares LoRA adapter layers
model = prepare_model_for_kbit_training(model)

print("✓ Model and tokenizer loaded successfully")
print(f"Model memory footprint: {model.get_memory_footprint() / 1e9:.2f} GB")

In [None]:
# ============================================================================
# 6. LORA CONFIGURATION
# ============================================================================

# Configure LoRA adapters
# LoRA adds trainable low-rank matrices to attention layers while keeping
# the base model frozen, dramatically reducing trainable parameters
lora_config = LoraConfig(
    r=config.LORA_R,
    lora_alpha=config.LORA_ALPHA,
    target_modules=config.LORA_TARGET_MODULES,
    lora_dropout=config.LORA_DROPOUT,
    bias="none",
    task_type="CAUSAL_LM",
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)

# Print trainable parameters
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
trainable_percentage = 100 * trainable_params / total_params

print("✓ LoRA adapters configured")
print(f"Trainable parameters: {trainable_params:,} ({trainable_percentage:.2f}%)")
print(f"Total parameters: {total_params:,}")
print(f"\nMemory savings: ~{100 - trainable_percentage:.1f}% of parameters frozen")

In [None]:
# ============================================================================
# 7. DATASET PREPARATION FOR TRAINING
# ============================================================================

def format_for_trainer(example):
    """
    Format example for SFTTrainer using Llama's chat template.
    
    SFTTrainer expects a 'text' field with the fully formatted prompt.
    We use apply_chat_template to ensure proper special token handling.
    """
    messages = format_medical_prompt(example)
    # apply_chat_template handles special tokens (<|begin_of_text|>, etc.)
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False  # We already have the assistant's response
    )
    return {"text": text}

# Convert to HuggingFace Dataset format
# SFTTrainer works best with Dataset objects for built-in batching and shuffling
train_dataset = Dataset.from_list(train_data)
test_dataset = Dataset.from_list(test_data)

# Apply formatting
train_dataset = train_dataset.map(format_for_trainer)
test_dataset = test_dataset.map(format_for_trainer)

print("✓ Datasets prepared for training")
print(f"Train dataset: {len(train_dataset)} examples")
print(f"Test dataset: {len(test_dataset)} examples")

# Preview a formatted training example
print("\n" + "="*80)
print("FORMATTED TRAINING EXAMPLE (first 500 chars)")
print("="*80)
print(train_dataset[0]['text'][:500] + "...")

In [None]:
# ============================================================================
# 8. TRAINING CONFIGURATION (ACTUALIZADO para trl reciente)
# ============================================================================

from trl import SFTConfig

# SFTConfig reemplaza TrainingArguments y añade parámetros específicos de SFT
training_args = SFTConfig(
    # Output
    output_dir=config.OUTPUT_DIR,
    
    # Training hyperparameters
    num_train_epochs=config.NUM_EPOCHS,
    per_device_train_batch_size=config.PER_DEVICE_BATCH_SIZE,
    per_device_eval_batch_size=config.PER_DEVICE_BATCH_SIZE,
    gradient_accumulation_steps=config.GRADIENT_ACCUMULATION_STEPS,
    
    # Optimization
    learning_rate=config.LEARNING_RATE,
    optim=config.OPTIMIZER_TYPE,
    lr_scheduler_type=config.LR_SCHEDULER_TYPE,
    warmup_ratio=config.WARMUP_RATIO,
    weight_decay=0.01,
    
    # Memory optimization
    gradient_checkpointing=True,
    fp16=False,
    bf16=True,
    
    # Logging and evaluation
    logging_steps=config.LOGGING_STEPS,
    eval_steps=config.EVAL_STEPS,
    save_steps=config.SAVE_STEPS,
    eval_strategy="steps",
    save_strategy="steps",
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    
    # Reproducibility
    seed=config.SEED,
    data_seed=config.SEED,
    
    # Performance
    dataloader_num_workers=0,
    group_by_length=False,
    
    # Reporting
    report_to="none",
    run_name="medical-llm-sft",
    
    # ============================================
    # SFT-specific parameters (nuevo en SFTConfig)
    # ============================================
    max_length=config.MAX_SEQ_LENGTH,
    packing=False,  # Disable for medical data to preserve context
    dataset_text_field="text",
)

print("✓ Training configuration ready (SFTConfig)")

In [None]:
# ============================================================================
# 9. INITIALIZE TRAINER (SIMPLIFICADO)
# ============================================================================

from trl import SFTTrainer

# Initialize SFTTrainer (ahora mucho más simple)
trainer = SFTTrainer(
    model=model,
    args=training_args,  # SFTConfig contiene todo
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    processing_class=tokenizer,
)

print("✓ Trainer initialized and ready for training")
print(f"\nEstimated training time on Colab T4:")
print(f"  ~{len(train_dataset) * config.NUM_EPOCHS / (config.PER_DEVICE_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS * 10):.1f} minutes per epoch")
print(f"  Total: ~{len(train_dataset) * config.NUM_EPOCHS / (config.PER_DEVICE_BATCH_SIZE * config.GRADIENT_ACCUMULATION_STEPS * 10) * config.NUM_EPOCHS:.1f} minutes")

In [None]:
# ============================================================================
# 10. TRAINING
# ============================================================================

print("Starting training...")
print("This will take approximately 2-3 hours on Colab Free T4 GPU")
print("="*80)

# Train the model
# The trainer will automatically handle batching, gradient accumulation,
# checkpointing, and evaluation
train_result = trainer.train()

print("="*80)
print("✓ Training completed!")
print(f"Training loss: {train_result.training_loss:.4f}")
print(f"Training took: {train_result.metrics['train_runtime']:.2f} seconds")

# Save final metrics
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)

In [None]:
# ============================================================================
# 11. EVALUATION
# ============================================================================

print("Evaluating fine-tuned model on test set...")
eval_results = trainer.evaluate()

print("="*80)
print("EVALUATION RESULTS")
print("="*80)
for key, value in eval_results.items():
    print(f"{key}: {value:.4f}")

# Save evaluation results
trainer.save_metrics("eval", eval_results)

In [None]:
# ============================================================================
# 12. QUALITATIVE EVALUATION - GENERATE SAMPLE PREDICTIONS
# ============================================================================

def generate_response(question, max_new_tokens=512):
    """
    Generate a response for a medical question using the fine-tuned model.
    
    Returns both the thinking process and final response in HuatuoGPT-o1 format.
    """
    system_prompt = (
        "You are a medical expert AI assistant. When answering medical questions, "
        "first provide your step-by-step reasoning in a '## Thinking' section, "
        "then provide your final answer in a '## Final Response' section."
    )
    
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": question}
    ]
    
    # Format and tokenize
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True  # Add prompt for model to continue
    )
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    # Generate
    # temperature=0.7 balances creativity and consistency
    # top_p=0.9 uses nucleus sampling for more natural text
    outputs = model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        temperature=0.7,
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.pad_token_id,
        eos_token_id=tokenizer.eos_token_id,
    )
    
    # Decode and extract only the new tokens (response)
    response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
    return response

# Test on a few examples from test set
print("="*80)
print("SAMPLE PREDICTIONS FROM FINE-TUNED MODEL")
print("="*80)

num_samples = 3
for i in range(min(num_samples, len(test_data))):
    example = test_data[i]
    
    print(f"\n{'='*80}")
    print(f"EXAMPLE {i+1}")
    print(f"{'='*80}")
    print(f"\nQUESTION:")
    print(example['question'])
    
    print(f"\n{'-'*80}")
    print("MODEL RESPONSE:")
    print(f"{'-'*80}")
    response = generate_response(example['question'])
    print(response)
    
    print(f"\n{'-'*80}")
    print("EXPECTED (Ground Truth):")
    print(f"{'-'*80}")
    print(f"## Thinking\n{example['complex_cot'][:300]}...")
    print(f"\n## Final Response\n{example['response'][:200]}...")

In [None]:
# ============================================================================
# 13. SAVE FINE-TUNED MODEL
# ============================================================================

# Save LoRA adapters
# These are small (~100MB) and contain only the trained weights
final_model_path = f"{config.OUTPUT_DIR}/final_model"
model.save_pretrained(final_model_path)
tokenizer.save_pretrained(final_model_path)

print(f"✓ Model saved to: {final_model_path}")
print(f"\nSaved files:")
!ls -lh {final_model_path}

# Optional: Create a zip file for easy download
!zip -r medical_llm_adapters.zip {final_model_path}
print("\n✓ Created medical_llm_adapters.zip for download")

In [None]:
# ============================================================================
# 14. DOWNLOAD MODEL (Optional)
# ============================================================================

# Download the model adapters to your local machine
# Uncomment the line below to trigger download
# files.download('medical_llm_adapters.zip')

print("To download the model:")
print("1. Uncomment the line above, or")
print("2. Use the file browser on the left to download manually")

In [None]:
# ============================================================================
# 15. NEXT STEPS
# ============================================================================

print("""
✅ TRAINING COMPLETED SUCCESSFULLY!

Next Steps:
-----------
1. Download the fine-tuned LoRA adapters (medical_llm_adapters.zip)
   
2. To use the model locally, load it with:
   ```python
   from transformers import AutoModelForCausalLM, AutoTokenizer
   from peft import PeftModel
   
   base_model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
   model = PeftModel.from_pretrained(base_model, "./path/to/adapters")
   tokenizer = AutoTokenizer.from_pretrained("./path/to/adapters")
   ```

3. Optional: Push to HuggingFace Hub for easy sharing
   - Create account at huggingface.co
   - Run: model.push_to_hub("your-username/medical-llm-finetuned")

4. Run more comprehensive evaluation on medical benchmarks

5. Compare with baseline Llama 3.1-8B-Instruct to quantify improvement

Resources:
----------
- HuatuoGPT-o1 Paper: https://arxiv.org/abs/2412.18925
- QLoRA Paper: https://arxiv.org/abs/2305.14314
- TRL Documentation: https://huggingface.co/docs/trl
""")