# Medical Reasoning Fine-tuning with TRL and LoRA

> Important
> ---------
> This notebook is for educational purposes only.

## Introduction

This notebook demonstrates how to fine-tune a language model for medical reasoning tasks using [**TRL (Transformers Reinforcement Learning)**](https://huggingface.co/docs/trl/index) and [**LoRA (Low-Rank Adaptation)**](https://tonyreina.github.io/lora/getting-started/what-is-lora/). If you're new to these concepts, here's what you need to know:

### What is LoRA?
[**LoRA (Low-Rank Adaptation)**](https://tonyreina.github.io/lora/getting-started/what-is-lora/) is an efficient fine-tuning technique that:
- Freezes the original model weights and adds small trainable matrices
- Dramatically reduces memory usage and training time
- Achieves performance comparable to full fine-tuning
- Creates lightweight adapters that can be easily shared and swapped

### What is TRL?
[**TRL (Transformers Reinforcement Learning)**](https://huggingface.co/docs/trl/index) is a library that:
- Provides easy-to-use trainers for supervised fine-tuning (SFT)
- Supports advanced training techniques like RLHF and DPO
- Integrates seamlessly with Hugging Face transformers and PEFT

### What We'll Do
In this notebook, we'll:
1. Load and preprocess a medical reasoning dataset
2. Configure quantization for memory efficiency
3. Set up LoRA adapters for the language model
4. Train using TRL's SFTTrainer
5. Save the trained LoRA adapter for inference

Let's get started!

## Step 1: Loading and Preprocessing the Dataset

We'll use the [**FreedomIntelligence/medical-o1-reasoning-SFT** dataset](https://huggingface.co/datasets/FreedomIntelligence/medical-o1-reasoning-SFT), which contains medical questions with complex chain-of-thought reasoning.

### Dataset Format
Each example contains:
- **Question**: The medical question or scenario
- **Complex_CoT**: Chain-of-thought reasoning (the model's "thinking" process)
- **Response**: The final answer

### Chain-of-Thought (CoT) Format

1. **System prompt format**: Custom instructions followed by `/think` to enable extended thinking mode
2. **User message**: The medical question
3. **Assistant response**: `/think [reasoning process] [final answer]`

> NOTE
> For this example, I am using the HuggingFace SmolLM2-135M-Instruct
> base model to fine-tune. This model **does not** have a template
> to work with the CoT (`/think`) tags, but I include them here
> anyway to show how that would work with such a dataset and model.

In [19]:
from datasets import load_dataset

dataset = load_dataset("FreedomIntelligence/medical-o1-reasoning-SFT", "en")


def preprocess_function(example):
    # Format as a conversation for SFTTrainer with system prompt for CoT
    # Note: SmolLM3-3B requires "/think" at the end of system prompt to enable extended thinking
    messages = [
        {
            "role": "system",
            "content": "You are a medical AI assistant. "
            "When answering medical questions, "
            "use /think to show your reasoning process "
            "before providing your final answer. "
            "Structure your response as: /think "
            "[your detailed reasoning] [final answer]./think",
        },
        {"role": "user", "content": example["Question"]},
        {
            "role": "assistant",
            "content": f"/think {example['Complex_CoT']} {example['Response']}",
        },
    ]
    return {"messages": messages}


dataset = dataset.map(preprocess_function, remove_columns=["Question", "Response", "Complex_CoT"])

# Split the training dataset to create train/validation/test sets
# (80% train, 10% validation, 10% test)
first_split = dataset["train"].train_test_split(test_size=0.2, seed=816)  # 80% train, 20% temp
temp_dataset = first_split["test"]
second_split = temp_dataset.train_test_split(test_size=0.5, seed=816)  # Split the 20% into 10% each

train_dataset = first_split["train"]  # 80%
eval_dataset = second_split["train"]  # 10%
test_dataset = second_split["test"]  # 10%

print("Sample:", next(iter(train_dataset)))

print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Validation samples: {len(eval_dataset)}")

Sample: {'messages': [{'content': 'You are a medical AI assistant. When answering medical questions, use /think to show your reasoning process before providing your final answer. Structure your response as: /think [your detailed reasoning] [final answer]./think', 'role': 'system'}, {'content': 'A patient presents with microcytic hypochromic anemia, hemoglobin level of 9%, serum iron of 20 Âµg/dL, ferritin level of 800 ng/mL, and transferrin percentage saturation of 64%. Based on these laboratory findings, what is the possible diagnosis?', 'role': 'user'}, {'content': "/think Okay, so we have a case of microcytic hypochromic anemia. That generally means the red blood cells are small and pale, which can occur in a few different conditions.\n\nLet's start by looking at the serum iron level. Itâ€™s reported at 20 Âµg/dL, which is definitely on the low side. Low serum iron is commonly seen in iron deficiency anemia, but it can also happen due to chronic diseases or other less common conditi

## Step 2: Data Collation

The [**DataCollatorWithFlattening**](https://huggingface.co/docs/transformers/main/en/main_classes/data_collator) is a special data collator from TRL that:
- Handles variable-length sequences efficiently
- Flattens conversation data for training
- Optimizes memory usage during batch processing

In [20]:
from transformers import DataCollatorWithFlattening

data_collator = DataCollatorWithFlattening()

## Step 3: Loading Configuration

We use a [YAML configuration file](config.yaml) to manage training parameters. This approach:
- Keeps settings organized and version-controlled
- Makes it easy to experiment with different hyperparameters
- Allows reproducible training runs

### Key Parameters Explained:
- **base_model_name**: The foundation model to fine-tune
- **lora_rank**: Controls the size of LoRA adapters (higher = more parameters)
- **lora_alpha**: Scaling factor for LoRA updates (affects learning strength)
- **batch_size**: Number of samples processed together
- **epochs_to_train**: Number of complete passes through the dataset

In [21]:
from os.path import join

import yaml

In [22]:
# Load configuration from config.yaml
with open("config.yaml") as f:
    config = yaml.safe_load(f)

MODEL_NAME = config["base_model_name"]
print(f"Using model: {MODEL_NAME}")

adapter_dir = join(config["adapter_dir_prefix"], MODEL_NAME)
print(f"LoRA adapter directory will be saved to: {adapter_dir}")

lora_rank = config["lora_rank"]
lora_alpha = config["lora_alpha"]
print(f"LoRA rank is {lora_rank} and LoRA alpha is {lora_alpha}")

batch_size = int(config["batch_size"])
epochs_to_train = int(config["epochs_to_train"])
max_output_length = int(config["max_output_length"])

Using model: microsoft/Phi-4-mini-instruct
LoRA adapter directory will be saved to: lora_adapter/microsoft/Phi-4-mini-instruct
LoRA rank is 16 and LoRA alpha is 32


## Step 4: Setting Up Model Components

Now we'll import the essential libraries:

- [**PyTorch**](https://pytorch.org): The underlying tensor computation framework
- [**PEFT (LoraConfig)**](https://huggingface.co/docs/peft/en/index): Handles parameter-efficient fine-tuning
- [**Transformers**](https://huggingface.co/docs/transformers/en/index): Provides the model and tokenizer classes
- [**BitsAndBytesConfig**](https://huggingface.co/docs/bitsandbytes/en/index): Enables memory-efficient quantization
- [**TRL**](https://huggingface.co/docs/trl/index): Provides the specialized trainer for supervised fine-tuning

In [23]:
import torch
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import SFTConfig, SFTTrainer

## Step 5: Quantization Configuration

**Quantization** with BitsAndBytes reduces memory usage by representing model weights with fewer bits:

- **load_in_4bit**: Use 4-bit precision instead of 16/32-bit (4x memory reduction!)
- **bnb_4bit_compute_dtype**: Use bfloat16 for computations (stable training)
- **bnb_4bit_use_double_quant**: Apply quantization twice for even more savings  
- **bnb_4bit_quant_type**: "nf4" is an optimized 4-bit format

The base model is quantized using BitsandBytes, but the fine-tuned
LoRA weights and gradients are at `bfloat16` precision.

This allows us to train larger models on consumer GPUs!

In [24]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,  # Changed from float16 to bfloat16
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

## Step 6: Loading Model and Tokenizer

Here we load the base model with several optimization flags:

- **quantization_config**: Apply the 4-bit quantization we configured
- **device_map="auto"**: Automatically distribute model across available GPUs
- **attn_implementation="flash_attention_2"**: Use optimized attention for speed
- **local_files_only**: Use cached models when available

The tokenizer converts text to tokens the model can understand. We set `pad_token = eos_token` because some models don't have a dedicated padding token.

In [25]:
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    dtype=torch.bfloat16,  # float16 to bfloat16
    use_cache=True,  # Whether to cache attention outputs to speed up inference
    quantization_config=bnb_config,
    local_files_only=True,  # Use cache first
    device_map="auto",
    attn_implementation="flash_attention_2",
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    local_files_only=True,
)
tokenizer.pad_token = tokenizer.eos_token

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Step 7: LoRA Configuration

Now we configure the **LoRA adapter** - this is the heart of parameter-efficient fine-tuning!

I have a more complete description of the LoRA approach [here](https://tonyreina.github.io/lora/getting-started/what-is-lora/).

### LoRA Parameters Explained:
- **r (rank)**: Size of the low-rank matrices
- **lora_alpha**: Scaling factor for LoRA updates (higher = stronger adaptation)
- **lora_dropout**: Prevents overfitting in the LoRA layers
- **target_modules**: Which parts of the model to adapt (attention layers are most effective)

### Why These Modules?
We target the attention projection layers because they:
- Control how the model focuses on different parts of the input
- Are most impactful for learning new reasoning patterns
- Provide good performance-to-parameter ratio

In [26]:
peft_config = LoraConfig(
    r=lora_rank,
    lora_alpha=lora_alpha,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=["self_attn.q_proj", "self_attn.v_proj", "self_attn.k_proj", "self_attn.o_proj"],
)

## Step 8: Training Configuration

The [**SFTConfig**](https://huggingface.co/docs/trl/en/sft_trainer) defines how we want to train our model:

### Memory & Performance:
- **gradient_accumulation_steps=4**: Process 4 batches before updating (saves memory)
- **bf16=True**: Use bfloat16 precision (faster, stable training)
- **gradient_checkpointing=True**: Trade compute for memory (essential for large models)

### Training Strategy:
- **completion_only_loss=True**: Only train on assistant responses, not user questions
- **loss_type="dft"**: Dynamic fine-tuning loss (TRL's improved loss function)

> NOTE
>
> [Dynamic Fine Tuning](https://huggingface.co/papers/2508.05629) is a modification
> to the traditional `log_loss` function used to train & fine-tune LLMs.
> It dynamically scales the `log_loss` by the probability of the token.
> In the [original paper](https://arxiv.org/pdf/2508.05629) it was shown
> to be as effective as RL approaches, such as PPO, GRPO, and DPO.

### Monitoring:
- **eval_strategy/save_strategy**: Save and evaluate every 100 steps
- **logging_steps=50**: Log training metrics frequently

> TIP
>
> The logs are sent to a local [MLFlow](https://mlflow.org) database.
> You can monitor the training in realtime by starting the
> MLFlow server (`pixi run -e cuda mlflow ui`) and opening
> the browser to `http://localhost:5000`.

In [None]:
# Configure the SFT training parameters
sft_config = SFTConfig(
    output_dir=config["training_results_dir"],
    num_train_epochs=epochs_to_train,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_length=max_output_length,
    logging_steps=50,
    save_strategy="steps",
    eval_strategy="steps",
    eval_steps=100,
    save_steps=100,
    bf16=True,
    gradient_checkpointing=True,
    loss_type="dft",  # Dynamic fine tuning
    completion_only_loss=True,  # Train only on assistant responses
    warmup_ratio=0.03,
    lr_scheduler_type="cosine_with_restarts",  # Cosine annealing with warm restarts for simulated annealing
    lr_scheduler_kwargs={
        "num_cycles": 2.0,  # Number of restart cycles (2.0 = 2 complete cosine waves with restarts)
    },
    weight_decay=0.01,
    remove_unused_columns=False,  # Keep all columns for manual evaluation after training
    report_to="mlflow",  # Use MLflow to track training experiments
)

## Step 9: Creating the SFT Trainer

The [**SFTTrainer** (Supervised Fine-Tuning Trainer)](https://huggingface.co/docs/trl/en/sft_trainer) is TRL's specialized trainer that:
- Handles conversation formatting automatically
- Integrates with PEFT for LoRA training  
- Provides optimized training loops for language model fine-tuning
- Supports advanced features like completion-only training

This single trainer handles all the complexity of modern LLM fine-tuning!

In [28]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    args=sft_config,
    peft_config=peft_config,
    data_collator=data_collator,
)

## Step 10: Memory Usage Check

Before training, let's check our GPU memory usage to ensure we have enough resources:

- **Memory allocated**: Currently used GPU memory
- **Memory reserved**: Memory reserved by PyTorch for operations
- **Memory available**: Total GPU memory capacity

This helps us verify that our quantization and optimization settings are working correctly!

In [29]:
# Check GPU memory usage before training
GB = 2**30
if torch.cuda.is_available():
    print(f"GPU Memory allocated: {torch.cuda.memory_allocated() / GB:.2f} GB")
    print(f"GPU Memory reserved: {torch.cuda.memory_reserved() / GB:.2f} GB")
    print(f"GPU Memory available: {torch.cuda.get_device_properties(0).total_memory / GB:.2f} GB")
else:
    print("CUDA is not available")

GPU Memory allocated: 5.79 GB
GPU Memory reserved: 9.15 GB
GPU Memory available: 47.35 GB


## Step 11: Start Training! ðŸš€

This is where the magic happens! The trainer will:

1. **Forward pass**: Run examples through the model
2. **Compute loss**: Measure prediction accuracy  
3. **Backward pass**: Calculate gradients
4. **Update LoRA weights**: Apply parameter updates (only ~0.1% of total parameters!)
5. **Evaluate**: Test on validation data periodically
6. **Save checkpoints**: Store progress for recovery

**Note**: Training progress will show loss decreasing and evaluation metrics improving. The beauty of LoRA is that we're only updating a tiny fraction of the model's parameters while achieving full fine-tuning performance!

### Track experiment using MLFlow

On the commandline run `pixi run -e cuda mlflow ui` to start the MLFlow tracking server. This will allow you to monitor the training in realtime at [http://localhost:5000](http://localhost:5000).

In [30]:
trainer.train()

2026/01/22 21:07:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.schemas
2026/01/22 21:07:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.tables
2026/01/22 21:07:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.types
2026/01/22 21:07:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.constraints
2026/01/22 21:07:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.defaults
2026/01/22 21:07:20 INFO alembic.runtime.plugins: setup plugin alembic.autogenerate.comments
2026/01/22 21:07:20 INFO mlflow.store.db.utils: Creating initial MLflow database tables...
2026/01/22 21:07:20 INFO mlflow.store.db.utils: Updating database tables
2026/01/22 21:07:20 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/01/22 21:07:20 INFO alembic.runtime.migration: Will assume non-transactional DDL.
2026/01/22 21:07:20 INFO alembic.runtime.migration: Context impl SQLiteImpl.
2026/01/22 21:07:20 INFO alembic.runtime

Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
100,0.1414,0.114936,0.878778,1017650.0,0.575286
200,0.0488,0.044849,0.268543,2035416.0,0.59463
300,0.0419,0.039879,0.246013,3062739.0,0.602813
400,0.0404,0.040773,0.246318,4092207.0,0.607576
500,0.0394,0.039847,0.239885,5110860.0,0.609942
600,0.0391,0.037489,0.223873,6132648.0,0.610539
700,0.0383,0.038177,0.217915,7159054.0,0.613317
800,0.039,0.03866,0.214454,8182179.0,0.614343
900,0.0375,0.037647,0.211119,9212964.0,0.614986
1000,0.0378,0.038232,0.211039,10235964.0,0.615383


TrainOutput(global_step=3944, training_loss=0.040891241561810336, metrics={'train_runtime': 17681.9589, 'train_samples_per_second': 3.566, 'train_steps_per_second': 0.223, 'total_flos': 7.810918543234253e+17, 'train_loss': 0.040891241561810336, 'entropy': 0.20005356283546183, 'num_tokens': 40371876.0, 'mean_token_accuracy': 0.62258922089042, 'epoch': 4.0})

## Step 12: Save the LoRA Adapter

After training completes, we save our LoRA adapter:

### What Gets Saved:
- **LoRA weight matrices**: The small adapters we trained (~few MB)
- **Adapter configuration**: Settings like rank, alpha, target modules
- **Tokenizer**: Ensures consistency during inference

### Why This is Amazing:
- The base model stays unchanged (no need to duplicate GBs of weights)
- Multiple LoRA adapters can be created for different tasks
- Adapters can be easily shared, version-controlled, and swapped
- You can even combine multiple LoRA adapters!

In [31]:
# Save the LoRA adapter
print(f"Saving LoRA adapter to {adapter_dir}")

trainer.model.save_pretrained(adapter_dir)
print(f"LoRA adapter saved successfully to {adapter_dir}!")
tokenizer.save_pretrained(adapter_dir)

Saving LoRA adapter to lora_adapter/microsoft/Phi-4-mini-instruct
LoRA adapter saved successfully to lora_adapter/microsoft/Phi-4-mini-instruct!


('lora_adapter/microsoft/Phi-4-mini-instruct/tokenizer_config.json',
 'lora_adapter/microsoft/Phi-4-mini-instruct/special_tokens_map.json',
 'lora_adapter/microsoft/Phi-4-mini-instruct/chat_template.jinja',
 'lora_adapter/microsoft/Phi-4-mini-instruct/vocab.json',
 'lora_adapter/microsoft/Phi-4-mini-instruct/merges.txt',
 'lora_adapter/microsoft/Phi-4-mini-instruct/added_tokens.json',
 'lora_adapter/microsoft/Phi-4-mini-instruct/tokenizer.json')

In [32]:
# Process test dataset to match the expected format for evaluation
# Apply the same chat template processing that SFTTrainer uses
def process_dataset_for_evaluation(dataset, tokenizer):
    def tokenize_function(examples):
        # Apply chat template to format messages
        formatted_texts = []
        for messages in examples["messages"]:
            formatted_text = tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )
            formatted_texts.append(formatted_text)

        # Tokenize the formatted text
        tokenized = tokenizer(
            formatted_texts,
            truncation=True,
            padding=False,
            max_length=trainer.args.max_length,
            return_overflowing_tokens=False,
        )

        return tokenized

    # Process the dataset
    processed_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=dataset.column_names,
        desc="Tokenizing dataset for evaluation",
    )

    return processed_dataset


processed_test_dataset = process_dataset_for_evaluation(test_dataset, tokenizer)

test_results = trainer.evaluate(processed_test_dataset)
print("ðŸŽ¯ Test Results:")
for key, value in test_results.items():
    if isinstance(value, float):
        print(f"   {key}: {value:.4f}")
    else:
        print(f"   {key}: {value}")



Tokenizing dataset for evaluation:   0%|          | 0/1971 [00:00<?, ? examples/s]

ðŸŽ¯ Test Results:
   eval_loss: 0.0366
   eval_runtime: 128.1027
   eval_samples_per_second: 15.3860
   eval_steps_per_second: 1.9280
   eval_entropy: 0.2005
   eval_num_tokens: 40371876.0000
   eval_mean_token_accuracy: 0.6186
   epoch: 4.0000


In [33]:
print(
    "Now run the notebook `trl_medical_reasoning_inference.ipynb` to use the LoRA fine-tuned model."
)

Now run the notebook `trl_medical_reasoning_inference.ipynb` to use the LoRA fine-tuned model.


## ðŸŽ‰ Training Complete!

Congratulations! You've successfully:

1. âœ… Learned about LoRA and TRL
2. âœ… Loaded and preprocessed medical reasoning data  
3. âœ… Set up memory-efficient quantization
4. âœ… Configured LoRA adapters for parameter-efficient training
5. âœ… Fine-tuned a language model using TRL's SFTTrainer
6. âœ… Saved your trained LoRA adapter

### Next Steps:
- Use the inference notebook to test your fine-tuned model
- Experiment with different LoRA ranks and alphas
- Try training on different datasets
- Combine multiple LoRA adapters for multi-task models

### Key Takeaways:
- **LoRA** enables efficient fine-tuning with minimal memory
- **TRL** provides state-of-the-art training techniques
- **Quantization** makes large models accessible on consumer hardware
- **Parameter-efficient fine-tuning** is the future of model customization!