# Medical Chatbot - LoRA Fine-tuning with Phi-3.5

This notebook demonstrates how to fine-tune a language model for medical assistance using **LoRA (Low-Rank Adaptation)** with 4-bit quantization (QLoRA). 

**Key Features:**
- üíæ Memory-efficient training with 4-bit quantization
- üéØ Parameter-efficient fine-tuning using LoRA
- üìä Perplexity metrics using HuggingFace Evaluate library
- üß™ Test set evaluation during training with custom callbacks
- üí¨ Chat template formatting for conversational AI

**Training Pipeline:**
1. Load and prepare dataset
2. Configure model with QLoRA
3. Apply chat templates
4. Train with metrics monitoring
5. Evaluate and save adapters
6. Test inference

## üîê Model Access

**Important:** Some models may require authentication or have gated access.

**Options:**
- Use an ungated model like `microsoft/Phi-3.5-mini-instruct` (default)
- For gated models (e.g., Llama), authenticate with: `pixi run huggingface-cli login`
- Or set environment variable: `export HUGGINGFACE_TOKEN=your_token`

**Alternative Models:**
- `meta-llama/Llama-3.2-3B-Instruct` (requires access)
- `teknium/OpenHermes-2.5-Mistral-7B`
- `google/gemma-2b-it`

## üì¶ Dependencies

Required libraries for LoRA fine-tuning with 4-bit quantization.

NOTE: If you are using [pixi](https://pixi.sh) then the virtual environment will already be setup for you. 

```bash
pixi install
```

In [1]:
# Install dependencies if needed (uncomment for a fresh environment)
# %pip install -q "transformers>=4.40.0" "datasets>=2.18.0" "peft>=0.11.0" "accelerate>=0.28.0" "bitsandbytes" "evaluate"

## ‚öôÔ∏è Configuration & Imports

Set up the training environment:
- **Seed**: 816 for reproducibility
- **Model**: microsoft/Phi-3.5-mini-instruct (3.8B parameters)
- **Quantization**: 4-bit NF4 with double quantization
- **Training**: 100 steps with batch size 1, gradient accumulation 8
- **Learning Rate**: 2e-4 with 3% warmup

In [2]:
import os
from dataclasses import dataclass
from typing import Dict, List

import torch
import numpy as np
import evaluate
from datasets import load_dataset, DatasetDict
from huggingface_hub import login
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
    TrainerCallback,
    set_seed,
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel


In [None]:
SEED = 816
set_seed(SEED)

# Configurable training parameters
base_model = "microsoft/Phi-3.5-mini-instruct"
train_file = "./data/my_custom_data.jsonl"
output_dir = "./checkpoints/llama3-lora-med"
max_length = 2048
batch_size = 1
lr = 2e-4
max_steps = 100
validate_steps = 50  # Evaluate test set every N steps


In [4]:
# System prompt for medical assistant
system_prompt = (
    "You are a careful medical assistant. Reason step by step, cite sources when available, "
    "and avoid guessing beyond the provided information. "
    "Your primary function is to not harm. If you are unsure, then tell the user that you don't know the answer. "
    "Always end the conversation with the words 'This response was generated by AI. "
    "Please check with professional medical practictioners to confirm the results are safe and appropriate.'"
)

In [5]:
# 4-bit quantization configuration for memory efficiency
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
)

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


## üìö Load Dataset

Load the JSONL dataset and split into train/validation/test sets.

**Dataset Format:**
```json
{"instruction": "What is hypertension?", "response": "Hypertension, or high blood pressure, means ..."}
```

**Splits:**
- Training: 80% of data
- Validation: 10% of data
- Test: 10% of data

In [6]:
# Load dataset from JSONL file
full_dataset = load_dataset("json", data_files={"train": train_file})["train"]

# Split into train/validation/test using cascading train_test_split
# First split: 80% train, 20% temp (for validation + test)
split_1 = full_dataset.train_test_split(test_size=0.2, seed=SEED)
train_data = split_1["train"]
temp_data = split_1["test"]

# Second split: split the temp data 50/50 into validation and test
split_2 = temp_data.train_test_split(test_size=0.5, seed=SEED)
validation_data = split_2["train"]
test_data = split_2["test"]

# Create DatasetDict with all three splits
raw_dataset = DatasetDict({
    "train": train_data,
    "validation": validation_data,
    "test": test_data
})

print(raw_dataset)
print("\nExample from training set:")
print(raw_dataset["train"][0])

Generating train split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['instruction', 'response'],
        num_rows: 80
    })
    validation: Dataset({
        features: ['instruction', 'response'],
        num_rows: 10
    })
    test: Dataset({
        features: ['instruction', 'response'],
        num_rows: 11
    })
})

Example from training set:
{'instruction': 'Describe the assessment and management of chronic kidney disease.', 'response': 'Stage by eGFR: Stage 1 (‚â•90), 2 (60-89), 3a (45-59), 3b (30-44), 4 (15-29), 5 (<15, ESRD). Slow progression: BP <120 mmHg (ACEi/ARB), glycemic control, avoid NSAIDs, limit protein (~0.6-0.8 g/kg), manage lipids, smoking cessation. Screen for complications: anemia (EPO if target Hgb 10-12), bone disease (check PTH, phosphate, calcium), cardiovascular disease. Nephrology referral if rapid decline, proteinuria, or stage 4-5. Prepare for renal replacement therapy (transplant, hemodialysis, peritoneal dialysis) when approaching ESRD.'}


## ü§ñ Model Setup

**Steps:**
1. Load tokenizer and configure padding
2. Load model in 4-bit quantization
3. Prepare model for k-bit training
4. Enable gradient checkpointing (saves memory)
5. Attach LoRA adapters

**LoRA Configuration:**
- Rank (r): 16
- Alpha: 32
- Target modules: All attention and MLP layers
- Dropout: 5%

In [7]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.padding_side = "right"

# Load model with 4-bit quantization
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
)

# Prepare for LoRA fine-tuning in 4-bit
model.config.use_cache = False  # Required for gradient checkpointing
model = prepare_model_for_kbit_training(model)
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Attach LoRA adapters
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

trainable params: 8,912,896 || all params: 3,829,992,448 || trainable%: 0.2327


## üî§ Data Preprocessing

**Process:**
1. **Format Examples**: Apply chat template (system, user, assistant roles)
2. **Tokenize**: Convert text to token IDs with padding/truncation
3. **Create Labels**: Clone input_ids for next-token prediction

**Chat Template Example:**
```
<|system|>You are a careful medical assistant...<|end|>
<|user|>What is hypertension?<|end|>
<|assistant|>Hypertension is...<|end|>
```

In [None]:
def format_example(example: Dict[str, str]) -> Dict[str, str]:
    """Apply chat template to each example."""
    messages: List[Dict[str, str]] = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": example.get("instruction", "")},
        {"role": "assistant", "content": example.get("response", "")},
    ]
    example["text"] = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=False,
    )
    return example

# Apply chat template to all examples
formatted = raw_dataset.map(format_example, remove_columns=raw_dataset["train"].column_names)


def tokenize_batch(batch: Dict[str, List[str]]) -> Dict[str, torch.Tensor]:
    """Tokenize a batch of examples."""
    tokenized = tokenizer(
        batch["text"],
        max_length=max_length,
        truncation=True,
        padding="max_length",
        return_tensors="pt",
    )
    tokenized["labels"] = tokenized["input_ids"].clone()
    return tokenized

# Tokenize all examples
tokenized = formatted.map(tokenize_batch, batched=True, remove_columns=["text"])
train_dataset = tokenized["train"].with_format("torch")
eval_dataset = tokenized["validation"].with_format("torch")

print(f"Training examples: {len(train_dataset)}")
print(f"Validation examples: {len(eval_dataset)}")
print("\nFirst 120 tokens of training example:")
print(tokenizer.decode(train_dataset[0]["input_ids"][:120]))

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Map:   0%|          | 0/80 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/11 [00:00<?, ? examples/s]

Training examples: 80
Validation examples: 10

First 120 tokens of training example:
<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|end

: 

## üìä Metrics & Training Setup

**Metrics:**
- **Perplexity**: exp(loss) - measures how well the model predicts the next token
  - Lower is better. 
  - Lowest possible score is 1. 
  - A score of 5, for example, means the model is effectively chosing from a list of 5 possiblities for each next word completion.
- Computed using HuggingFace Evaluate library

**Custom Callback:**
- Evaluates **validation set** every 50 steps during training
- Provides real-time feedback on model performance
- **Test set is reserved for final evaluation only**

**Training Arguments:**
- Optimizer: paged_adamw_32bit (memory efficient)
- Mixed precision: bf16 (if available) or fp16
- Gradient accumulation: 8 steps (effective batch size = 8)
- Checkpointing: Save every 50 steps, keep last 2

In [None]:
# Data collator for causal language modeling
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)


def compute_metrics(eval_preds):
    """Compute perplexity metric from model predictions."""
    predictions, labels = eval_preds
    
    # Extract logits if predictions is a tuple
    if isinstance(predictions, tuple):
        predictions = predictions[0]
    
    # Shift for next-token prediction (standard for causal LM)
    shift_logits = predictions[..., :-1, :]
    shift_labels = labels[..., 1:]
    
    # Flatten for loss calculation
    shift_logits = shift_logits.reshape(-1, shift_logits.shape[-1])
    shift_labels = shift_labels.reshape(-1)
    
    # Calculate cross-entropy loss (ignore padding tokens with -100)
    loss_fct = torch.nn.CrossEntropyLoss(ignore_index=-100)
    loss = loss_fct(
        torch.tensor(shift_logits, dtype=torch.float32),
        torch.tensor(shift_labels, dtype=torch.long)
    )
    
    # Perplexity is exp(loss)
    perplexity = torch.exp(loss).item()
    
    return {"perplexity": perplexity}


# No need to prepare test dataset here - it's already in raw_dataset["test"]
# We'll only use it for final evaluation at the end

# Custom callback to evaluate validation set during training
class ValidationEvalCallback(TrainerCallback):
    """Evaluate validation set periodically during training."""
    
    def __init__(self, eval_dataset):
        self.eval_dataset = eval_dataset
        self.trainer_obj = None
    
    def on_step_end(self, args, state, control, **kwargs):
        """Called at the end of each training step."""
        if state.global_step % validate_steps == 0 and state.global_step > 0 and self.trainer_obj is not None:
            eval_results = self.trainer_obj.predict(self.eval_dataset)
            eval_loss = eval_results.metrics.get('test_loss', 'N/A')
            eval_perp = eval_results.metrics.get('test_perplexity', 'N/A')
            print(f"\n[Step {state.global_step}] Validation Loss: {eval_loss:.4f}, "
                  f"Validation Perplexity: {eval_perp:.4f} "
                  "(Lower is better. Lowest possible score is 1.)")


# Initialize training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=8,
    learning_rate=lr,
    max_steps=max_steps,
    warmup_ratio=0.03,
    logging_steps=5,
    save_steps=50,
    save_total_limit=2,
    bf16=torch.cuda.is_available(),
    fp16=not torch.cuda.is_available(),
    gradient_checkpointing=True,
    optim="paged_adamw_32bit",
    report_to=["none"],
)

# Initialize callback with validation dataset
val_callback = ValidationEvalCallback(eval_dataset)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[val_callback],
)

# Set trainer reference in callback (needed for validation evaluation)
val_callback.trainer_obj = trainer

print("\nüöÄ Starting training...")
trainer.train()


üöÄ Starting training...


Step,Training Loss
5,2.4158
10,2.0002
15,1.606
20,1.3234
25,1.0364
30,0.9216
35,0.8639
40,0.815
45,0.7807
50,0.759



[Step 50] Validation Loss: 0.8353, Validation Perplexity: 2.3047 (Lower is better)

[Step 100] Validation Loss: 1.0489, Validation Perplexity: 2.8650 (Lower is better)

[Step 150] Validation Loss: 1.5096, Validation Perplexity: 4.5627 (Lower is better)


## üíæ Save Model

Save the trained LoRA adapters and optionally merge with base model.

**Saved Artifacts:**
- LoRA adapter weights (lightweight, ~few MB)
- Tokenizer configuration

**Loading the adapter:** Use `PeftModel.from_pretrained(base_model, adapter_path)`

In [None]:
# Save LoRA adapter weights
adapter_dir = os.path.join(output_dir, "lora_adapter")
model.save_pretrained(adapter_dir)
tokenizer.save_pretrained(adapter_dir)
print(f"‚úÖ Adapter saved to {adapter_dir}")

# Optionally merge and save full model
merge = os.environ.get("MERGE_LORA", "0") == "1"
if merge:
    merged_dir = os.path.join(output_dir, "merged")
    merged_model = model.merge_and_unload()
    merged_model.save_pretrained(merged_dir)
    tokenizer.save_pretrained(merged_dir)
    print(f"‚úÖ Merged model saved to {merged_dir}")

## üß™ Test Set Evaluation

Evaluate the final model performance on the held-out test set.

**Metrics:**
- **Test Loss**: Cross-entropy loss on test examples
- **Test Perplexity**: exp(loss) 
  - Lower is better. 
  - Lowest possible score is 1. 
  - A score of 5, for example, means the model is effectively chosing from a list of 5 possiblities for each next word completion.

A perplexity of ~10-20 is typical for well-trained medical chatbots.

In [None]:
# Evaluate on the held-out test set (first time using it!)
test_formatted = raw_dataset["test"].map(format_example, remove_columns=raw_dataset["test"].column_names)
test_tokenized = test_formatted.map(tokenize_batch, batched=True, remove_columns=["text"])
test_dataset = test_tokenized.with_format("torch")

print(f"Test dataset size: {len(test_dataset)}")

# Evaluate on test set
test_results = trainer.predict(test_dataset)
print("\n" + "="*50)
print("üìä FINAL TEST METRICS")
print("="*50)
print(f"Test Loss: {test_results.metrics.get('test_loss', 'N/A'):.4f}")
print(f"Test Perplexity: {test_results.metrics.get('test_perplexity', 'N/A'):.4f}")
print("\nAll metrics:")
print(test_results.metrics)

## üí¨ Inference Demo

Test the fine-tuned model with a medical query.

**Process:**
1. Load base model in fp16
2. Attach trained LoRA adapters
3. Format query with chat template
4. Generate response with sampling disabled (deterministic)

**Try modifying the question to test different medical scenarios!**

In [None]:
# Load base model and attach LoRA adapters
adapter_dir = os.path.join(output_dir, "lora_adapter")
inference_model = AutoModelForCausalLM.from_pretrained(
    base_model,
    device_map="auto",
    torch_dtype=torch.float16,
)
inference_model = PeftModel.from_pretrained(inference_model, adapter_dir)
inference_model.eval()

# Prepare medical query
messages = [
    {"role": "system", "content": system_prompt},
    {"role": "user", "content": "Give two differential considerations for chest pain."},
]

# Tokenize with chat template
inputs = tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    return_tensors="pt",
)
inputs = inputs.to(inference_model.device)

# Generate response
print("ü§ñ Generating response...\n")
gen = inference_model.generate(
    inputs,
    max_new_tokens=200,
    temperature=0.3,
    do_sample=False,
)

# Decode and print response
response = tokenizer.decode(gen[0][inputs.shape[-1]:], skip_special_tokens=True)
print("Response:")
print("-" * 50)
print(response)