# Clinical Discharge Summarization using MedGemma 4B with QLoRA

**Project Overview:**
This notebook demonstrates Parameter-Efficient Fine-Tuning (PEFT) using QLoRA on the MedGemma 4B model for clinical discharge summarization. The objective is to achieve **high recall** - generating detailed, verbose summaries that capture all medical entities (diagnoses, medications, vitals, abnormal lab results) from source clinical notes.

**Key Technologies:**
- Model: google/medgemma-4b (or base Gemma-4b)
- Technique: QLoRA (4-bit quantization)
- Evaluation: Clinical BERTScore using Bio_ClinicalBERT
- Platform: Google Colab / Consumer GPUs

## 1. Environment Setup

First, we install all necessary libraries for model loading, quantization, fine-tuning, and evaluation.

In [None]:
# Install required libraries
# transformers: Hugging Face library for loading pre-trained models and tokenizers
# peft: Parameter-Efficient Fine-Tuning library for LoRA adapters
# bitsandbytes: Enables 4-bit/8-bit quantization for memory efficiency
# trl: Transformer Reinforcement Learning library with SFTTrainer for supervised fine-tuning
# accelerate: Distributed training and mixed precision support
# datasets: For loading and processing datasets
# bert_score: For computing BERTScore with clinical models
# scipy: Required dependency for bert_score

!pip install -q -U transformers
!pip install -q -U peft
!pip install -q -U bitsandbytes
!pip install -q -U trl
!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U bert_score
!pip install -q -U scipy
!pip install -q -U einops  # Required for Gemma model architecture

print("✓ All libraries installed successfully!")

In [None]:
# Import necessary libraries
import torch
import transformers
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    PeftModel
)
from trl import SFTTrainer, DataCollatorForCompletionOnlyLM
from datasets import Dataset, load_dataset
import pandas as pd
import numpy as np
from bert_score import BERTScorer
import warnings
warnings.filterwarnings('ignore')

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Configuration and Hyperparameters

Define all model paths, LoRA parameters, and training hyperparameters in one place for easy modification.

In [None]:
# ============================================================================
# MODEL CONFIGURATION
# ============================================================================

# NOTE: If google/medgemma-4b is not publicly available, use "google/gemma-2-4b-it" 
# or "google/gemma-7b-it" as the base model. The medical variant should be similar.
MODEL_NAME = "google/medgemma-4b-it"  # Update to "google/medgemma-4b" when available

# ============================================================================
# LORA CONFIGURATION
# ============================================================================
# These parameters control the LoRA adapter architecture:
# - r (rank): The dimensionality of the low-rank matrices. Higher = more parameters = better fit but more memory
# - lora_alpha: Scaling factor for LoRA updates. Higher alpha = larger learning rate for LoRA weights
# - lora_dropout: Dropout probability for LoRA layers to prevent overfitting

LORA_R = 32  # Rank of 32 provides good balance between performance and memory
LORA_ALPHA = 64  # Alpha = 2*r is a common heuristic
LORA_DROPOUT = 0.05  # Small dropout for regularization

# Target modules for Gemma architecture
# These are the attention and MLP projection layers where LoRA adapters will be inserted
# Gemma uses a standard transformer architecture with:
# - q_proj, k_proj, v_proj: Query, Key, Value projections in attention
# - o_proj: Output projection after attention
# - gate_proj, up_proj, down_proj: MLP layers (Gemma uses SwiGLU activation)
TARGET_MODULES = [
    "q_proj",
    "k_proj", 
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj"
]

# ============================================================================
# TRAINING HYPERPARAMETERS
# ============================================================================
OUTPUT_DIR = "./medgemma-discharge-summarization"
NUM_EPOCHS = 3
BATCH_SIZE = 2  # Per device batch size (increase if you have more VRAM)
GRADIENT_ACCUMULATION_STEPS = 4  # Effective batch size = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS = 8
LEARNING_RATE = 2e-4  # Standard learning rate for LoRA fine-tuning
MAX_SEQ_LENGTH = 2048  # Maximum sequence length (Gemma supports up to 8192, but we use 2048 for memory efficiency)
WARMUP_STEPS = 100  # Warmup steps for learning rate scheduler
LOGGING_STEPS = 10  # Log training metrics every N steps
SAVE_STEPS = 100  # Save checkpoint every N steps

# ============================================================================
# GENERATION PARAMETERS (for high recall)
# ============================================================================
# These parameters are tuned to maximize detail and completeness in summaries:
MAX_NEW_TOKENS = 512  # Allow longer summaries to capture all details
TEMPERATURE = 0.7  # Moderate temperature for balance between creativity and coherence
TOP_P = 0.9  # Nucleus sampling for diverse but relevant outputs
TOP_K = 50  # Top-K sampling
REPETITION_PENALTY = 1.1  # Slight penalty to avoid repetitive text

print("✓ Configuration loaded successfully!")
print(f"  Model: {MODEL_NAME}")
print(f"  LoRA Rank: {LORA_R}, Alpha: {LORA_ALPHA}")
print(f"  Training: {NUM_EPOCHS} epochs, Batch Size: {BATCH_SIZE}, Gradient Accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"  Effective Batch Size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

## 3. Load and Prepare Dataset

This section provides TWO options for loading data:

### **Option A (Recommended): Load Your Actual MIMIC Dataset**
- File: `mimic_cleaned_text_only.csv`
- Columns: `final_input` (clinical notes), `final_target` (reference summary)

### **Option B: Use Sample Data**
- For demonstration and testing purposes
- Contains 3 realistic clinical examples

**Instructions:**
- If you have the MIMIC dataset file, use **Section 3A** below
- Otherwise, skip to **Section 3B** to get started with sample data

In [None]:
# ============================================================================
# SAMPLE DATA CREATION (USE ONLY IF YOU DON'T HAVE MIMIC DATASET)
# ============================================================================

# IMPORTANT: Skip this cell if you already ran Section 3A successfully
# This is fallback data for demonstration and testing purposes

print("Using sample data for demonstration...\n")
print("⚠ WARNING: This is limited sample data with only 3 examples.")
print("  For actual training, use Section 3A with your MIMIC dataset.\n")

sample_data = {
    "instruction": [
        "Summarize the following clinical discharge notes. Include all diagnoses, medications, vitals, and significant findings.",
        "Generate a comprehensive discharge summary from the following clinical notes. Ensure all medical entities are captured.",
        "Create a detailed discharge summary including all diagnoses, treatments, medications, and follow-up instructions."
    ],
    "input": [
        """Patient is a 65-year-old male with history of hypertension and type 2 diabetes mellitus who presented to the ED with chest pain. 
        Vital signs on admission: BP 156/92, HR 88, RR 18, O2 sat 96% on RA, Temp 98.6F. 
        ECG showed ST elevations in leads II, III, aVF. Troponin elevated at 2.4 ng/mL. 
        Patient underwent emergent cardiac catheterization revealing 95% occlusion of RCA. 
        Drug-eluting stent placed successfully. Post-procedure course uncomplicated. 
        Started on aspirin 81mg daily, clopidogrel 75mg daily, atorvastatin 80mg daily, metoprolol 25mg BID. 
        Blood glucose controlled with insulin sliding scale. HbA1c 8.2%. 
        Discharge vitals: BP 128/78, HR 72, stable. Patient educated on medication compliance and lifestyle modifications.""",
        
        """72-year-old female admitted with acute exacerbation of COPD. 
        History significant for 50 pack-year smoking history, chronic bronchitis, hypertension. 
        Presented with increased dyspnea, productive cough with yellow sputum, wheezing. 
        Vitals: BP 142/88, HR 102, RR 24, O2 sat 88% on RA improved to 94% on 2L NC, Temp 100.8F. 
        Chest X-ray showed hyperinflation, no infiltrates. ABG: pH 7.32, pCO2 58, pO2 62, HCO3 28. 
        Treated with nebulized albuterol/ipratropium q4h, IV methylprednisolone 125mg q6h tapered to prednisone 40mg PO daily, 
        azithromycin 500mg x1 then 250mg daily x 4 days. Clinical improvement noted by day 3. 
        Discharge on prednisone taper, continue home inhalers: tiotropium, fluticasone/salmeterol. 
        Smoking cessation counseling provided. Follow-up with pulmonology in 2 weeks.""",
        
        """58-year-old male with no significant past medical history presented with sudden onset severe headache ("worst headache of my life"), 
        nausea, vomiting, photophobia. Neurological exam revealed nuchal rigidity, positive Kernig's sign. 
        Vitals: BP 168/95, HR 78, RR 16, Temp 99.2F, O2 sat 99% RA. 
        CT head non-contrast showed subarachnoid hemorrhage in basal cisterns. 
        Neurosurgery consulted. CT angiogram revealed 7mm anterior communicating artery aneurysm. 
        Patient underwent successful endovascular coiling on hospital day 2. 
        Post-procedure monitoring in ICU for vasospasm prevention with nimodipine 60mg q4h, 
        maintaining systolic BP 140-160. No neurological deficits noted. 
        Started on levetiracetam 500mg BID for seizure prophylaxis. Discharge on day 7 with outpatient neurosurgery follow-up."""
    ],
    "output": [
        """DISCHARGE SUMMARY:
        Primary Diagnosis: ST-Elevation Myocardial Infarction (STEMI) - Inferior wall
        Secondary Diagnoses: Hypertension, Type 2 Diabetes Mellitus
        
        Hospital Course: 65-year-old male admitted with chest pain and ECG changes consistent with inferior STEMI. 
        Elevated troponin (2.4 ng/mL). Emergency cardiac catheterization revealed 95% RCA occlusion, successfully treated with drug-eluting stent placement. 
        Post-procedure recovery uncomplicated.
        
        Vitals: Admission BP 156/92, discharge BP 128/78. HR improved from 88 to 72. Remained afebrile.
        
        Medications at Discharge:
        - Aspirin 81mg daily (antiplatelet)
        - Clopidogrel 75mg daily (antiplatelet, continue for 12 months minimum)
        - Atorvastatin 80mg daily (high-intensity statin)
        - Metoprolol 25mg twice daily (beta-blocker)
        - Continue home diabetes medications, insulin adjustments made
        
        Labs: HbA1c 8.2% - diabetes management needs optimization.
        
        Follow-up: Cardiology in 1 week, Primary care in 2 weeks for diabetes management.
        Patient counseled on medication adherence, cardiac rehabilitation, smoking cessation if applicable, diet modification.""",
        
        """DISCHARGE SUMMARY:
        Primary Diagnosis: Acute Exacerbation of Chronic Obstructive Pulmonary Disease (COPD)
        Secondary Diagnoses: Chronic bronchitis, Hypertension, Tobacco use disorder (50 pack-years)
        
        Hospital Course: 72-year-old female with COPD admitted for acute exacerbation with increased dyspnea, productive cough, hypoxemia. 
        Initial O2 saturation 88% on room air, improved to 94% on 2L nasal cannula. Chest X-ray showed hyperinflation without infiltrates. 
        ABG revealed respiratory acidosis (pH 7.32, pCO2 58).
        
        Treatment: Aggressive bronchodilator therapy (albuterol/ipratropium nebulizers q4h), systemic corticosteroids 
        (IV methylprednisolone 125mg q6h transitioned to oral prednisone 40mg daily), antibiotics (azithromycin 5-day course). 
        Clinical improvement by hospital day 3.
        
        Vital Signs: Admission - BP 142/88, HR 102, RR 24, Temp 100.8F. Improved to normal by discharge.
        
        Medications at Discharge:
        - Prednisone 40mg daily (taper: 40mg x 3 days, 20mg x 3 days, 10mg x 3 days, then stop)
        - Tiotropium (continue home LAMA inhaler)
        - Fluticasone/salmeterol (continue home ICS/LABA inhaler)
        - Continue home antihypertensive medications
        
        Follow-up: Pulmonology in 2 weeks. Smoking cessation counseling provided - patient receptive to quitting. 
        Consider pulmonary rehabilitation referral.""",
        
        """DISCHARGE SUMMARY:
        Primary Diagnosis: Subarachnoid Hemorrhage (SAH) secondary to ruptured anterior communicating artery aneurysm
        
        Hospital Course: 58-year-old male presented with sudden severe headache ("thunderclap"), nuchal rigidity, positive meningeal signs. 
        Non-contrast head CT confirmed subarachnoid hemorrhage. CT angiogram identified 7mm AComm artery aneurysm. 
        Successful endovascular coiling performed on hospital day 2 by neurosurgery.
        
        ICU Monitoring: Post-procedure vasospasm prevention protocol initiated with nimodipine 60mg q4h. 
        Blood pressure maintained in target range (systolic 140-160). No delayed cerebral ischemia or neurological deficits observed. 
        Serial neurological exams remained normal.
        
        Vital Signs: Admission BP 168/95 (controlled post-procedure), HR 78, afebrile throughout stay.
        
        Medications at Discharge:
        - Nimodipine 60mg every 4 hours (continue for 21 days post-hemorrhage for vasospasm prevention)
        - Levetiracetam 500mg twice daily (seizure prophylaxis)
        
        Follow-up: Neurosurgery clinic in 1 week. Repeat CT angiogram in 6 months to assess aneurysm coiling stability. 
        Patient and family educated on warning signs of rebleeding, vasospasm (new headache, confusion, focal deficits). 
        No driving for 6 months per neurosurgery recommendations. Excellent prognosis given successful intervention and no complications."""
    ]
}

# Convert to pandas DataFrame then to Hugging Face Dataset
df = pd.DataFrame(sample_data)
dataset = Dataset.from_pandas(df)

# Split into train and test sets (80/20 split)
# In production, you should have a separate validation set as well
dataset = dataset.train_test_split(test_size=0.2, seed=42)
train_dataset = dataset["train"]
test_dataset = dataset["test"]

print(f"✓ Sample dataset created!")
print(f"  Training samples: {len(train_dataset)}")
print(f"  Test samples: {len(test_dataset)}")
print(f"\nSample training example:")
print(f"  Instruction: {train_dataset[0]['instruction'][:100]}...")
print(f"  Input length: {len(train_dataset[0]['input'])} characters")
print(f"  Output length: {len(train_dataset[0]['output'])} characters")
print("\n⚠ Remember: Use your actual MIMIC dataset (Section 3A) for real training!")

## 3B. Sample Data (Alternative Option)

**Use this section ONLY if you don't have the MIMIC dataset file**

This creates sample data for testing and demonstration purposes.

In [None]:
# ============================================================================
# LOAD ACTUAL MIMIC DATASET FROM CSV
# ============================================================================

# IMPORTANT: Run this cell if you have the mimic_cleaned_text_only.csv file
# If you don't have the file, skip to Section 3B for sample data

import os

# Path to your MIMIC dataset CSV file
# Adjust this path if your file is located elsewhere
MIMIC_CSV_PATH = "mimic_cleaned_text_only.csv"

# Check if the file exists
if os.path.exists(MIMIC_CSV_PATH):
    print(f"Loading MIMIC dataset from: {MIMIC_CSV_PATH}\n")
    
    # Load the CSV file using pandas
    # The file should have two columns: final_input and final_target
    mimic_df = pd.read_csv(MIMIC_CSV_PATH)
    
    print(f"✓ Dataset loaded successfully!")
    print(f"  Total samples: {len(mimic_df)}")
    print(f"  Columns: {list(mimic_df.columns)}\n")
    
    # Display basic statistics
    print("Dataset Statistics:")
    print(f"  Average input length: {mimic_df['final_input'].str.len().mean():.0f} characters")
    print(f"  Average target length: {mimic_df['final_target'].str.len().mean():.0f} characters")
    print(f"  Minimum input length: {mimic_df['final_input'].str.len().min():.0f} characters")
    print(f"  Maximum input length: {mimic_df['final_input'].str.len().max():.0f} characters")
    
    # Add a consistent instruction column
    # This instruction emphasizes HIGH RECALL - capturing all medical details
    instruction_text = "Summarize the following clinical discharge notes. Include ALL diagnoses, medications, vitals, lab results, procedures, and follow-up instructions. Ensure complete coverage of all medical entities."
    mimic_df['instruction'] = instruction_text
    
    # Rename columns to match the expected format
    # final_input → input (clinical notes)
    # final_target → output (reference summary)
    mimic_df = mimic_df.rename(columns={
        'final_input': 'input',
        'final_target': 'output'
    })
    
    # Remove any rows with missing data
    initial_count = len(mimic_df)
    mimic_df = mimic_df.dropna(subset=['input', 'output'])
    dropped_count = initial_count - len(mimic_df)
    
    if dropped_count > 0:
        print(f"\n⚠ Removed {dropped_count} rows with missing data")
    
    # Convert to Hugging Face Dataset
    dataset = Dataset.from_pandas(mimic_df[['instruction', 'input', 'output']])
    
    # Split into train and test sets
    # Using 90/10 split since we have a larger dataset
    # Adjust test_size as needed (e.g., 0.15 for 85/15 split)
    dataset = dataset.train_test_split(test_size=0.1, seed=42)
    train_dataset = dataset["train"]
    test_dataset = dataset["test"]
    
    print(f"\n✓ Dataset prepared and split!")
    print(f"  Training samples: {len(train_dataset)}")
    print(f"  Test samples: {len(test_dataset)}")
    
    # Display a sample from the training set
    print(f"\n{'='*80}")
    print("SAMPLE TRAINING EXAMPLE:")
    print(f"{'='*80}\n")
    print(f"Instruction: {train_dataset[0]['instruction'][:150]}...")
    print(f"\nInput (first 300 chars):\n{train_dataset[0]['input'][:300]}...")
    print(f"\nOutput (first 300 chars):\n{train_dataset[0]['output'][:300]}...")
    print(f"\n{'='*80}")
    
    print("\n✓ MIMIC dataset loaded successfully! You can now skip Section 3B.")
    print("  Proceed to Section 4 (Load Model with 4-bit Quantization)")
    
else:
    print(f"⚠ File not found: {MIMIC_CSV_PATH}")
    print(f"\nPlease either:")
    print(f"  1. Place the mimic_cleaned_text_only.csv file in the current directory")
    print(f"  2. Update MIMIC_CSV_PATH variable with the correct file path")
    print(f"  3. Skip to Section 3B to use sample data instead\n")
    print(f"Current directory: {os.getcwd()}")

## 3A. Load Your Actual MIMIC Dataset (RECOMMENDED)

**Use this section if you have the `mimic_cleaned_text_only.csv` file**

This loads your actual MIMIC clinical discharge dataset with the correct column mappings:
- `final_input` → clinical notes
- `final_target` → reference summaries

## 4. Load Model with 4-bit Quantization (QLoRA)

QLoRA (Quantized LoRA) enables fine-tuning large models on consumer GPUs by:
1. Loading the base model in 4-bit precision (NormalFloat 4-bit)
2. Using double quantization to further reduce memory
3. Computing gradients in float16 for numerical stability
4. Training only LoRA adapter weights (a small fraction of total parameters)

In [None]:
# ============================================================================
# QUANTIZATION CONFIGURATION
# ============================================================================

# BitsAndBytesConfig controls how the model is quantized:
# - load_in_4bit: Enable 4-bit quantization (reduces memory by ~75% vs FP32)
# - bnb_4bit_quant_type="nf4": Use NormalFloat 4-bit (optimal for normally distributed weights)
# - bnb_4bit_use_double_quant: Apply quantization to quantization constants (extra memory savings)
# - bnb_4bit_compute_dtype: Data type for computations (float16 for speed, bfloat16 for stability)

compute_dtype = torch.float16  # Use float16 for faster computation

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=compute_dtype
)

print("✓ Quantization configuration created")
print(f"  Quantization type: NF4 (4-bit NormalFloat)")
print(f"  Double quantization: Enabled")
print(f"  Compute dtype: {compute_dtype}")

In [None]:
# ============================================================================
# LOAD TOKENIZER
# ============================================================================

# The tokenizer converts text to tokens (numbers) that the model can process
# Important configurations:
# - padding_side="right": Pad sequences on the right (standard for causal LM)
# - add_eos_token: Automatically add end-of-sequence token (important for Gemma)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right",  # Right padding is standard for causal language models
    add_eos_token=True,  # Ensure EOS token is added for proper sequence termination
)

# Set the padding token to be the same as EOS token
# (Gemma models don't have a separate PAD token by default)
tokenizer.pad_token = tokenizer.eos_token

print("✓ Tokenizer loaded successfully")
print(f"  Vocabulary size: {len(tokenizer)}")
print(f"  EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
print(f"  PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")

In [None]:
# ============================================================================
# LOAD MODEL WITH QUANTIZATION
# ============================================================================

# Load the base Gemma model with 4-bit quantization
# This significantly reduces memory usage (4-bit vs 32-bit = 8x reduction)
# making it possible to fine-tune on consumer GPUs

print("Loading model... This may take a few minutes.")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,  # Apply 4-bit quantization
    device_map="auto",  # Automatically distribute model across available devices
    trust_remote_code=True,
    torch_dtype=compute_dtype,  # Use float16 for non-quantized layers
)

# Prepare model for k-bit training
# This function:
# 1. Freezes all base model weights (only LoRA adapters will be trained)
# 2. Enables gradient checkpointing to save memory
# 3. Prepares input embeddings for training
model = prepare_model_for_kbit_training(model)

# Enable gradient checkpointing for memory efficiency
# This trades compute for memory by recomputing activations during backward pass
model.config.use_cache = False  # Required for gradient checkpointing
model.gradient_checkpointing_enable()

print("✓ Model loaded successfully with 4-bit quantization")
print(f"  Model type: {model.config.model_type}")
print(f"  Number of parameters: {model.num_parameters() / 1e9:.2f}B")
print(f"  Device map: {model.hf_device_map}")

## 5. Configure LoRA Adapters

LoRA (Low-Rank Adaptation) works by adding small trainable matrices to specific layers of the frozen base model. This dramatically reduces the number of trainable parameters while maintaining performance.

In [None]:
# ============================================================================
# LORA CONFIGURATION
# ============================================================================

# LoRA configuration parameters:
# - r: Rank of the low-rank matrices (higher = more capacity but more parameters)
# - lora_alpha: Scaling factor (controls magnitude of LoRA updates)
# - target_modules: Which model layers to apply LoRA to
# - lora_dropout: Dropout for regularization
# - bias: Whether to train bias parameters ("none" is standard)
# - task_type: Type of task (CAUSAL_LM for text generation)

lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",  # Don't train bias parameters
    task_type="CAUSAL_LM",  # Causal language modeling task
)

# Apply LoRA configuration to the model
model = get_peft_model(model, lora_config)

# Print trainable parameters
# This shows the massive parameter reduction achieved by LoRA
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
trainable_percent = 100 * trainable_params / total_params

print("✓ LoRA adapters configured and applied")
print(f"  Target modules: {TARGET_MODULES}")
print(f"  LoRA rank (r): {LORA_R}")
print(f"  LoRA alpha: {LORA_ALPHA}")
print(f"  LoRA dropout: {LORA_DROPOUT}")
print(f"\n  Trainable parameters: {trainable_params:,} ({trainable_percent:.2f}% of total)")
print(f"  Total parameters: {total_params:,}")
print(f"  Memory savings: Training only {trainable_percent:.2f}% of parameters!")

## 6. Prepare Training Data with Gemma Prompt Format

**Critical:** Gemma models use a specific prompt format with special tokens:
- `<start_of_turn>user`: Indicates user input
- `<end_of_turn>`: Marks end of turn
- `<start_of_turn>model`: Indicates model output

Using the correct format is essential for optimal performance.

In [None]:
# ============================================================================
# GEMMA PROMPT FORMATTING FUNCTION
# ============================================================================

def format_prompt_gemma(sample):
    """
    Format a training sample using Gemma's conversation template.
    
    Gemma uses a turn-based conversation format:
    <start_of_turn>user
    {instruction}
    {input}
    <end_of_turn>
    <start_of_turn>model
    {output}<end_of_turn>
    
    Args:
        sample: Dictionary containing 'instruction', 'input', and 'output' keys
    
    Returns:
        Dictionary with formatted 'text' field
    """
    instruction = sample["instruction"]
    input_text = sample["input"]
    output_text = sample["output"]
    
    # Construct the full prompt using Gemma's format
    # The user turn contains both the instruction and the clinical notes
    # The model turn contains the expected summary output
    full_prompt = f"""<start_of_turn>user
{instruction}

Clinical Notes:
{input_text}<end_of_turn>
<start_of_turn>model
{output_text}<end_of_turn>"""
    
    return {"text": full_prompt}

# Apply formatting to both train and test datasets
train_dataset = train_dataset.map(format_prompt_gemma)
test_dataset = test_dataset.map(format_prompt_gemma)

print("✓ Dataset formatted with Gemma prompt template")
print("\nExample formatted prompt (truncated):")
print("=" * 80)
print(train_dataset[0]["text"][:500])
print("...")
print("=" * 80)

## 7. Training Configuration and Trainer Setup

Configure the training process using Hugging Face's `TrainingArguments` and the specialized `SFTTrainer` from the TRL library.

In [None]:
# ============================================================================
# TRAINING ARGUMENTS
# ============================================================================

# TrainingArguments control all aspects of the training loop:
# Memory optimization:
#   - per_device_train_batch_size: Batch size per GPU (keep low for memory)
#   - gradient_accumulation_steps: Accumulate gradients over N steps (simulates larger batch)
#   - gradient_checkpointing: Trade compute for memory
#   - fp16: Use mixed precision training (faster + less memory)
# 
# Optimization:
#   - learning_rate: Step size for parameter updates
#   - weight_decay: L2 regularization
#   - warmup_steps: Gradually increase LR at start of training
#   - lr_scheduler_type: How to adjust LR during training
#   - optim: Optimizer choice (adamw_torch is standard)
#
# Logging and checkpointing:
#   - logging_steps: How often to log metrics
#   - save_steps: How often to save model checkpoints
#   - evaluation_strategy: When to run evaluation

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    gradient_checkpointing=True,
    optim="adamw_torch",  # Standard AdamW optimizer
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,  # L2 regularization
    fp16=True,  # Mixed precision training (use bf16 if your GPU supports it)
    max_grad_norm=1.0,  # Gradient clipping to prevent exploding gradients
    warmup_steps=WARMUP_STEPS,
    lr_scheduler_type="cosine",  # Cosine learning rate schedule
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,  # Keep only the 3 most recent checkpoints
    evaluation_strategy="steps",
    eval_steps=SAVE_STEPS,
    do_eval=True,
    report_to="none",  # Disable wandb/tensorboard (can enable if you want tracking)
    push_to_hub=False,  # Don't push to Hugging Face Hub automatically
)

print("✓ Training arguments configured")
print(f"  Total training steps: ~{len(train_dataset) * NUM_EPOCHS // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Warmup steps: {WARMUP_STEPS}")

In [None]:
# ============================================================================
# CREATE SUPERVISED FINE-TUNING TRAINER
# ============================================================================

# SFTTrainer (Supervised Fine-Tuning Trainer) from TRL library is specifically
# designed for instruction fine-tuning of language models. It handles:
# - Proper packing of sequences
# - Masking of prompt tokens (only compute loss on completion)
# - Memory-efficient training with large sequence lengths

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=lora_config,
    dataset_text_field="text",  # Column containing formatted prompts
    max_seq_length=MAX_SEQ_LENGTH,
    tokenizer=tokenizer,
    args=training_args,
    packing=False,  # Set to True to pack multiple samples per sequence (more efficient but complex)
)

print("✓ SFTTrainer initialized successfully")
print("\nTrainer is ready to begin fine-tuning!")

## 8. Fine-Tune the Model

Now we train the model. This process will:
1. Iterate through the training data for `NUM_EPOCHS` epochs
2. Update only the LoRA adapter weights (not the base model)
3. Log training metrics periodically
4. Save checkpoints for recovery and evaluation

**Note:** Training time depends on your GPU and dataset size. For the sample data, this should complete in a few minutes.

In [None]:
# ============================================================================
# START TRAINING
# ============================================================================

print("Starting fine-tuning...\n")
print("This will train for {} epochs with:".format(NUM_EPOCHS))
print(f"  - {len(train_dataset)} training samples")
print(f"  - Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")
print(f"  - Learning rate: {LEARNING_RATE}")
print("\nMonitor the loss below. For good convergence, loss should decrease steadily.\n")
print("=" * 80)

# Train the model
# The trainer will handle:
# - Forward pass (compute predictions)
# - Loss computation (compare predictions to ground truth)
# - Backward pass (compute gradients)
# - Optimizer step (update LoRA weights)
# - Logging and checkpointing
training_output = trainer.train()

print("=" * 80)
print("\n✓ Training completed successfully!")
print(f"\nFinal training loss: {training_output.training_loss:.4f}")
print(f"Total training time: {training_output.metrics['train_runtime']:.2f} seconds")
print(f"Samples per second: {training_output.metrics['train_samples_per_second']:.2f}")

In [None]:
# ============================================================================
# SAVE THE FINE-TUNED MODEL
# ============================================================================

# Save the trained LoRA adapters
# Note: This saves only the adapter weights (~few MB), not the full model
# To use the model later, you'll load the base model + these adapters

output_dir_final = f"{OUTPUT_DIR}/final"
trainer.model.save_pretrained(output_dir_final)
tokenizer.save_pretrained(output_dir_final)

print(f"✓ Model saved to: {output_dir_final}")
print("\nThe saved files include:")
print("  - adapter_config.json: LoRA configuration")
print("  - adapter_model.bin: Trained LoRA weights")
print("  - tokenizer files")
print("\nTo load this model later, use:")
print(f"  model = AutoModelForCausalLM.from_pretrained('{MODEL_NAME}', ...)")
print(f"  model = PeftModel.from_pretrained(model, '{output_dir_final}')")

## 9. Clinical BERTScore Evaluation

**Why Clinical BERTScore?**

Traditional metrics like BLEU or ROUGE measure word overlap, which is insufficient for medical text where:
- Synonyms are common ("myocardial infarction" = "heart attack")
- Semantic equivalence matters more than exact wording
- Clinical accuracy is critical

**BERTScore** measures semantic similarity using contextual embeddings. By using **Bio_ClinicalBERT** (trained on clinical notes from MIMIC-III), we get embeddings that understand medical terminology and context.

**Interpretation:**
- Precision: How much of the generated summary is relevant?
- Recall: How much of the reference summary is captured? (Our primary metric for completeness)
- F1: Harmonic mean of precision and recall

Scores range from 0 to 1, with higher being better.

In [None]:
# ============================================================================
# INITIALIZE CLINICAL BERTSCORE
# ============================================================================

# Create a BERTScorer with Bio_ClinicalBERT as the backbone
# This model was trained on MIMIC-III clinical notes and understands medical language
#
# Important parameters:
# - model_type: The BERT model to use for embeddings
# - num_layers: Which layer's embeddings to use (9 is optimal for Bio_ClinicalBERT)
# - rescale_with_baseline: Normalize scores using baseline statistics
# - lang: Language (en for English)
# - device: GPU if available, else CPU

print("Initializing Clinical BERTScore...")
print("This will download emilyalsentzer/Bio_ClinicalBERT if not cached.\n")

clinical_scorer = BERTScorer(
    model_type="emilyalsentzer/Bio_ClinicalBERT",
    num_layers=9,  # Layer 9 has been found optimal for clinical text
    rescale_with_baseline=True,
    lang="en",
    device="cuda" if torch.cuda.is_available() else "cpu"
)

print("✓ Clinical BERTScore initialized")
print(f"  Model: emilyalsentzer/Bio_ClinicalBERT")
print(f"  Device: {clinical_scorer.device}")
print("\nThis model was trained on MIMIC-III clinical notes and understands:")
print("  - Medical terminology and abbreviations")
print("  - Clinical context and relationships")
print("  - Semantic equivalence in healthcare text")

In [None]:
# ============================================================================
# GENERATE PREDICTIONS ON TEST SET
# ============================================================================

print("Generating predictions on test set...\n")

# Put model in evaluation mode
model.eval()

predictions = []
references = []

# Generate predictions for each test example
for i, sample in enumerate(test_dataset):
    print(f"Generating summary {i+1}/{len(test_dataset)}...")
    
    # Extract the input (clinical notes)
    instruction = sample["instruction"]
    input_text = sample["input"]
    reference = sample["output"]
    
    # Format the prompt for inference (same format as training, but without the model's response)
    inference_prompt = f"""<start_of_turn>user
{instruction}

Clinical Notes:
{input_text}<end_of_turn>
<start_of_turn>model
"""
    
    # Tokenize the prompt
    inputs = tokenizer(inference_prompt, return_tensors="pt").to(model.device)
    
    # Generate the summary
    # Generation parameters are tuned for high recall (detailed outputs):
    # - max_new_tokens: Allow long summaries
    # - temperature: Control randomness (0.7 = moderate creativity)
    # - top_p: Nucleus sampling for diverse but coherent text
    # - do_sample: Enable sampling (vs greedy decoding)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            top_k=TOP_K,
            repetition_penalty=REPETITION_PENALTY,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode the generated tokens to text
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the generated summary (remove the prompt)
    # Find where the model's response starts
    model_response_marker = "<start_of_turn>model"
    if model_response_marker in generated_text:
        generated_summary = generated_text.split(model_response_marker)[-1].strip()
    else:
        generated_summary = generated_text[len(inference_prompt):].strip()
    
    predictions.append(generated_summary)
    references.append(reference)
    
    print(f"  Generated {len(generated_summary)} characters\n")

print("✓ All predictions generated")
print(f"  Total predictions: {len(predictions)}")

In [None]:
# ============================================================================
# COMPUTE CLINICAL BERTSCORE
# ============================================================================

print("Computing Clinical BERTScore...\n")
print("This metric measures semantic similarity using Bio_ClinicalBERT embeddings.")
print("Unlike BLEU/ROUGE (word overlap), BERTScore captures:")
print("  - Semantic equivalence (synonyms, paraphrases)")
print("  - Clinical context and medical terminology")
print("  - Conceptual similarity beyond surface form\n")

# Compute BERTScore
# Returns three tensors: Precision, Recall, F1
P, R, F1 = clinical_scorer.score(
    cands=predictions,  # Generated summaries
    refs=references,    # Reference summaries
)

# Convert to numpy for easier manipulation
precision_scores = P.cpu().numpy()
recall_scores = R.cpu().numpy()
f1_scores = F1.cpu().numpy()

# Compute averages
avg_precision = np.mean(precision_scores)
avg_recall = np.mean(recall_scores)
avg_f1 = np.mean(f1_scores)

print("=" * 80)
print("CLINICAL BERTSCORE RESULTS (using Bio_ClinicalBERT)")
print("=" * 80)
print(f"\nAverage Precision: {avg_precision:.4f}")
print("  → Measures: How much of the generated summary is clinically relevant?")
print("  → Interpretation: Higher = fewer irrelevant or hallucinated details\n")

print(f"Average Recall: {avg_recall:.4f}")
print("  → Measures: How much of the reference summary is captured?")
print("  → Interpretation: Higher = more complete, captures more medical entities")
print("  → THIS IS YOUR PRIMARY METRIC FOR HIGH RECALL!\n")

print(f"Average F1: {avg_f1:.4f}")
print("  → Measures: Harmonic mean of precision and recall")
print("  → Interpretation: Balanced measure of overall quality\n")

print("=" * 80)
print("\nPer-sample scores:")
for i in range(len(predictions)):
    print(f"\nSample {i+1}:")
    print(f"  Precision: {precision_scores[i]:.4f}")
    print(f"  Recall: {recall_scores[i]:.4f}")
    print(f"  F1: {f1_scores[i]:.4f}")

print("\n" + "=" * 80)
print("\nWHY CLINICAL BERTSCORE?")
print("=" * 80)
print("""
Standard metrics like BLEU and ROUGE only measure word overlap, which fails for medical text:

Example:
  Reference: "Patient had myocardial infarction with ST elevations"
  Candidate: "Patient experienced heart attack with ST segment elevation"
  
  BLEU/ROUGE: Low score (different words)
  Clinical BERTScore: High score (same medical meaning)

Bio_ClinicalBERT was pre-trained on 2 million clinical notes from MIMIC-III,
so it understands medical synonyms, abbreviations, and clinical context.

This makes BERTScore with Bio_ClinicalBERT the gold standard for evaluating
clinical text generation tasks like discharge summarization.
""")
print("=" * 80)

## 10. Qualitative Analysis

Let's examine the actual generated summaries to qualitatively assess how well the model captures medical entities and details.

In [None]:
# ============================================================================
# DISPLAY PREDICTIONS VS REFERENCES
# ============================================================================

print("=" * 80)
print("QUALITATIVE ANALYSIS: Generated vs Reference Summaries")
print("=" * 80)

for i in range(len(predictions)):
    print(f"\n{'=' * 80}")
    print(f"EXAMPLE {i+1}")
    print(f"{'=' * 80}\n")
    
    print("INPUT (Clinical Notes):")
    print("-" * 80)
    print(test_dataset[i]["input"][:500] + "...\n")  # Show first 500 chars
    
    print("REFERENCE SUMMARY:")
    print("-" * 80)
    print(references[i])
    print()
    
    print("GENERATED SUMMARY:")
    print("-" * 80)
    print(predictions[i])
    print()
    
    print("SCORES:")
    print("-" * 80)
    print(f"Precision: {precision_scores[i]:.4f}")
    print(f"Recall: {recall_scores[i]:.4f}")
    print(f"F1: {f1_scores[i]:.4f}")
    
print(f"\n{'=' * 80}")
print("END OF QUALITATIVE ANALYSIS")
print(f"{'=' * 80}\n")

print("""
EVALUATION CHECKLIST FOR HIGH RECALL:
□ Are all diagnoses mentioned?
□ Are all medications listed with dosages?
□ Are vital signs included?
□ Are abnormal lab results reported?
□ Are procedures and treatments described?
□ Are follow-up instructions present?
□ Is the timeline/hospital course clear?

If any of these are missing, consider:
1. Training for more epochs
2. Increasing MAX_NEW_TOKENS for generation
3. Using more training data
4. Adjusting prompt engineering to emphasize completeness
""")

## 11. Inference Function for New Clinical Notes

This function allows you to generate summaries for new clinical notes using your fine-tuned model.

In [None]:
# ============================================================================
# INFERENCE FUNCTION
# ============================================================================

def generate_discharge_summary(
    clinical_notes: str,
    instruction: str = "Summarize the following clinical discharge notes. Include all diagnoses, medications, vitals, and significant findings.",
    max_new_tokens: int = MAX_NEW_TOKENS,
    temperature: float = TEMPERATURE,
    top_p: float = TOP_P,
    top_k: int = TOP_K,
    repetition_penalty: float = REPETITION_PENALTY,
) -> str:
    """
    Generate a discharge summary from clinical notes using the fine-tuned model.
    
    Args:
        clinical_notes: Raw clinical notes as a string
        instruction: Task instruction (default is optimized for completeness)
        max_new_tokens: Maximum length of generated summary
        temperature: Sampling temperature (higher = more creative)
        top_p: Nucleus sampling parameter
        top_k: Top-K sampling parameter
        repetition_penalty: Penalty for repeating tokens
    
    Returns:
        Generated discharge summary as a string
    """
    
    # Format the prompt using Gemma's template
    inference_prompt = f"""<start_of_turn>user
{instruction}

Clinical Notes:
{clinical_notes}<end_of_turn>
<start_of_turn>model
"""
    
    # Tokenize
    inputs = tokenizer(inference_prompt, return_tensors="pt").to(model.device)
    
    # Generate
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repetition_penalty=repetition_penalty,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id,
        )
    
    # Decode
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract only the model's response
    model_response_marker = "<start_of_turn>model"
    if model_response_marker in generated_text:
        summary = generated_text.split(model_response_marker)[-1].strip()
    else:
        summary = generated_text[len(inference_prompt):].strip()
    
    return summary

print("✓ Inference function defined")
print("\nUsage example:")
print("  summary = generate_discharge_summary(clinical_notes)")
print("  print(summary)")

In [None]:
# ============================================================================
# TEST THE INFERENCE FUNCTION
# ============================================================================

# Example: Generate a summary for a new clinical note
sample_clinical_note = """45-year-old female with history of asthma presented to ED with acute dyspnea and wheezing.
Vitals: BP 118/76, HR 110, RR 28, O2 sat 89% on RA improved to 95% on 4L NC, Temp 98.4F.
Patient reports missed doses of controller inhaler. Exam notable for diffuse expiratory wheezes.
Peak flow 40% of predicted. Treated with continuous albuterol nebulizers, IV methylprednisolone 125mg,
and magnesium sulfate 2g IV. Clinical improvement noted within 2 hours. Transitioned to albuterol q4h.
Discharge on prednisone 40mg daily x 5 days, continue home fluticasone/salmeterol, albuterol PRN.
Follow-up with pulmonology in 1 week. Patient educated on importance of daily controller medication."""

print("Generating summary for new clinical note...\n")
print("=" * 80)

generated_summary = generate_discharge_summary(sample_clinical_note)

print("GENERATED DISCHARGE SUMMARY:")
print("=" * 80)
print(generated_summary)
print("=" * 80)

print("\n✓ Summary generated successfully!")
print("\nYou can now use this function to generate summaries for any new clinical notes.")

## 12. Saving and Loading Instructions

Important notes on how to save and reload your fine-tuned model for future use.

In [None]:
# ============================================================================
# HOW TO RELOAD YOUR FINE-TUNED MODEL
# ============================================================================

print("""
=============================================================================
SAVING AND LOADING YOUR FINE-TUNED MODEL
=============================================================================

Your fine-tuned model has been saved to: {}

This directory contains:
  - adapter_config.json: LoRA configuration
  - adapter_model.bin: Trained LoRA weights (~few MB)
  - Tokenizer files

TO RELOAD THE MODEL IN A NEW SESSION:
---------------------------------------------------------------------------

1. Install dependencies:
   pip install transformers peft bitsandbytes torch

2. Load the base model with quantization:
   
   from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
   from peft import PeftModel
   import torch
   
   # Quantization config
   bnb_config = BitsAndBytesConfig(
       load_in_4bit=True,
       bnb_4bit_quant_type="nf4",
       bnb_4bit_use_double_quant=True,
       bnb_4bit_compute_dtype=torch.float16
   )
   
   # Load base model
   base_model = AutoModelForCausalLM.from_pretrained(
       "{}",
       quantization_config=bnb_config,
       device_map="auto",
       trust_remote_code=True,
   )
   
   # Load LoRA adapters
   model = PeftModel.from_pretrained(base_model, "{}")
   
   # Load tokenizer
   tokenizer = AutoTokenizer.from_pretrained("{}")
   
3. Use the generate_discharge_summary() function defined above

ALTERNATIVE: Merge adapters into base model (for deployment)
---------------------------------------------------------------------------

If you want a standalone model without separate adapter files:

   # After loading model with PeftModel.from_pretrained():
   model = model.merge_and_unload()
   model.save_pretrained("./merged_model")
   tokenizer.save_pretrained("./merged_model")
   
   # This creates a single model with adapters merged in
   # Can be loaded like a regular Hugging Face model

=============================================================================
""".format(output_dir_final, MODEL_NAME, output_dir_final, output_dir_final))

## 13. Next Steps and Improvements

Recommendations for improving model performance and achieving higher recall.

In [None]:
print("""
=============================================================================
NEXT STEPS FOR IMPROVING HIGH RECALL PERFORMANCE
=============================================================================

1. DATA IMPROVEMENTS:
   □ Collect more training examples (aim for 1000+ samples)
   □ Ensure reference summaries are comprehensive and capture all entities
   □ Add data augmentation (paraphrasing, entity variations)
   □ Balance dataset across different clinical scenarios

2. PROMPT ENGINEERING:
   □ Experiment with more explicit instructions:
     "List ALL diagnoses, medications, vitals, lab results, procedures..."
   □ Add structured output format in prompt:
     "Include sections: Diagnoses, Medications, Vitals, Labs, Procedures..."
   □ Provide few-shot examples in the prompt

3. HYPERPARAMETER TUNING:
   □ Increase MAX_NEW_TOKENS (current: {}) to allow longer summaries
   □ Lower temperature (current: {}) for more deterministic outputs
   □ Train for more epochs if not overfitting
   □ Increase LoRA rank (current: {}) for more capacity

4. TRAINING IMPROVEMENTS:
   □ Use larger batch size if memory allows
   □ Implement custom loss that penalizes missing entities
   □ Add entity extraction as auxiliary task during training
   □ Use curriculum learning (easy → hard examples)

5. POST-PROCESSING:
   □ Add entity extraction to verify all entities are present
   □ Implement retrieval-augmented generation (RAG) to ensure completeness
   □ Use template-based post-processing to enforce structure

6. EVALUATION:
   □ Create entity-level recall metrics (diagnoses, meds, vitals)
   □ Manual clinical review by domain experts
   □ Compare against baseline models (GPT-4, Claude, etc.)
   □ A/B testing with clinicians

7. ALTERNATIVE APPROACHES:
   □ Try extractive + abstractive hybrid approach
   □ Use larger model (7B or 13B parameters)
   □ Fine-tune specialized medical models (BioGPT, ClinicalGPT)
   □ Multi-stage generation (extract entities → generate summary)

8. DEPLOYMENT CONSIDERATIONS:
   □ Implement confidence scores for generated summaries
   □ Add human-in-the-loop review system
   □ Monitor for hallucinations and factual errors
   □ Ensure HIPAA compliance and data privacy

=============================================================================

CURRENT CONFIGURATION SUMMARY:
  Model: {}
  LoRA Rank: {}
  Training Epochs: {}
  Max Generation Length: {} tokens
  Temperature: {}
  
  Clinical BERTScore Results:
    - Precision: {:.4f}
    - Recall: {:.4f} ← PRIMARY METRIC
    - F1: {:.4f}

TARGET RECALL: Aim for ≥0.90 for production use

=============================================================================
""".format(
    MAX_NEW_TOKENS,
    TEMPERATURE,
    LORA_R,
    MODEL_NAME,
    LORA_R,
    NUM_EPOCHS,
    MAX_NEW_TOKENS,
    TEMPERATURE,
    avg_precision,
    avg_recall,
    avg_f1
))

---

## Summary

This notebook demonstrated:

1. ✅ **Environment setup** with all required libraries
2. ✅ **Model loading** with QLoRA (4-bit quantization)
3. ✅ **LoRA configuration** for efficient fine-tuning
4. ✅ **Data formatting** with Gemma prompt template
5. ✅ **Training** with SFTTrainer
6. ✅ **Clinical BERTScore evaluation** using Bio_ClinicalBERT
7. ✅ **Inference function** for generating new summaries
8. ✅ **Comprehensive documentation** for your project report

**Key Takeaways:**
- QLoRA enables fine-tuning large models on consumer GPUs
- Clinical BERTScore is the appropriate metric for medical text
- High recall requires careful prompt engineering and sufficient training data
- The fine-tuned model can be easily saved and reloaded