# Medical Chatbot - LoRA Fine-tuning with Phi-4-mini-instruct

This notebook demonstrates how to fine-tune a language model for medical assistance using **LoRA (Low-Rank Adaptation)** with 4-bit quantization (QLoRA) on **Phi-4-mini-instruct**.

**Key Features:**
- üíæ Memory-efficient training with 4-bit quantization
- üéØ Parameter-efficient fine-tuning using LoRA
- üìä Perplexity metrics using HuggingFace Evaluate library
- üß™ Test set evaluation during training with custom callbacks
- üí¨ Chat template formatting for conversational AI
- üîß **Function calling support** for tools like internet search

**Training Pipeline:**
1. Load and prepare dataset
2. Configure model with QLoRA
3. Apply chat templates
4. Train with metrics monitoring
5. Evaluate and save adapters
6. Test inference with function calling

## üîê Model Access

**Phi-4-mini-instruct** is an ungated model, so no authentication required.

**Why Phi-4?**
- Improved reasoning and instruction-following
- Native function calling support for tools
- 128K token context length
- Better multilingual support

**Alternative Models:**
- `meta-llama/Llama-3.2-3B-Instruct` (requires access)
- `teknium/OpenHermes-2.5-Mistral-7B`
- `google/gemma-2b-it`

## üì¶ Dependencies

Required libraries for LoRA fine-tuning with 4-bit quantization.

NOTE: If you are using [pixi](https://pixi.sh) then the virtual environment will already be setup for you. 

```bash
pixi install
```

In [None]:
# Install dependencies if needed (uncomment for a fresh environment)
# %pip install -q "transformers>=4.40.0" "datasets>=2.18.0" "peft>=0.11.0" "accelerate>=0.28.0" "bitsandbytes" "evaluate"

## ‚öôÔ∏è Configuration & Imports

Set up the training environment:
- **Seed**: 816 for reproducibility
- **Model**: microsoft/Phi-4-mini-instruct (3.8B parameters)
- **Quantization**: 4-bit NF4 with double quantization
- **Training**: 100 steps with batch size 1, gradient accumulation 8
- **Learning Rate**: 2e-4 with 3% warmup
- **Function Calling**: Support for tools (search, compute, etc.)

In [None]:
import os
import json
from dataclasses import dataclass
from typing import Dict, List

import torch
import numpy as np
import evaluate
from datasets import load_dataset, DatasetDict
from huggingface_hub import login
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    TrainerCallback,
    EarlyStoppingCallback,
    set_seed,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel

In [None]:
SEED = 816
set_seed(SEED)

# Configurable training parameters
base_model = "microsoft/Phi-4-mini-instruct"
train_file = "./data/my_custom_data.jsonl"
output_dir = "./checkpoints/llama3-lora-med"
max_length = 512  # Reduced from 2048 for 8GB GPU
batch_size = 1
lr = 2e-4
max_steps = 10
validate_steps = 5  # Evaluate test set every N steps

In [None]:
# Define available tools for function calling
tools = [
    {
        "name": "search_internet",
        "description": "Search the internet for current medical information, research, and clinical guidelines",
        "parameters": {
            "query": {
                "description": "The medical search query",
                "type": "str"
            },
            "num_results": {
                "description": "Number of search results to return",
                "type": "int",
                "default": 5
            }
        }
    },
    {
        "name": "retrieve_clinical_guidelines",
        "description": "Retrieve clinical guidelines and best practices for specific conditions",
        "parameters": {
            "condition": {
                "description": "The medical condition to retrieve guidelines for",
                "type": "str"
            },
            "guideline_source": {
                "description": "Source of guidelines (e.g., 'NICE', 'AHA', 'CDC', 'WHO')",
                "type": "str",
                "default": "general"
            }
        }
    }
]

# System prompt for medical assistant with tool awareness
system_prompt = (
    "You are a careful medical assistant with access to tools for finding current information. "
    "You can search the internet for recent medical research and clinical guidelines. "
    "Reason step by step, cite sources when available, and avoid guessing beyond provided information. "
    "Your primary function is to not harm. If you are unsure, tell the user you don't know or will search for information. "
    "Use tools proactively when questions require current information or specific guidelines. "
    "Always end the conversation with 'This response was generated by AI. "
    "Please check with professional medical practitioners to confirm the results are safe and appropriate.'"
)

In [None]:
# 4-bit quantization configuration for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

## üîß Function Calling & Tools

**Phi-4 supports function calling** for accessing external tools. The system prompt includes tool definitions in JSON format.

**Tool Format:**
- Tools are wrapped in `<|tool|>` and `<|/tool|>` tokens
- Model generates function calls when appropriate
- Supports internet search, guideline retrieval, and more

**Model Output Example:**
```
I'll search for the latest information about this condition.

<|function_call|>
search_internet
{"query": "recent research on hypertension treatment 2025"}
<|/function_call|>
```

During inference, you can capture these calls and execute actual tool logic, then feed results back to the model.

## üìö Load Dataset

Load the JSONL dataset and split into train/validation/test sets.

**Dataset Format:**
```json
{"instruction": "What is hypertension?", "response": "Hypertension, or high blood pressure, means ..."}
```

**Splits:**
- Training: 80% of data
- Validation: 10% of data
- Test: 10% of data

### Dataset Split Details

The data is loaded from a JSONL file where each line contains a medical Q&A pair:
- **instruction**: The user's question
- **response**: The model's expected answer

**Cascading Split Strategy:**
- First, we split 80/20 to create training and temporary data
- Then we split the temporary data 50/50 to create validation and test sets
- This ensures three non-overlapping subsets for training, validation, and testing
- Using the same seed (816) ensures reproducibility

This approach prevents data leakage‚Äîthe model never sees test data during training.

In [None]:
# Load dataset from JSONL file
full_dataset = load_dataset("json", data_files={"train": train_file})["train"]

# Split into train/validation/test using cascading train_test_split
# First split: 80% train, 20% temp (for validation + test)
split_1 = full_dataset.train_test_split(test_size=0.2, seed=SEED)
train_data = split_1["train"]
temp_data = split_1["test"]

# Second split: split the temp data 50/50 into validation and test
split_2 = temp_data.train_test_split(test_size=0.5, seed=SEED)
validation_data = split_2["train"]
test_data = split_2["test"]

# Create DatasetDict with all three splits
raw_dataset = DatasetDict({
    "train": train_data,
    "validation": validation_data,
    "test": test_data
})

print(raw_dataset)
print("\nExample from training set:")
print(raw_dataset["train"][0])

## ü§ñ Model Setup

**Steps:**
1. Load tokenizer and configure padding
2. Load model in 4-bit quantization
3. Prepare model for k-bit training
4. Enable gradient checkpointing (saves memory)
5. Attach LoRA adapters

**LoRA Configuration:**
- Rank (r): 16
- Alpha: 32
- Target modules: All attention and MLP layers
- Dropout: 5%

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

# Load model with 4-bit quantization
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
)

# Prepare for LoRA fine-tuning in 4-bit
model.config.use_cache = False  # Required for gradient checkpointing
model = prepare_model_for_kbit_training(model)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Attach LoRA adapters
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

### Understanding 4-bit Quantization

**What is quantization?**

Quantization reduces memory usage by representing model weights with fewer bits. Standard 16-bit (float16) weights become 4-bit weights, reducing memory by ~75%.

**NF4 (Normal Float 4):**

- Uses 4 bits instead of 16
- Optimized for normal distributions of neural network weights
- Double quantization: quantizes the quantization scale itself for more compression

**Why this matters:**

- Phi-4-mini has 3.8B parameters. In float16, this needs ~8GB of VRAM
- With 4-bit quantization, it fits in ~2GB
- Enables fine-tuning on consumer GPUs (like NVIDIA RTX 3060/4090)

**Trade-off:** 

Slight reduction in model precision, but with LoRA fine-tuning, the impact is minimal and results are comparable to full precision training.

## üî§ Data Preprocessing

**Process:**
1. **Format Examples**: Apply chat template (system, user, assistant roles)
2. **Tokenize**: Convert text to token IDs with padding/truncation
3. **Create Labels**: Clone input_ids for next-token prediction

**Chat Template Example:**
```
<|system|>You are a careful medical assistant...<|end|>
<|user|>What is hypertension?<|end|>
<|assistant|>Hypertension is...<|end|>
```

In [None]:
def format_example(example: Dict[str, str]) -> Dict[str, str]:
    """Apply chat template to each example."""
    messages: List[Dict[str, str]] = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": example.get("instruction", "")},
        {"role": "assistant", "content": example.get("response", "")},
    ]
    example["text"] = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )
    return example

# Apply chat template to all examples
formatted = raw_dataset.map(format_example, remove_columns=raw_dataset["train"].column_names)


def tokenize_batch(batch: Dict[str, List[str]]) -> Dict[str, torch.Tensor]:
    """Tokenize a batch of examples."""
    tokenized = tokenizer(
        batch["text"],
        max_length=max_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized

# Tokenize all examples
tokenized = formatted.map(tokenize_batch, batched=True, remove_columns=["text"])
train_dataset = tokenized["train"].with_format("torch")
eval_dataset = tokenized["validation"].with_format("torch")

print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(eval_dataset)}")
print("\nFirst 120 tokens of training example:")
print(tokenizer.decode(train_dataset[0]["input_ids"][:120]))

### Chat Template & Tokenization Process

**Step 1: Format Examples**
The `format_example()` function applies Phi-4's chat template, which structures conversations with role markers:
- `<|system|>`: System instructions and context
- `<|user|>`: User's question
- `<|assistant|>`: Model's response

**Step 2: Tokenize**
The `tokenize_batch()` function converts text to token IDs:
- **max_length=512**: Truncate longer sequences (prevents GPU memory overflow)
- **padding="max_length"**: Pad shorter sequences to 512 tokens
- **labels**: Clone of input_ids for supervised learning (predict next token)

**Why labels matter:**
During training, the model learns to predict each token's next token. The labels tell the trainer what the "correct" next token should be for each position. By using `input_ids` as labels, we're doing standard causal language modeling (predict the next word).

In [None]:
# Data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

## üìä Metrics

- **Perplexity**: exp(loss) - measures how well the model predicts the next token
  - Lower is better. 
  - Lowest possible score is 1. 
  - A score of 5, for example, means the model is effectively chosing from a list of 5 possiblities for each next word completion.
- Computed using HuggingFace Evaluate library

In [None]:
def compute_metrics(eval_preds):
    """Compute perplexity metric from model predictions."""
    predictions, labels = eval_preds
    
    # Extract logits if predictions is a tuple
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    
    # Shift for next-token prediction (standard for causal LM)
    shift_logits = predictions[..., :-1, :]
    shift_labels = labels[..., 1:]
    
    # Flatten for loss calculation
    shift_logits = shift_logits.reshape(-1, shift_logits.shape[-1])
    shift_labels = shift_labels.reshape(-1)
    
    # Calculate cross-entropy loss (ignore padding tokens with -100)
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_fct(
        torch.tensor(shift_logits, dtype=torch.float32),
        torch.tensor(shift_labels, dtype=torch.long)
    )
    
    # Perplexity is exp(loss)
    perplexity = torch.exp(loss).item()
    
    return {"perplexity": perplexity}

### How Perplexity is Calculated

Perplexity measures how "surprised" the model is by the correct next token:

**Formula:** 
$$\text{Perplexity} = e^{\text{loss}}$$

**Process:**
1. For each token position, compute cross-entropy loss (how wrong the prediction was)
2. Average the loss across all tokens
3. Take exponential of the average loss

**Interpretation:**
- **Perplexity = 1**: Model is 100% confident in its predictions (perfect)
- **Perplexity = 5**: Model effectively chooses from ~5 equally likely options
- **Perplexity = 100**: Model is very uncertain about next tokens

**For chatbots:** A perplexity of 10 on a test set is acceptable. Lower is always better.

### Custom Validation Callback

The `ValidationEvalCallback` class provides real-time monitoring of model performance on the validation set during training.

**Why it's needed:**
- Standard HuggingFace trainers only show training loss
- We need to monitor validation loss/perplexity to detect overfitting
- Allows us to see if the model is generalizing to unseen data

**How it works:**
- Runs at the end of each training step
- Every N steps (defined by `validate_steps`), it evaluates on the validation set
- Prints validation metrics so you can monitor training progress
- Works in conjunction with `EarlyStoppingCallback` to prevent overfitting

In [None]:
# Custom callback to evaluate validation set during training
class ValidationEvalCallback(TrainerCallback):
    """Evaluate validation set periodically during training."""
    
    def __init__(self, eval_dataset):
        self.eval_dataset = eval_dataset
        self.trainer_obj = None
    
    def on_step_end(self, args, state, control, **kwargs):
        """Called at the end of each training step."""
        if state.global_step % validate_steps == 0 and state.global_step > 0 and self.trainer_obj is not None:
            eval_results = self.trainer_obj.predict(self.eval_dataset)
            eval_loss = eval_results.metrics.get('test_loss', 'N/A')
            eval_perp = eval_results.metrics.get('test_perplexity', 'N/A')
            print(f"\n[Step {state.global_step}] Validation Loss: {eval_loss:.4f}, "
                  f"Validation Perplexity: {eval_perp:.4f} "
                  "(Lower is better. Lowest possible score is 1.)")
            

# Initialize callback with validation dataset
val_callback = ValidationEvalCallback(eval_dataset)

### Training Arguments Explained

Key parameters that control the training process:

**Batch Size & Accumulation:**
- `per_device_train_batch_size=1`: Process 1 example per GPU step
- `gradient_accumulation_steps=4`: Accumulate gradients over 4 steps before updating weights
- *Effective batch size = 1 √ó 4 = 4*

**Learning Schedule:**
- `learning_rate=2e-4`: Step size for weight updates (smaller = more stable but slower)
- `warmup_ratio=0.03`: Gradually increase LR for first 3% of training (prevents instability)
- `max_steps=10`: Train for 10 steps total (for demo; increase for production)

**Checkpointing:**
- `save_steps=50`: Save model every 50 steps
- `save_total_limit=2`: Keep only the 2 most recent checkpoints
- `load_best_model_at_end=True`: Load the best checkpoint after training completes

**Evaluation:**
- `eval_strategy="steps"`: Evaluate during training (not just at end)
- `eval_steps=5`: Evaluate every 5 steps
- `metric_for_best_model="eval_loss"`: Track validation loss
- `greater_is_better=False`: Lower loss is better

**Precision:**
- `bf16=True` (if CUDA) or `fp16=True`: Use lower precision for memory efficiency

In [None]:
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,  # Reduced from 8 for 8GB GPU
    learning_rate=lr,
    max_steps=max_steps,
    warmup_ratio=0.03,
    logging_steps=5,
    save_steps=50,
    save_total_limit=2,
    bf16=torch.cuda.is_available(),
    fp16=not torch.cuda.is_available(),
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    report_to=["none"],
    eval_strategy="steps",  # Enable evaluation during training
    eval_steps=validate_steps,  # Evaluate every N steps
    load_best_model_at_end=True,  # Load the best checkpoint at the end
    metric_for_best_model="eval_loss",  # Track validation loss
    greater_is_better=False,  # Lower loss is better
)


# Early stopping: stop if validation loss doesn't improve for N evaluations
early_stopping = EarlyStoppingCallback(
    early_stopping_patience=3,  # Stop if no improvement for 3 evaluations
    early_stopping_threshold=0.001,  # Minimum improvement threshold
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[val_callback, early_stopping],  # Add early stopping callback
)

# Set trainer reference in callback (needed for validation evaluation)
val_callback.trainer_obj = trainer


In [None]:
print("\nüöÄ Starting training with early stopping (patience=3)...")
trainer.train()

## üíæ Save Model

Save the trained LoRA adapters and optionally merge with base model.

**Saved Artifacts:**
- LoRA adapter weights (lightweight, ~few MB)
- Tokenizer configuration

**Loading the adapter:** Use `PeftModel.from_pretrained(base_model, adapter_path)`

In [None]:
# Save LoRA adapter weights
adapter_dir = os.path.join(output_dir, "lora_adapter")
model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
print(f"‚úÖ Adapter saved to {adapter_dir}")

# Optionally merge and save full model
merge = os.environ.get("MERGE_LORA", "0") == "1"
if merge:
    merged_dir = os.path.join(output_dir, "merged")
    merged_model = model.merge_and_unload()
    merged_model.save_pretrained(merged_dir)
    tokenizer.save_pretrained(merged_dir)
    print(f"‚úÖ Merged model saved to {merged_dir}")

## üß™ Test Set Evaluation

Evaluate the final model performance on the held-out test set.

**Metrics:**
- **Test Loss**: Cross-entropy loss on test examples
- **Test Perplexity**: exp(loss) 
  - Lower is better. 
  - Lowest possible score is 1. 
  - A score of 5, for example, means the model is effectively chosing from a list of 5 possiblities for each next word completion.

A perplexity of ~10-20 is typical for well-trained medical chatbots.

In [None]:
# Evaluate on the held-out test set (first time using it!)
test_formatted = raw_dataset["test"].map(format_example, remove_columns=raw_dataset["test"].column_names)
test_tokenized = test_formatted.map(tokenize_batch, batched=True, remove_columns=["text"])
test_dataset = test_tokenized.with_format("torch")

print(f"Test dataset size: {len(test_dataset)}")

# Evaluate on test set
test_results = trainer.predict(test_dataset)
print("\n" + "="*50)
print("üìä FINAL TEST METRICS")
print("="*50)
print(f"Test Loss: {test_results.metrics.get('test_loss', 'N/A'):.4f}")
print(f"Test Perplexity: {test_results.metrics.get('test_perplexity', 'N/A'):.4f}")
print("\nAll metrics:")
print(test_results.metrics)

## üí¨ Inference Demo with Function Calling

Test the fine-tuned model with a medical query. Phi-4 can generate function calls for tools.

**Process:**
1. Load base model in fp16
2. Attach trained LoRA adapters
3. Format query with chat template and tool definitions
4. Generate response (may include function calls)
5. Parse and handle function calls if present

**Try modifying the question to see if the model decides to use tools!**

### Inference Setup Steps

**Step 1: Load Base Model in float16**
- Use float16 precision (not quantized) to maximize quality for inference
- Phi-4 can run inference on a single GPU with 8GB VRAM

**Step 2: Attach LoRA Adapters**
- Load the trained LoRA weights we saved earlier
- These are lightweight (~few MB) compared to the base model
- They modify the attention and feed-forward layers learned during fine-tuning

**Step 3: Format Prompt with Tools**
- Include the system prompt and tool definitions
- Use Phi-4's special tokens: `<|system|>`, `<|user|>`, `<|assistant|>`
- The model can now reference tools in its response

**Step 4: Generate Response**
- Set `temperature=0.3` for consistent, focused medical advice
- `do_sample=False` with low temperature = deterministic greedy decoding
- Higher `max_new_tokens` allows longer responses
- The model may generate `<|function_call|>` blocks when it wants to use a tool

**Output:**
The model's response may include function call blocks if it decides to search for information. In a real application, you'd parse these calls and execute actual tool logic.

In [None]:
# Load base model and attach LoRA adapters
adapter_dir = os.path.join(output_dir, "lora_adapter")
inference_model = AutoModelForCausalLM.from_pretrained(
    base_model,
    device_map="auto",
    torch_dtype=torch.float16,
)
inference_model = PeftModel.from_pretrained(inference_model, adapter_dir)
inference_model.eval()

# Format tools for Phi-4 function calling
tools_json = json.dumps(tools, indent=2)

# Prepare medical query with tools
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": "What are the latest treatment guidelines for hypertension?"},
]

# Build prompt with tools for Phi-4
# Note: This format is Phi-4 specific with <|tool|> tags
system_with_tools = f"{system_prompt}\n\nAvailable tools:\n{tools_json}"
formatted_prompt = f"<|system|>{system_with_tools}<|end|><|user|>What are the latest treatment guidelines for hypertension?<|end|><|assistant|>"

# Tokenize
inputs = inference_model.get_input_embeddings().weight.device
inputs_tensor = inference_model.tokenizer(
    formatted_prompt,
    return_tensors="pt",
).input_ids.to(inference_model.device)

# Generate response
print("ü§ñ Generating response with function calling enabled...\n")
gen = inference_model.generate(
    inputs_tensor,
    max_new_tokens=300,
    temperature=0.3,
    do_sample=False,
)

# Decode response
response = inference_model.tokenizer.decode(gen[0], skip_special_tokens=False)
print("Full Response:")
print("-" * 60)
print(response)
print("-" * 60)

# Check for function calls in response
if "<|function_call|>" in response:
    print("\n‚úÖ Model generated function call(s)!")
    print("Function calls detected - you can parse and execute these in a real application")
else:
    print("\nüìù Model provided direct answer without tool usage")