# Medical Reasoning Fine-tuning with TRL and LoRA

## Introduction

This notebook demonstrates how to fine-tune a language model for medical reasoning tasks using **TRL (Transformers Reinforcement Learning)** and **LoRA (Low-Rank Adaptation)**. If you're new to these concepts, here's what you need to know:

### What is LoRA?
**LoRA (Low-Rank Adaptation)** is an efficient fine-tuning technique that:
- Freezes the original model weights and adds small trainable matrices
- Dramatically reduces memory usage and training time
- Achieves performance comparable to full fine-tuning
- Creates lightweight adapters that can be easily shared and swapped

### What is TRL?
**TRL (Transformers Reinforcement Learning)** is a library that:
- Provides easy-to-use trainers for supervised fine-tuning (SFT)
- Supports advanced training techniques like RLHF and DPO
- Integrates seamlessly with Hugging Face transformers and PEFT

### What We'll Do
In this notebook, we'll:
1. Load and preprocess a medical reasoning dataset
2. Configure quantization for memory efficiency
3. Set up LoRA adapters for the language model
4. Train using TRL's SFTTrainer
5. Save the trained LoRA adapter for inference

Let's get started!

## Step 1: Loading and Preprocessing the Dataset

We'll use the **FreedomIntelligence/medical-o1-reasoning-SFT** dataset, which contains medical questions with complex chain-of-thought reasoning.

### Dataset Format
Each example contains:
- **Question**: The medical question or scenario
- **Complex_CoT**: Chain-of-thought reasoning (the model's "thinking" process)
- **Response**: The final answer

### Why This Format?
We format the data as conversations with a special `<think>` tag to teach the model to:
1. Show its reasoning process explicitly
2. Provide clear, structured answers
3. Think step-by-step through complex medical scenarios

In [None]:
from datasets import load_dataset

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")


def preprocess_function(example):
    # Format as a conversation for SFTTrainer
    messages = [
        {"role": "user", "content": example["Question"]},
        {
            "role": "assistant",
            "content": f"<think>{example['Complex_CoT']}</think>{example['Response']}",
        },
    ]
    return {"messages": messages}


dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])

# Split the training dataset to create train/validation/test sets
# (80% train, 10% validation, 10% test)
first_split = dataset["train"].train_test_split(test_size=0.2, seed=816)  # 80% train, 20% temp
temp_dataset = first_split["test"]
second_split = temp_dataset.train_test_split(test_size=0.5, seed=816)  # Split the 20% into 10% each

train_dataset = first_split["train"]  # 80%
eval_dataset = second_split["train"]  # 10%
test_dataset = second_split["test"]  # 10%

print("Sample:", next(iter(train_dataset)))

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Validation samples: {len(eval_dataset)}")

## Step 2: Data Collation

The **DataCollatorWithFlattening** is a special data collator from TRL that:
- Handles variable-length sequences efficiently
- Flattens conversation data for training
- Optimizes memory usage during batch processing

In [None]:
from transformers import DataCollatorWithFlattening

data_collator = DataCollatorWithFlattening()

## Step 3: Loading Configuration

We use a YAML configuration file to manage training parameters. This approach:
- Keeps settings organized and version-controlled
- Makes it easy to experiment with different hyperparameters
- Allows reproducible training runs

### Key Parameters Explained:
- **base_model_name**: The foundation model to fine-tune
- **lora_rank**: Controls the size of LoRA adapters (higher = more parameters)
- **lora_alpha**: Scaling factor for LoRA updates (affects learning strength)
- **batch_size**: Number of samples processed together
- **epochs_to_train**: Number of complete passes through the dataset

In [None]:
from os.path import join

import yaml

In [None]:
# Load configuration from config.yaml
with open("config.yaml") as f:
    config = yaml.safe_load(f)

MODEL_NAME = config["base_model_name"]
print(f"Using model: {MODEL_NAME}")

adapter_dir = join(config["adapter_dir_prefix"], MODEL_NAME)
print(f"LoRA adapter directory will be saved to: {adapter_dir}")

lora_rank = config["lora_rank"]
lora_alpha = config["lora_alpha"]
print(f"LoRA rank is {lora_rank} and LoRA alpha is {lora_alpha}")

batch_size = int(config["batch_size"])
epochs_to_train = int(config["epochs_to_train"])
max_output_length = int(config["max_output_length"])

## Step 4: Setting Up Model Components

Now we'll import the essential libraries:

- **PyTorch**: The underlying tensor computation framework
- **PEFT (LoraConfig)**: Handles parameter-efficient fine-tuning
- **Transformers**: Provides the model and tokenizer classes
- **BitsAndBytesConfig**: Enables memory-efficient quantization
- **TRL**: Provides the specialized trainer for supervised fine-tuning

In [None]:
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

## Step 5: Quantization Configuration

**Quantization** reduces memory usage by representing model weights with fewer bits:

- **load_in_4bit**: Use 4-bit precision instead of 16/32-bit (4x memory reduction!)
- **bnb_4bit_compute_dtype**: Use bfloat16 for computations (stable training)
- **bnb_4bit_use_double_quant**: Apply quantization twice for even more savings  
- **bnb_4bit_quant_type**: "nf4" is an optimized 4-bit format

This allows us to train larger models on consumer GPUs!

In [None]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # Changed from float16 to bfloat16
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

## Step 6: Loading Model and Tokenizer

Here we load the base model with several optimization flags:

- **quantization_config**: Apply the 4-bit quantization we configured
- **device_map="auto"**: Automatically distribute model across available GPUs
- **attn_implementation="flash_attention_2"**: Use optimized attention for speed
- **local_files_only**: Use cached models when available

The tokenizer converts text to tokens the model can understand. We set `pad_token = eos_token` because some models don't have a dedicated padding token.

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,  # float16 to bfloat16
    use_cache=True,  # Whether to cache attention outputs to speed up inference
    quantization_config=bnb_config,
    local_files_only=True,  # Use cache first
    device_map="auto",
    attn_implementation="flash_attention_2",
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    local_files_only=True,
)
tokenizer.pad_token = tokenizer.eos_token

## Step 7: LoRA Configuration

Now we configure the **LoRA adapter** - this is the heart of parameter-efficient fine-tuning!

### LoRA Parameters Explained:
- **r (rank)**: Size of the low-rank matrices
- **lora_alpha**: Scaling factor for LoRA updates (higher = stronger adaptation)
- **lora_dropout**: Prevents overfitting in the LoRA layers
- **target_modules**: Which parts of the model to adapt (attention layers are most effective)

### Why These Modules?
We target the attention projection layers because they:
- Control how the model focuses on different parts of the input
- Are most impactful for learning new reasoning patterns
- Provide good performance-to-parameter ratio

In [None]:
peft_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.o_proj"],
)

## Step 8: Training Configuration

The **SFTConfig** defines how we want to train our model:

### Memory & Performance:
- **gradient_accumulation_steps=4**: Process 4 batches before updating (saves memory)
- **bf16=True**: Use bfloat16 precision (faster, stable training)
- **gradient_checkpointing=True**: Trade compute for memory (essential for large models)

### Training Strategy:
- **completion_only_loss=True**: Only train on assistant responses, not user questions
- **loss_type="dft"**: Dynamic fine-tuning loss (TRL's improved loss function)

### Monitoring:
- **eval_strategy/save_strategy**: Save and evaluate every 100 steps
- **logging_steps=50**: Log training metrics frequently

In [None]:
# Configure the SFT training parameters
sft_config = SFTConfig(
    output_dir="./results",
    num_train_epochs=epochs_to_train,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_length=max_output_length,
    logging_steps=50,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=100,
    save_steps=100,
    bf16=True,
    gradient_checkpointing=True,
    loss_type="dft",  # Dynamic fine tuning
    completion_only_loss=True,  # Train only on assistant responses
)

## Step 9: Creating the SFT Trainer

The **SFTTrainer** (Supervised Fine-Tuning Trainer) is TRL's specialized trainer that:
- Handles conversation formatting automatically
- Integrates with PEFT for LoRA training  
- Provides optimized training loops for language model fine-tuning
- Supports advanced features like completion-only training

This single trainer handles all the complexity of modern LLM fine-tuning!

In [None]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=sft_config,
    peft_config=peft_config,
    data_collator=data_collator,
)

## Step 10: Memory Usage Check

Before training, let's check our GPU memory usage to ensure we have enough resources:

- **Memory allocated**: Currently used GPU memory
- **Memory reserved**: Memory reserved by PyTorch for operations
- **Memory available**: Total GPU memory capacity

This helps us verify that our quantization and optimization settings are working correctly!

In [None]:
# Check GPU memory usage before training
GB = 2**30
if torch.cuda.is_available():
    print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / GB:.2f} GB")
    print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / GB:.2f} GB")
    print(f"GPU Memory available: {torch.cuda.get_device_properties(0).total_memory / GB:.2f} GB")
else:
    print("CUDA is not available")

## Step 11: Start Training! ðŸš€

This is where the magic happens! The trainer will:

1. **Forward pass**: Run examples through the model
2. **Compute loss**: Measure prediction accuracy  
3. **Backward pass**: Calculate gradients
4. **Update LoRA weights**: Apply parameter updates (only ~0.1% of total parameters!)
5. **Evaluate**: Test on validation data periodically
6. **Save checkpoints**: Store progress for recovery

**Note**: Training progress will show loss decreasing and evaluation metrics improving. The beauty of LoRA is that we're only updating a tiny fraction of the model's parameters while achieving full fine-tuning performance!

In [None]:
trainer.train()

## Step 12: Save the LoRA Adapter

After training completes, we save our LoRA adapter:

### What Gets Saved:
- **LoRA weight matrices**: The small adapters we trained (~few MB)
- **Adapter configuration**: Settings like rank, alpha, target modules
- **Tokenizer**: Ensures consistency during inference

### Why This is Amazing:
- The base model stays unchanged (no need to duplicate GBs of weights)
- Multiple LoRA adapters can be created for different tasks
- Adapters can be easily shared, version-controlled, and swapped
- You can even combine multiple LoRA adapters!

In [None]:
# Save the LoRA adapter
print(f"Saving LoRA adapter to {adapter_dir}")

trainer.model.save_pretrained(adapter_dir)
print(f"LoRA adapter saved successfully to {adapter_dir}!")
tokenizer.save_pretrained(adapter_dir)

In [None]:
print(
    "Now run the notebook `trl_medical_reasoning_inference.ipynb` to use the LoRA fine-tuned model."
)

## ðŸŽ‰ Training Complete!

Congratulations! You've successfully:

1. âœ… Learned about LoRA and TRL
2. âœ… Loaded and preprocessed medical reasoning data  
3. âœ… Set up memory-efficient quantization
4. âœ… Configured LoRA adapters for parameter-efficient training
5. âœ… Fine-tuned a language model using TRL's SFTTrainer
6. âœ… Saved your trained LoRA adapter

### Next Steps:
- Use the inference notebook to test your fine-tuned model
- Experiment with different LoRA ranks and alphas
- Try training on different datasets
- Combine multiple LoRA adapters for multi-task models

### Key Takeaways:
- **LoRA** enables efficient fine-tuning with minimal memory
- **TRL** provides state-of-the-art training techniques
- **Quantization** makes large models accessible on consumer hardware
- **Parameter-efficient fine-tuning** is the future of model customization!