# Experiment: MedGemma LoRA Fine-tuning for Maternal Health Risk Assessment

**Objective:**
- Fine-tune Google's MedGemma-4B-IT model using LoRA for maternal health risk classification
- Train on structured clinical dialogue data (vitals → risk assessment)
- Evaluate model's ability to generate clinically appropriate risk assessments

**Success Criteria:**
- Training loss decreases steadily without overfitting
- Model generates responses with correct RISK LEVEL prefix (LOW/MID/HIGH)
- Outputs contain required sections: Clinical Reasoning, Complications, Actions, Warning Signs
- Validation perplexity < 5.0 after 3 epochs

## 1. Setup: Imports and Reproducibility

In [None]:
# Core imports for MedGemma fine-tuning
from __future__ import annotations

import os
import json
import random
import torch
import numpy as np
from pathlib import Path
from typing import Any

# HuggingFace ecosystem
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    EarlyStoppingCallback,
)
from peft import LoraConfig, get_peft_model, TaskType, PeftModel
from datasets import Dataset, load_from_disk

# Reproducibility config
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

# Device setup
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 2. Configuration

Define model, LoRA, and training hyperparameters.

In [None]:
# Model and data configuration
MODEL_NAME = "google/medgemma-4b-it"  # 4B instruction-tuned MedGemma
DATA_PATH = "./mamaguard_train.jsonl"  # Training data from prepare_training_data.py
EVAL_PATH = "./mamaguard_eval.jsonl"   # Evaluation data
OUTPUT_DIR = "./medgemma-lora-maternal-health"

# LoRA hyperparameters
LORA_CONFIG = {
    "r": 16,              # LoRA rank (low = faster, high = more expressive)
    "lora_alpha": 32,     # Scaling factor (typically 2*r)
    "lora_dropout": 0.05, # Dropout for regularization
    "target_modules": [   # Which layers to apply LoRA
        "q_proj",
        "k_proj",
        "v_proj",
        "o_proj",
        "gate_proj",
        "up_proj",
        "down_proj",
    ],
    "bias": "none",
    "task_type": TaskType.CAUSAL_LM,
}

# Training hyperparameters
TRAINING_CONFIG = {
    "num_train_epochs": 3,
    "per_device_train_batch_size": 1,
    "per_device_eval_batch_size": 1,
    "gradient_accumulation_steps": 4,
    "learning_rate": 2e-4,
    "warmup_steps": 100,
    "weight_decay": 0.01,
    "logging_steps": 10,
    "save_steps": 200,
    "eval_steps": 200,
    "save_total_limit": 3,
    "load_best_model_at_end": True,
    "metric_for_best_model": "eval_loss",
    "greater_is_better": False,
    "fp16": torch.cuda.is_available(),  # Mixed precision training
    "optim": "paged_adamw_8bit",  # Memory-efficient optimizer
}

# Tokenization config
MAX_SEQ_LENGTH = 2048  # Max sequence length for training

print("Configuration loaded:")
print(f"  Model: {MODEL_NAME}")
print(f"  LoRA rank: {LORA_CONFIG['r']}, alpha: {LORA_CONFIG['lora_alpha']}")
print(f"  Learning rate: {TRAINING_CONFIG['learning_rate']}")
print(f"  Batch size: {TRAINING_CONFIG['per_device_train_batch_size']} x {TRAINING_CONFIG['gradient_accumulation_steps']} grad accum")

## 3. Load and Prepare Dataset

In [None]:
def load_jsonl(path: str) -> list[dict[str, Any]]:
    """Load JSONL file into list of dictionaries."""
    data = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            data.append(json.loads(line.strip()))
    return data

# Load training and evaluation data
train_data = load_jsonl(DATA_PATH)
eval_data = load_jsonl(EVAL_PATH)

print(f"Loaded {len(train_data)} training samples")
print(f"Loaded {len(eval_data)} evaluation samples")

# Display a sample
print("\nSample training entry:")
print(train_data[0]["text"][:800] + "...")

## 4. Load Tokenizer and Model

In [None]:
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)

# Set pad token if not present (Gemma uses eos as pad)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

print(f"Tokenizer vocab size: {len(tokenizer)}")
print(f"Pad token: {tokenizer.pad_token}")
print(f"EOS token: {tokenizer.eos_token}")

In [None]:
# Load model with memory-efficient settings
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
    trust_remote_code=True,
)

# Print model info
print(f"Model loaded: {MODEL_NAME}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.1f}M")

## 5. Apply LoRA Configuration

In [None]:
# Create LoRA config
lora_config = LoraConfig(
    r=LORA_CONFIG["r"],
    lora_alpha=LORA_CONFIG["lora_alpha"],
    target_modules=LORA_CONFIG["target_modules"],
    lora_dropout=LORA_CONFIG["lora_dropout"],
    bias=LORA_CONFIG["bias"],
    task_type=LORA_CONFIG["task_type"],
)

# Apply LoRA to model
model = get_peft_model(model, lora_config)

# Print trainable parameters
model.print_trainable_parameters()

## 6. Tokenize Dataset

In [None]:
def tokenize_function(examples: dict[str, list]) -> dict[str, list]:
    """Tokenize the text field for causal LM training."""
    outputs = tokenizer(
        examples["text"],
        truncation=True,
        max_length=MAX_SEQ_LENGTH,
        padding="max_length",
        return_tensors=None,
    )
    # For causal LM, labels are same as input_ids
    outputs["labels"] = outputs["input_ids"].copy()
    return outputs

# Convert to HF Dataset
train_dataset = Dataset.from_list([{"text": d["text"]} for d in train_data])
eval_dataset = Dataset.from_list([{"text": d["text"]} for d in eval_data])

# Tokenize
train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
eval_dataset = eval_dataset.map(tokenize_function, batched=True, remove_columns=["text"])

print(f"Train dataset size: {len(train_dataset)}")
print(f"Eval dataset size: {len(eval_dataset)}")
print(f"Example tokenized length: {len(train_dataset[0]['input_ids'])}")

## 7. Setup Training Arguments

In [None]:
# Create output directory
Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# Training arguments
training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=TRAINING_CONFIG["num_train_epochs"],
    per_device_train_batch_size=TRAINING_CONFIG["per_device_train_batch_size"],
    per_device_eval_batch_size=TRAINING_CONFIG["per_device_eval_batch_size"],
    gradient_accumulation_steps=TRAINING_CONFIG["gradient_accumulation_steps"],
    learning_rate=TRAINING_CONFIG["learning_rate"],
    warmup_steps=TRAINING_CONFIG["warmup_steps"],
    weight_decay=TRAINING_CONFIG["weight_decay"],
    logging_steps=TRAINING_CONFIG["logging_steps"],
    save_steps=TRAINING_CONFIG["save_steps"],
    eval_steps=TRAINING_CONFIG["eval_steps"],
    save_total_limit=TRAINING_CONFIG["save_total_limit"],
    load_best_model_at_end=TRAINING_CONFIG["load_best_model_at_end"],
    metric_for_best_model=TRAINING_CONFIG["metric_for_best_model"],
    greater_is_better=TRAINING_CONFIG["greater_is_better"],
    eval_strategy="steps",
    save_strategy="steps",
    logging_strategy="steps",
    fp16=TRAINING_CONFIG["fp16"],
    optim=TRAINING_CONFIG["optim"],
    report_to="none",  # Disable wandb/tensorboard for local runs
    seed=SEED,
)

print(f"Training output dir: {OUTPUT_DIR}")
print(f"Total training steps: ~{len(train_dataset) * TRAINING_CONFIG['num_train_epochs'] // (TRAINING_CONFIG['per_device_train_batch_size'] * TRAINING_CONFIG['gradient_accumulation_steps'])}")

## 8. Initialize Trainer and Train

In [None]:
# Data collator for language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,  # We're doing causal LM, not masked LM
)

# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

print("Trainer initialized. Starting training...")

In [None]:
# Start training
train_result = trainer.train()

# Print training metrics
print("\nTraining completed!")
print(f"Final train loss: {train_result.training_loss:.4f}")
print(f"Training runtime: {train_result.metrics.get('train_runtime', 0)/60:.1f} minutes")

## 9. Save Model and Tokenizer

In [None]:
# Save the LoRA adapter weights
model.save_pretrained(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)

# Save training config for reproducibility
config_save = {
    "model_name": MODEL_NAME,
    "lora_config": LORA_CONFIG,
    "training_config": TRAINING_CONFIG,
    "seed": SEED,
    "max_seq_length": MAX_SEQ_LENGTH,
}

with open(f"{OUTPUT_DIR}/training_config.json", "w") as f:
    json.dump(config_save, f, indent=2)

print(f"Model saved to: {OUTPUT_DIR}")
print(f"Files in output dir: {list(Path(OUTPUT_DIR).glob('*'))}")

## 10. Evaluation: Generate Sample Responses

In [None]:
def generate_response(model, tokenizer, prompt: str, max_new_tokens: int = 512) -> str:
    """Generate a response from the fine-tuned model."""
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode only the new tokens
    input_length = inputs["input_ids"].shape[1]
    response_tokens = outputs[0][input_length:]
    response = tokenizer.decode(response_tokens, skip_special_tokens=True)
    return response.strip()

# Test with a sample prompt
sample_prompt = train_data[0]["text"].split("<end_of_turn>")[0] + "<end_of_turn>\n<start_of_turn>model\n"
print("Test prompt:")
print(sample_prompt[:500] + "...\n")

# Generate response
response = generate_response(model, tokenizer, sample_prompt)
print("Generated response:")
print(response)

## 11. Validate Output Format

In [None]:
def validate_response_format(response: str) -> dict[str, Any]:
    """Check if response follows expected format."""
    checks = {
        "has_risk_level": response.upper().startswith("RISK LEVEL:"),
        "has_clinical_reasoning": "CLINICAL REASONING" in response.upper() or "CLINICAL ASSESSMENT" in response.upper(),
        "has_complications": "COMPLICATION" in response.upper(),
        "has_actions": "ACTION" in response.upper() or "RECOMMENDED" in response.upper(),
        "has_warning_signs": "WARNING" in response.upper(),
        "risk_label_valid": any(r in response.upper() for r in ["LOW", "MID", "HIGH"]),
    }
    checks["all_pass"] = all(checks.values())
    return checks

# Validate the generated response
validation = validate_response_format(response)
print("Validation results:")
for check, passed in validation.items():
    status = "✓" if passed else "✗"
    print(f"  {status} {check}")

## 12. Test on Multiple Samples

In [None]:
# Test on a few samples from each risk level
test_samples = []
for risk in ["LOW", "MID", "HIGH"]:
    samples = [d for d in eval_data if d.get("risk_level") == risk][:2]
    test_samples.extend(samples)

print(f"Testing on {len(test_samples)} samples...\n")

results = []
for i, sample in enumerate(test_samples):
    # Extract prompt from text
    prompt = sample["text"].split("<end_of_turn>")[0] + "<end_of_turn>\n<start_of_turn>model\n"
    expected_risk = sample.get("risk_level", "UNKNOWN")
    
    # Generate
    generated = generate_response(model, tokenizer, prompt, max_new_tokens=400)
    validation = validate_response_format(generated)
    
    # Extract predicted risk
    predicted_risk = "UNKNOWN"
    for r in ["LOW", "MID", "HIGH"]:
        if r in generated.upper()[:50]:
            predicted_risk = r
            break
    
    results.append({
        "expected": expected_risk,
        "predicted": predicted_risk,
        "format_valid": validation["all_pass"],
        "response_preview": generated[:200] + "..." if len(generated) > 200 else generated,
    })
    
    print(f"Sample {i+1} | Expected: {expected_risk} | Predicted: {predicted_risk}")
    print(f"  Format valid: {validation['all_pass']}")
    print(f"  Preview: {results[-1]['response_preview']}\n")

## Results Summary

In [None]:
# Compile results
correct_predictions = sum(1 for r in results if r["expected"] == r["predicted"])
valid_formats = sum(1 for r in results if r["format_valid"])

summary = {
    "total_samples_tested": len(results),
    "correct_risk_predictions": correct_predictions,
    "prediction_accuracy": correct_predictions / len(results) if results else 0,
    "valid_format_count": valid_formats,
    "format_compliance_rate": valid_formats / len(results) if results else 0,
    "training_loss": train_result.training_loss if 'train_result' in locals() else None,
    "model_output_dir": OUTPUT_DIR,
}

print("=== EXPERIMENT SUMMARY ===")
for key, value in summary.items():
    if isinstance(value, float):
        print(f"{key}: {value:.2%}")
    else:
        print(f"{key}: {value}")

# Save summary
with open(f"{OUTPUT_DIR}/experiment_summary.json", "w") as f:
    json.dump(summary, f, indent=2)
print(f"\nSummary saved to: {OUTPUT_DIR}/experiment_summary.json")

## Next Steps

- **If results are good:** Merge LoRA weights with base model for deployment using `model.merge_and_unload()`
- **If underfitting:** Increase LoRA rank (r=32), train for more epochs, or increase learning rate
- **If overfitting:** Increase dropout, reduce epochs, or add more training data
- **Production:** Convert to GGUF or ONNX format for efficient inference
- **Evaluation:** Run full evaluation on held-out test set with clinical expert review