# Gemma2: Fine-Tuning with LoRA on ChatDoctor-HealthCareMagic

This notebook demonstrates how to fine-tune Google's Gemma2-2B model using Parameter-Efficient Fine-Tuning (PEFT) with LoRA (Low-Rank Adaptation) on the [ChatDoctor-HealthCareMagic medical Q&A dataset](https://huggingface.co/datasets/lavita/ChatDoctor-HealthCareMagic-100k).

## To run:
This notebook was run on Kaggle using NVIDIA T4(x2) [GPT T4x2]. Kaggle offers T4 for free once you have verified your phone number.

## 1. Setup Libraries

Install required packages and configure Keras backend for JAX with optimized memory usage on GPU.

In [1]:
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"
!pip install -q datasets
!pip install -q pandas

In [2]:
import os
os.environ["KERAS_BACKEND"] = "jax"  # Use JAX backend for Keras (optimized for GPU training)
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.00"  # Allocate 100% of GPU memory to JAX

In [3]:
import keras
import keras_nlp
from datasets import load_dataset

## 2. Dataset Preparation
Load the ChatDoctor-HealthCareMagic dataset, sample, and split into training + test. Format data for instruction-following and prepare evaluation samples to compare model performance before and after fine-tuning.

In [4]:
NUM_TRAINING_INSTANCES = 1000 # Number of training examples to process. 
RANDOM_SEED = 42
TRAIN_SPLIT = 0.9
NUM_EVAL_SAMPLES = min(5, int(NUM_TRAINING_INSTANCES * (1 - TRAIN_SPLIT)))  

In [5]:
dataset = load_dataset("lavita/ChatDoctor-HealthCareMagic-100k", split="train")
dataset = dataset.shuffle(seed=RANDOM_SEED)

In [6]:
data = []
for example in dataset:
    text = f"Instruction:\n{example['input']}\n\nResponse:\n{example['output']}"
    data.append(text)
    if len(data) >= NUM_TRAINING_INSTANCES:
        break

# Split 90% train, 10% test
split_idx = int(len(data) * TRAIN_SPLIT)
train_data = data[:split_idx]
test_data = data[split_idx:]

print(f"Train: {len(train_data)} examples")
print(f"Test: {len(test_data)} examples")
print("== Sample == \n", train_data[0])


Train: 900 examples
Test: 100 examples
== Sample == 
 Instruction:
I have been having alot of catching ,pain and discomfort under my right rib.  If I twist to either side especially my right it feels like my rib actually catches on something and at times I have to stop try to catch my breath and wait for it to subside.  There are times if I am laughing too hard that it will do the same thing but normally its more so if I have twisted or moved  a certain way

Response:
Hi thanks for asking question. Here you are complaining pain in particular position esp. While turning to a side. So strong possibility is about moderate degree muscular strain. It might have occurred by heavyweight lift or during some activities. Simple analgesic taken. Take rest. Sleep in supine position. Second here Costco Chat Doctor.  Ribs are tender to touch.x-ray also useful. If cough, cold, sore throat present then respiratory infections also has to be ruled out. Treat it symptomatically. If still seems serious th

In [7]:
import random

random.seed(RANDOM_SEED)

eval_samples = random.sample(test_data, NUM_EVAL_SAMPLES)

# Store evaluation data structure
eval_data = []
for i, test_sample in enumerate(eval_samples):
    # Parse the test sample to extract question and answer
    parts = test_sample.split("Instruction:\n")[1].split("\n\nResponse:\n")
    question = parts[0].strip()
    ground_truth_answer = parts[1].strip()
    
    eval_data.append({
        'question': question,
        'ground_truth': ground_truth_answer,
        'before_finetuning': None,  # Will be filled in next section
        'after_finetuning': None   # Will be filled after training
    })

print(f"Prepared {len(eval_data)} evaluation samples for comparison dataframe")
print(f"First evaluation question: {eval_data[0]['question'][:100]}...")

Prepared 5 evaluation samples for comparison dataframe
First evaluation question: i am female 50 yrs old i have urine infection and stabbing bladder pain for more then week now, and ...


## 3. Load Model & Baseline Evaluation
Load the Gemma 2B model and generate baseline responses on evaluation samples before any fine-tuning. This establishes the "before" performance for comparison.

In [8]:
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.summary()

In [9]:
import time

print("Generating responses on Gemma2-2b without fine-tuning...")
start_time = time.time()

for i, eval_item in enumerate(eval_data):
    sample_start = time.time()
    # Use raw question for original model (no special formatting)
    response = gemma_lm.generate(eval_item['question'], max_length=512)
    eval_data[i]['before_finetuning'] = response.strip()
    sample_time = time.time() - sample_start
    print(f"Processed sample {i+1}/{len(eval_data)} in {sample_time:.2f}s")

total_time = time.time() - start_time
avg_time = total_time / len(eval_data)

print(f"\n✅ Captured {len(eval_data)} before fine-tuning responses")
print(f"⏱️  Total inference time: {total_time:.2f}s")
print(f"⏱️  Average time per sample: {avg_time:.2f}s")

Generating responses on Gemma2-2b without fine-tuning...
Processed sample 1/5 in 14.91s
Processed sample 2/5 in 21.91s
Processed sample 3/5 in 26.62s
Processed sample 4/5 in 23.86s
Processed sample 5/5 in 26.42s

✅ Captured 5 before fine-tuning responses
⏱️  Total inference time: 113.72s
⏱️  Average time per sample: 22.74s


## 4. PEFT using LoRA

Enable Low-Rank Adaptation (LoRA) on the model backbone with rank 4, configure the optimizer, and fine-tune on the training data. LoRA allows efficient fine-tuning by only updating a small subset of parameters.

In [10]:
gemma_lm.backbone.enable_lora(rank=4)
gemma_lm.summary()

In [11]:
optimizer = keras.optimizers.AdamW(
    learning_rate=1e-4,
    weight_decay=0.01,
)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.preprocessor.sequence_length = 256 # Limit input sequence
gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

In [None]:
# Fine-tune the model
history = gemma_lm.fit(train_data, epochs=1, batch_size=1)

[1m 55/900[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m17:40[0m 1s/step - loss: 2.3238 - sparse_categorical_accuracy: 0.4242

In [None]:
# Print training metrics
print("Training History:")
print(f"Available metrics: {list(history.history.keys())}")
print("\nTraining Progress:")
for epoch in range(len(history.history['loss'])):
    print(f"Epoch {epoch + 1}:")
    for metric, values in history.history.items():
        print(f"  {metric}: {values[epoch]:.4f}")
    print()

## 5. Inference Pipeline & Evaluation

Configure the fine-tuned model with Top-K sampling, generate responses on the same evaluation questions, and create a comparison DataFrame showing before vs. after fine-tuning performance.

In [None]:
# Compile the model with the strategy again
strategy = keras_nlp.samplers.TopKSampler(k=10, temperature=0.7, seed=42)  
gemma_lm.compile(sampler=strategy)

In [None]:
def format_prompt(user_question):
    return f"Instruction:\n{user_question}\n\nResponse:\n"

In [None]:
print("Generating after fine-tuning responses...")
for i, eval_item in enumerate(eval_data):
    formatted_prompt = format_prompt(eval_item['question'])
    response = gemma_lm.generate(formatted_prompt, max_length=200)
    clean_response = response.replace(formatted_prompt, "").strip()
    eval_data[i]['after_finetuning'] = clean_response

print(f"Captured {len(eval_data)} after fine-tuning responses")

In [None]:
import pandas as pd

# Convert evaluation data directly to DataFrame and save
comparison_df = pd.DataFrame(eval_data)

print(f"✅ DataFrame created with {len(comparison_df)} samples")
print(f"Columns: {list(comparison_df.columns)}")

# Display sample results
print("\n== SAMPLE COMPARISON ==")
for i in range(min(2, len(comparison_df))):
    print(f"Question: {comparison_df.iloc[i]['question']}")
    print(f"Ground Truth: {comparison_df.iloc[i]['ground_truth']}")
    print(f"Before Fine-tuning: {comparison_df.iloc[i]['before_finetuning']}")
    print(f"After Fine-tuning: {comparison_df.iloc[i]['after_finetuning']}")
    print("-------\n")

# Save to CSV
comparison_df.to_csv('gemma2_2b_comparison.csv', index=False)
print("💾 Results saved to 'gemma2_2b_comparison.csv'")