# Fine-Tuning a Language Model for Healthcare Question Answering

**A practical guide for students using Ollama and Hugging Face.**

This notebook provides a comprehensive, step-by-step tutorial on fine-tuning a smaller language model for a specific domain. We will focus on the healthcare sector and create a question-answering model capable of responding to medical queries.

## 1. Introduction to Fine-Tuning

### What is Fine-Tuning?

Large Language Models (LLMs) are pre-trained on vast amounts of general text data, giving them a broad understanding of language. However, for specialized tasks or domains, their performance can be significantly improved by **fine-tuning**. Fine-tuning is the process of taking a pre-trained model and continuing its training on a smaller, task-specific dataset. This adapts the model's knowledge and capabilities to the new domain, leading to more accurate and relevant outputs.

### Why is Fine-Tuning Important?

- **Domain Adaptation:** Fine-tuning allows a general-purpose model to learn the specific jargon, entities, and relationships of a particular field, such as medicine, law, or finance.
- **Improved Performance:** A fine-tuned model will outperform a general model on tasks within its specialized domain.
- **Cost and Time Efficiency:** Training a model from scratch is computationally expensive and time-consuming. Fine-tuning offers a more efficient way to achieve high performance on specific tasks.

### Parameter-Efficient Fine-Tuning (PEFT)

Fine-tuning a full LLM can still be resource-intensive. **Parameter-Efficient Fine-Tuning (PEFT)** methods address this by only updating a small subset of the model's parameters. This significantly reduces the computational and storage costs while achieving performance comparable to full fine-tuning. We will be using **Low-Rank Adaptation (LoRA)**, a popular PEFT technique, in this tutorial.

## 2. Understanding LoRA (Low-Rank Adaptation)

### How LoRA Works

LoRA works by adding small, trainable rank decomposition matrices to the existing model weights. Instead of updating all the weights in the model, LoRA only trains these small matrices, which are then added to the frozen pre-trained weights during inference.

**Key LoRA Parameters:**

- **r (rank):** The dimension of the low-rank matrices. Lower values mean fewer trainable parameters. Typical values: 8, 16, 32.
- **lora_alpha:** Scaling factor for the LoRA weights. Usually set to 2x the rank value.
- **lora_dropout:** Dropout probability for LoRA layers to prevent overfitting.

**Benefits:**
- Only 0.1-1% of parameters need to be trained
- Significantly reduced memory requirements
- Faster training times
- Small adapter files (often just a few MB)

## 3. Setup and Installation

First, let's install the necessary Python libraries. We will need `transformers` for loading the model, `datasets` for loading our data, `peft` for the LoRA implementation, `trl` for the training loop, and `bitsandbytes` for quantization.

In [None]:
!pip install -q -U transformers datasets accelerate peft trl bitsandbytes

## 4. Loading the Dataset

We will use the `MedQuad` dataset from Hugging Face, which contains medical question-answer pairs. This dataset has 16,407 examples covering various medical topics.

**Dataset Structure:**
- `qtype`: Type of question (symptoms, treatment, prevention, etc.)
- `Question`: The medical question
- `Answer`: The detailed answer

In [None]:
from datasets import load_dataset

# Load the MedQuad dataset
dataset_name = "keivalya/MedQuad-MedicalQnADataset"
dataset = load_dataset(dataset_name, split="train")

print(f"Dataset size: {len(dataset)} examples")
print("\nFirst example:")
print(dataset[0])

## 5. Data Preprocessing

To prepare the data for fine-tuning, we need to format it into a prompt template that the model can understand. A good prompt helps the model learn the desired input-output structure. We will create a simple prompt that clearly separates the question from the answer.

We'll also take a smaller subset of the data for this demonstration to speed up training.

In [None]:
def create_prompt(sample):
    """Format a sample into a training prompt."""
    prompt = f"""### Question:
{sample['Question']}

### Answer:
{sample['Answer']}"""
    return prompt

# Let's see what a formatted prompt looks like
print("Example formatted prompt:")
print(create_prompt(dataset[0]))
print("\n" + "="*50 + "\n")

# For this demo, we'll use a subset of the data (first 1000 examples)
# In practice, you would use the full dataset
train_dataset = dataset.select(range(min(1000, len(dataset))))
print(f"Training on {len(train_dataset)} examples")

## 6. Model Loading and Configuration

Now, let's load our base model. We will use `TinyLlama`, a small but powerful model perfect for demonstrations on consumer hardware. We will also use 4-bit quantization to further reduce the memory footprint.

### Quantization

**4-bit quantization** reduces the precision of model weights from 32-bit or 16-bit floating point to 4-bit integers. This dramatically reduces memory usage (up to 75% reduction) with minimal impact on model quality.

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

# Configure 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)

# Load the base model with quantization
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
)

# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

# Prepare model for k-bit training
model = prepare_model_for_kbit_training(model)

print("Base model loaded successfully!")

## 7. Configuring LoRA

Now we'll configure the LoRA parameters and apply them to our model. This creates the trainable adapter layers that will be fine-tuned.

In [None]:
# Configure LoRA
lora_config = LoraConfig(
    r=16,                      # Rank of the low-rank matrices
    lora_alpha=32,             # Scaling factor (usually 2x rank)
    lora_dropout=0.05,         # Dropout for regularization
    bias="none",               # Don't train bias parameters
    task_type="CAUSAL_LM",     # Task type: Causal Language Modeling
    target_modules=["q_proj", "v_proj"],  # Which layers to apply LoRA to
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)

# Print trainable parameters
model.print_trainable_parameters()

Notice that only a small percentage of parameters are trainable! This is the power of PEFT - we're only training a tiny fraction of the model while still achieving good results.

## 8. Training the Model

We will use the `SFTTrainer` (Supervised Fine-Tuning Trainer) from the `trl` library to perform the fine-tuning. This trainer simplifies the training process and is optimized for supervised fine-tuning tasks.

### Training Parameters:

- **num_train_epochs:** Number of times to iterate over the entire dataset
- **per_device_train_batch_size:** Number of samples processed at once
- **learning_rate:** How quickly the model updates its weights
- **max_seq_length:** Maximum length of input sequences

In [None]:
from transformers import TrainingArguments
from trl import SFTTrainer

# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    save_strategy="epoch",
    optim="paged_adamw_8bit",
)

# Format the dataset for training
def formatting_func(example):
    return create_prompt(example)

# Create the trainer
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    peft_config=lora_config,
    formatting_func=formatting_func,
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_args,
)

print("Starting training...")
trainer.train()
print("Training complete!")

## 9. Testing the Fine-Tuned Model

Let's test our fine-tuned model with a medical question to see how it performs!

In [None]:
def ask_question(question):
    """Ask the fine-tuned model a medical question."""
    prompt = f"""### Question:
{question}

### Answer:
"""
    
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    outputs = model.generate(
        **inputs,
        max_new_tokens=256,
        temperature=0.7,
        do_sample=True,
        top_p=0.9,
    )
    
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # Extract just the answer part
    answer = response.split("### Answer:")[-1].strip()
    return answer

# Test with a sample question
test_question = "What are the symptoms of diabetes?"
print(f"Question: {test_question}")
print(f"\nAnswer: {ask_question(test_question)}")

## 10. Saving the Model

After training, we save the LoRA adapters. These are small files (typically just a few MB) that contain only the trained weights.

In [None]:
# Save the LoRA adapters
output_dir = "tinyllama-medical-qa-lora"
trainer.model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)

print(f"Model saved to {output_dir}")
print("\nYou can now load this model using:")
print("from peft import PeftModel")
print(f"model = PeftModel.from_pretrained(base_model, '{output_dir}')")

## 11. Merging and Exporting for Ollama

To use the model with Ollama, we need to merge the LoRA adapters with the base model and convert it to GGUF format.

In [None]:
from peft import PeftModel

# Load base model without quantization for merging
base_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

# Load and merge the LoRA weights
merged_model = PeftModel.from_pretrained(base_model, output_dir)
merged_model = merged_model.merge_and_unload()

# Save the merged model
merged_output_dir = "tinyllama-medical-qa-merged"
merged_model.save_pretrained(merged_output_dir)
tokenizer.save_pretrained(merged_output_dir)

print(f"Merged model saved to {merged_output_dir}")

## 12. Converting to GGUF and Using with Ollama

To use the model with Ollama, follow these steps in your terminal:

### Step 1: Install llama.cpp

```bash
git clone https://github.com/ggerganov/llama.cpp.git
cd llama.cpp
pip install -r requirements.txt
```

### Step 2: Convert to GGUF

```bash
python convert.py /path/to/tinyllama-medical-qa-merged \
  --outfile tinyllama-medical-qa.gguf \
  --outtype f16
```

### Step 3: Create a Modelfile

Create a file named `Modelfile` with this content:

```
FROM ./tinyllama-medical-qa.gguf

TEMPLATE """### Question:
{{ .Prompt }}

### Answer:
"""

SYSTEM "You are a helpful medical assistant trained to answer healthcare questions."

PARAMETER temperature 0.7
PARAMETER top_p 0.9
PARAMETER stop "### Question:"
```

### Step 4: Create and Run in Ollama

```bash
# Create the model in Ollama
ollama create tinyllama-medical-qa -f Modelfile

# Run the model
ollama run tinyllama-medical-qa
```

### Step 5: Test Your Model

```bash
ollama run tinyllama-medical-qa "What are the symptoms of diabetes?"
```

## 13. Summary and Key Takeaways

In this notebook, we successfully fine-tuned a small language model for healthcare question answering. Here are the key concepts we covered:

### What We Learned:

1. **Fine-Tuning Fundamentals:** Fine-tuning adapts a pre-trained model to a specific domain or task, improving performance without training from scratch.

2. **Parameter-Efficient Fine-Tuning (PEFT):** Methods like LoRA allow us to fine-tune models by training only a small fraction of parameters, making it feasible on consumer hardware.

3. **LoRA (Low-Rank Adaptation):** Adds small trainable matrices to frozen model weights, achieving comparable performance to full fine-tuning with minimal resources.

4. **Quantization:** 4-bit quantization reduces memory requirements by up to 75%, enabling larger models to run on limited hardware.

5. **Practical Workflow:**
   - Load dataset and preprocess with prompt templates
   - Load base model with quantization
   - Configure and apply LoRA
   - Train using SFTTrainer
   - Save, merge, and export for deployment

### Next Steps:

- **Experiment with different datasets:** Try fine-tuning on other domains like legal, financial, or technical documentation.
- **Adjust hyperparameters:** Experiment with different LoRA ranks, learning rates, and batch sizes.
- **Use larger models:** Apply the same techniques to larger models like Llama 2 7B or Mistral 7B.
- **Evaluate performance:** Create test sets and measure accuracy, relevance, and factual correctness.

### Resources:

- [Hugging Face PEFT Documentation](https://huggingface.co/docs/peft)
- [LoRA Paper](https://arxiv.org/abs/2106.09685)
- [Ollama Documentation](https://ollama.ai)
- [TRL Library](https://huggingface.co/docs/trl)