# MedGemma Fine-Tuned Model Evaluation

**Purpose**: Evaluate the fine-tuned MedGemma-4B model using Clinical BERTScore and qualitative analysis.

**Model**: Fine-tuned MedGemma-4B with LoRA adapters for clinical discharge summarization.

**Evaluation Metrics**:
- Clinical BERTScore (Precision, Recall, F1)
- Qualitative comparison of generated vs reference summaries

## 1. Installation and Imports

Install required libraries for model loading and evaluation.

In [1]:
# Uncomment for Google Colab
# !pip install -q -U transformers peft bitsandbytes accelerate bert_score scipy torch

print("✓ Installation complete (or skipped for local environment)")

✓ Installation complete (or skipped for local environment)


In [2]:
import warnings

import numpy as np
import pandas as pd
import torch
from bert_score import BERTScorer
from datasets import Dataset
from peft import PeftModel
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)

warnings.filterwarnings('ignore')

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch version: 2.9.1+cu130
CUDA available: True
GPU: NVIDIA GeForce RTX 5060 Laptop GPU
GPU Memory: 8.55 GB


## 2. Configuration

Set paths and model parameters.

In [4]:
# ============================================================================
# MODEL CONFIGURATION
# ============================================================================

# Base model
MODEL_NAME = "google/medgemma-4b-it"

# Path to fine-tuned LoRA adapters
ADAPTER_PATH = "./medgemma-discharge-summarization/final"

# Dataset path
MIMIC_CSV_PATH = "mimic_cleaned_text_only.csv"

# ============================================================================
# EVALUATION SETTINGS
# ============================================================================

# Number of test samples to evaluate (set to -1 for all)
NUM_TEST_SAMPLES = 100

# ============================================================================
# GENERATION PARAMETERS
# ============================================================================

MAX_NEW_TOKENS = 512
TEMPERATURE = 0.7
TOP_P = 0.9
TOP_K = 50
REPETITION_PENALTY = 1.1

# ============================================================================
# BERTSCORE CONFIGURATION
# ============================================================================

CLINICAL_BERT = "emilyalsentzer/Bio_ClinicalBERT"

print("✓ Configuration loaded")
print(f"  Base model: {MODEL_NAME}")
print(f"  Adapter path: {ADAPTER_PATH}")
print(f"  Dataset: {MIMIC_CSV_PATH}")
print(f"  Test samples: {NUM_TEST_SAMPLES if NUM_TEST_SAMPLES > 0 else 'All'}")

✓ Configuration loaded
  Base model: google/medgemma-4b-it
  Adapter path: ./medgemma-discharge-summarization/final
  Dataset: mimic_cleaned_text_only.csv
  Test samples: 100


## 3. Load Test Dataset

Load and prepare the MIMIC dataset for evaluation.

In [5]:
import os

print(f"Loading dataset from: {MIMIC_CSV_PATH}\n")

if os.path.exists(MIMIC_CSV_PATH):
    # Load the CSV file
    mimic_df = pd.read_csv(MIMIC_CSV_PATH)

    # Take subset for testing (first 10,000 samples)
    mimic_df = mimic_df[:10_000]

    print(f"✓ Dataset loaded successfully!")
    print(f"  Total samples: {len(mimic_df)}")
    print(f"  Columns: {list(mimic_df.columns)}\n")

    # Add instruction column
    instruction_text = "Summarize the following clinical discharge notes. Include ALL diagnoses, medications, vitals, lab results, procedures, and follow-up instructions. Ensure complete coverage of all medical entities."
    mimic_df['instruction'] = instruction_text

    # Rename columns
    mimic_df = mimic_df.rename(columns={
        'final_input': 'input',
        'final_target': 'output'
    })

    # Remove rows with missing data
    mimic_df = mimic_df.dropna(subset=['input', 'output'])

    # Convert to Hugging Face Dataset
    dataset = Dataset.from_pandas(mimic_df[['instruction', 'input', 'output']])

    # Split into train and test sets (5% test)
    dataset = dataset.train_test_split(test_size=0.05, seed=42)
    test_dataset = dataset["test"]

    # Limit test samples if configured
    if NUM_TEST_SAMPLES > 0 and NUM_TEST_SAMPLES < len(test_dataset):
        test_dataset = test_dataset.select(range(NUM_TEST_SAMPLES))

    print(f"✓ Test dataset prepared!")
    print(f"  Test samples: {len(test_dataset)}")

else:
    print(f"⚠️  File not found: {MIMIC_CSV_PATH}")
    print(f"Please ensure the dataset is in the project directory")

Loading dataset from: mimic_cleaned_text_only.csv

✓ Dataset loaded successfully!
  Total samples: 10000
  Columns: ['final_input', 'final_target']

✓ Test dataset prepared!
  Test samples: 100


## 4. Load Fine-Tuned Model

Load the base MedGemma model with 4-bit quantization, then load the fine-tuned LoRA adapters.

**Memory Optimization**: Uses QLoRA (4-bit quantization) for efficient inference.

In [6]:
# Enable synchronous CUDA for better error messages
import os

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
print("✓ CUDA synchronous mode enabled")

✓ CUDA synchronous mode enabled


In [7]:
print("=" * 80)
print("LOADING FINE-TUNED MEDGEMMA MODEL")
print("=" * 80)

# Configure 4-bit quantization
compute_dtype = torch.bfloat16
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=compute_dtype
)

print("\nStep 1: Loading base model...")
print(f"  Model: {MODEL_NAME}")
print(f"  Quantization: 4-bit NF4")
print(f"  This may take 2-3 minutes...\n")

# Load base model
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True,
    dtype=torch.bfloat16
)
print("✓ Base model loaded")

# Load LoRA adapters
if os.path.exists(ADAPTER_PATH):
    print(f"\nStep 2: Loading LoRA adapters...")
    print(f"  Path: {ADAPTER_PATH}")
    model = PeftModel.from_pretrained(model, ADAPTER_PATH)
    print("✓ LoRA adapters loaded")
else:
    print(f"\n⚠️  WARNING: Adapter path not found: {ADAPTER_PATH}")
    print("   Using base model only (not fine-tuned)")

# Get actual vocab size from model
embedding_layer = model.get_input_embeddings()
actual_vocab_size = embedding_layer.weight.shape[0]
print(f"\n  Model embedding vocab size: {actual_vocab_size}")

# Load tokenizer
print(f"\nStep 3: Loading tokenizer...")
if os.path.exists(ADAPTER_PATH):
    try:
        tokenizer = AutoTokenizer.from_pretrained(
            ADAPTER_PATH,
            trust_remote_code=True,
            padding_side="right",
            add_eos_token=True
        )
        print("✓ Tokenizer loaded from adapter path")
    except Exception as e:
        print(f"⚠️  Adapter tokenizer failed: {e}")
        print("   Loading base model tokenizer instead")
        tokenizer = AutoTokenizer.from_pretrained(
            MODEL_NAME,
            trust_remote_code=True,
            padding_side="right",
            add_eos_token=True
        )
else:
    tokenizer = AutoTokenizer.from_pretrained(
        MODEL_NAME,
        trust_remote_code=True,
        padding_side="right",
        add_eos_token=True
    )

tokenizer.pad_token = tokenizer.eos_token

print(f"\n  Tokenizer vocab size: {len(tokenizer)}")
print(f"  PAD token ID: {tokenizer.pad_token_id}")
print(f"  EOS token ID: {tokenizer.eos_token_id}")

# Validation check
if len(tokenizer) != actual_vocab_size:
    print(f"\n⚠️  MISMATCH DETECTED!")
    print(f"   Tokenizer vocab: {len(tokenizer)}")
    print(f"   Model vocab: {actual_vocab_size}")

    if len(tokenizer) > actual_vocab_size:
        print(f"\n   Resizing model embeddings to {len(tokenizer)}...")
        model.resize_token_embeddings(len(tokenizer))
        actual_vocab_size = model.get_input_embeddings().weight.shape[0]
        print(f"   ✓ New model vocab size: {actual_vocab_size}")

# Validation test
print(f"\n{'=' * 80}")
print("VALIDATION TEST")
print(f"{'=' * 80}")

test_text = "Patient presented with chest pain."
test_tokens = tokenizer(test_text, return_tensors="pt")
max_id = test_tokens['input_ids'].max().item()

print(f"Test text: '{test_text}'")
print(f"Max token ID: {max_id}")
print(f"Valid range: [0, {actual_vocab_size - 1}]")

if max_id >= actual_vocab_size:
    print(f"\n❌ CRITICAL ERROR: Token ID out of range!")
    raise ValueError(f"Token ID {max_id} >= vocab size {actual_vocab_size}")
else:
    print(f"\n✅ VALIDATION PASSED!")
    print(f"   All token IDs are within valid range")

model.eval()
print(f"\n✓ Model ready for evaluation")

LOADING FINE-TUNED MEDGEMMA MODEL

Step 1: Loading base model...
  Model: google/medgemma-4b-it
  Quantization: 4-bit NF4
  This may take 2-3 minutes...



`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

✓ Base model loaded

Step 2: Loading LoRA adapters...
  Path: ./medgemma-discharge-summarization/final
✓ LoRA adapters loaded

  Model embedding vocab size: 262208

Step 3: Loading tokenizer...
✓ Tokenizer loaded from adapter path

  Tokenizer vocab size: 262145
  PAD token ID: 1
  EOS token ID: 1

⚠️  MISMATCH DETECTED!
   Tokenizer vocab: 262145
   Model vocab: 262208

VALIDATION TEST
Test text: 'Patient presented with chest pain.'
Max token ID: 236761
Valid range: [0, 262207]

✅ VALIDATION PASSED!
   All token IDs are within valid range

✓ Model ready for evaluation


## 5. Generate Predictions on Test Set

Generate clinical summaries for all test samples.

In [9]:
print("Generating predictions on test set...\n")

predictions = []
references = []

for i, sample in enumerate(test_dataset):
    print(f"Generating summary {i + 1}/{len(test_dataset)}...", end=" ")

    instruction = sample["instruction"]
    input_text = sample["input"]
    reference = sample["output"]

    # Format the prompt (without the model's response)
    inference_prompt = f"""<start_of_turn>user
{instruction}

Clinical Notes:
{input_text}<end_of_turn>
<start_of_turn>model
"""

    # Tokenize
    inputs = tokenizer(inference_prompt, return_tensors="pt").to(model.device)

    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            top_k=TOP_K,
            repetition_penalty=REPETITION_PENALTY,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
            no_repeat_ngram_size=2
        )

    # Decode
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract only the model's response
    model_response_marker = "<start_of_turn>model"
    if model_response_marker in generated_text:
        generated_summary = generated_text.split(model_response_marker)[-1].strip()
    else:
        generated_summary = generated_text[len(inference_prompt):].strip()

    predictions.append(generated_summary)
    references.append(reference)

    print(f"✓ ({len(generated_summary)} chars)")

print(f"\n✓ All predictions generated")
print(f"  Total predictions: {len(predictions)}")

Generating predictions on test set...

Generating summary 1/100... ✓ (807 chars)
Generating summary 2/100... ✓ (1775 chars)
Generating summary 3/100... ✓ (1409 chars)
Generating summary 4/100... ✓ (530 chars)
Generating summary 5/100... ✓ (800 chars)
Generating summary 6/100... ✓ (1470 chars)
Generating summary 7/100... ✓ (1152 chars)
Generating summary 8/100... ✓ (348 chars)
Generating summary 9/100... ✓ (367 chars)
Generating summary 10/100... ✓ (1734 chars)
Generating summary 11/100... ✓ (1236 chars)
Generating summary 12/100... ✓ (2064 chars)
Generating summary 13/100... ✓ (1194 chars)
Generating summary 14/100... ✓ (322 chars)
Generating summary 15/100... ✓ (670 chars)
Generating summary 16/100... ✓ (373 chars)
Generating summary 17/100... ✓ (3233 chars)
Generating summary 18/100... ✓ (1026 chars)
Generating summary 19/100... ✓ (1470 chars)
Generating summary 20/100... ✓ (205 chars)
Generating summary 21/100... ✓ (58 chars)
Generating summary 22/100... ✓ (249 chars)
Generating sum

## 6. Compute Clinical BERTScore

Evaluate semantic similarity using Bio_ClinicalBERT.

In [None]:
print("=" * 80)
print("COMPUTING CLINICAL BERTSCORE")
print("=" * 80)

print(f"\nInitializing BERTScorer with {CLINICAL_BERT}...")
clinical_scorer = BERTScorer(
    model_type=CLINICAL_BERT,
    num_layers=9,
    rescale_with_baseline=True,
    lang="en",
    device="cuda" if torch.cuda.is_available() else "cpu"
)
print("✓ BERTScorer initialized")

# Get the tokenizer from the scorer to do proper truncation
bert_tokenizer = clinical_scorer._tokenizer


def truncate_with_bert_tokenizer(text: str, tokenizer, max_length: int = 500) -> str:
    """
    Properly truncate text using BERT's tokenizer to ensure it fits within token limit.
    
    Args:
        text: Input text to truncate
        tokenizer: BERT tokenizer
        max_length: Maximum number of tokens (BERT supports 512, we use 500 for safety)
    
    Returns:
        Truncated text that will tokenize to <= max_length tokens
    """
    # Tokenize and truncate
    tokens = tokenizer.encode(
        text,
        add_special_tokens=True,
        truncation=True,
        max_length=max_length
    )

    # Decode back to text
    truncated_text = tokenizer.decode(tokens, skip_special_tokens=True)
    return truncated_text


# Truncate texts to fit BERT's 512 token limit
print("\nPreparing texts (truncating long sequences for BERT)...")
print("  Using BERT tokenizer for accurate truncation...")

truncated_predictions = []
truncated_references = []

for pred, ref in zip(predictions, references):
    truncated_predictions.append(truncate_with_bert_tokenizer(pred, bert_tokenizer))
    truncated_references.append(truncate_with_bert_tokenizer(ref, bert_tokenizer))

# Check truncation statistics
orig_pred_lens = [len(bert_tokenizer.encode(p)) for p in predictions]
trunc_pred_lens = [len(bert_tokenizer.encode(p)) for p in truncated_predictions]
num_truncated = sum(1 for o, t in zip(orig_pred_lens, trunc_pred_lens) if o != t)

print(f"  {num_truncated}/{len(predictions)} predictions were truncated")
print(f"  Average prediction tokens: {np.mean(trunc_pred_lens):.0f}")
print(f"  Max prediction tokens: {max(trunc_pred_lens)}")

print("\nCalculating BERTScores (this may take a few minutes)...\n")

# Compute scores with truncated texts
P, R, F1 = clinical_scorer.score(
    cands=truncated_predictions,
    refs=truncated_references,
)

# Convert to numpy
precision_scores = P.cpu().numpy()
recall_scores = R.cpu().numpy()
f1_scores = F1.cpu().numpy()

# Compute averages
avg_precision = np.mean(precision_scores)
avg_recall = np.mean(recall_scores)
avg_f1 = np.mean(f1_scores)

print("=" * 80)
print("CLINICAL BERTSCORE RESULTS")
print("=" * 80)
print(f"\nAverage Precision: {avg_precision:.4f}")
print(f"  → How much of the generated summary is clinically relevant")

print(f"\nAverage Recall: {avg_recall:.4f}")
print(f"  → How much of the reference summary is captured")
print(f"  → PRIMARY METRIC FOR HIGH RECALL")

print(f"\nAverage F1: {avg_f1:.4f}")
print(f"  → Harmonic mean of precision and recall")

print(f"\n{'=' * 80}")
print(f"\nNote: Texts were truncated to 500 tokens using BERT's tokenizer.")
print(f"This ensures all texts fit within BERT's 512 token limit.")

## 7. Detailed Per-Sample Analysis

Display scores for each test sample.

In [None]:
print("Per-Sample BERTScore Results:")
print("=" * 80)

for i in range(min(10, len(predictions))):  # Show first 10 samples
    print(f"\nSample {i + 1}:")
    print(f"  Precision: {precision_scores[i]:.4f}")
    print(f"  Recall: {recall_scores[i]:.4f}")
    print(f"  F1: {f1_scores[i]:.4f}")

print(f"\n... (showing first 10 of {len(predictions)} samples)")
print("=" * 80)

## 8. Qualitative Analysis

Compare generated summaries with reference summaries for qualitative assessment.

In [None]:
print("=" * 80)
print("QUALITATIVE ANALYSIS: Generated vs Reference Summaries")
print("=" * 80)

# Show 3 examples
num_examples = min(3, len(predictions))

for i in range(num_examples):
    print(f"\n{'=' * 80}")
    print(f"EXAMPLE {i + 1}")
    print(f"{'=' * 80}\n")

    print("INPUT (Clinical Notes - first 400 chars):")
    print("-" * 80)
    print(test_dataset[i]["input"][:400] + "...\n")

    print("REFERENCE SUMMARY:")
    print("-" * 80)
    print(references[i])
    print()

    print("GENERATED SUMMARY:")
    print("-" * 80)
    print(predictions[i])
    print()

    print("SCORES:")
    print("-" * 80)
    print(f"Precision: {precision_scores[i]:.4f}")
    print(f"Recall: {recall_scores[i]:.4f}")
    print(f"F1: {f1_scores[i]:.4f}")

print(f"\n{'=' * 80}")
print("END OF QUALITATIVE ANALYSIS")
print(f"{'=' * 80}")

## 9. Save Evaluation Results

Save predictions and scores to files for further analysis.

In [None]:
import json

# Create results directory
results_dir = "./evaluation_results"
os.makedirs(results_dir, exist_ok=True)

# Prepare results dataframe
results_df = pd.DataFrame({
    'input': [sample['input'] for sample in test_dataset],
    'reference': references,
    'prediction': predictions,
    'bertscore_precision': precision_scores,
    'bertscore_recall': recall_scores,
    'bertscore_f1': f1_scores
})

# Save as CSV
csv_path = os.path.join(results_dir, "evaluation_results.csv")
results_df.to_csv(csv_path, index=False)
print(f"✓ Results saved to CSV: {csv_path}")

# Save summary statistics
summary_stats = {
    "model": MODEL_NAME,
    "adapter_path": ADAPTER_PATH,
    "num_test_samples": len(predictions),
    "bertscore": {
        "precision": {
            "mean": float(avg_precision),
            "std": float(np.std(precision_scores)),
            "min": float(np.min(precision_scores)),
            "max": float(np.max(precision_scores))
        },
        "recall": {
            "mean": float(avg_recall),
            "std": float(np.std(recall_scores)),
            "min": float(np.min(recall_scores)),
            "max": float(np.max(recall_scores))
        },
        "f1": {
            "mean": float(avg_f1),
            "std": float(np.std(f1_scores)),
            "min": float(np.min(f1_scores)),
            "max": float(np.max(f1_scores))
        }
    },
    "generation_config": {
        "max_new_tokens": MAX_NEW_TOKENS,
        "temperature": TEMPERATURE,
        "top_p": TOP_P,
        "top_k": TOP_K,
        "repetition_penalty": REPETITION_PENALTY
    }
}

summary_path = os.path.join(results_dir, "summary_statistics.json")
with open(summary_path, 'w') as f:
    json.dump(summary_stats, f, indent=2)

print(f"✓ Summary statistics saved: {summary_path}")

print(f"\n{'=' * 80}")
print("EVALUATION COMPLETE")
print(f"{'=' * 80}")
print(f"\nAll results saved to: {results_dir}")

## 10. Evaluation Checklist

Use this checklist to assess the quality of generated summaries:

**High Recall Checklist**:
- ☐ Are all diagnoses mentioned?
- ☐ Are all medications listed with dosages?
- ☐ Are vital signs included?
- ☐ Are abnormal lab results reported?
- ☐ Are procedures and treatments described?
- ☐ Are follow-up instructions present?
- ☐ Is the timeline/hospital course clear?

**Quality Assessment**:
- Target Recall: ≥0.90 for production use
- Target F1: ≥0.85 for balanced performance
- Check for hallucinations (invented facts not in source)
- Verify medical terminology accuracy