# Clinical Discharge Summarization using MedGemma 4B with QLoRA

**Project Overview:**
This notebook demonstrates Parameter-Efficient Fine-Tuning (PEFT) using QLoRA on the MedGemma 4B model for clinical discharge summarization. The objective is to achieve **high recall** - generating detailed, verbose summaries that capture all medical entities (diagnoses, medications, vitals, abnormal lab results) from source clinical notes.

**Key Technologies:**
- Model: google/medgemma-4b (or base Gemma-4b)
- Technique: QLoRA (4-bit quantization)
- Evaluation: Clinical BERTScore using Bio_ClinicalBERT
- Platform: Google Colab / Consumer GPUs

## 1. Environment Setup

First, we install all necessary libraries for model loading, quantization, fine-tuning, and evaluation.

In [1]:
# Install required libraries
"""
!pip install -q -U transformers
!pip install -q -U peft
!pip install -q -U bitsandbytes
!pip install -q -U trl
!pip install -q -U accelerate
!pip install -q -U datasets
!pip install -q -U bert_score
!pip install -q -U scipy
!pip install -q -U hf-xet
# Required for Gemma model architecture
!pip install -q -U einops
!pip install -q torch==2.9.1+cu130 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu130
"""
print("✓ All libraries installed successfully!")

✓ All libraries installed successfully!


In [2]:
import warnings

import pandas as pd
import torch
from bert_score import BERTScorer
from datasets import Dataset
from peft import (
    LoraConfig,
    prepare_model_for_kbit_training
)
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    BitsAndBytesConfig
)
from trl import SFTTrainer

warnings.filterwarnings('ignore')

# Check GPU availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch version: 2.9.0+cu126
CUDA available: True
GPU: NVIDIA A100-SXM4-80GB
GPU Memory: 85.17 GB


## 2. Configuration and Hyperparameters

Define all model paths, LoRA parameters, and training hyperparameters in one place for easy modification.

In [3]:
# ============================================================================
# MODEL CONFIGURATION
# ============================================================================

# NOTE: If google/medgemma-4b is not publicly available, use "google/gemma-2-4b-it"
MODEL_NAME = "google/medgemma-4b-it"  # Update to "google/medgemma-4b" when available

# ============================================================================
# LORA CONFIGURATION
# ============================================================================
# These parameters control the LoRA adapter architecture:
# - r (rank): The dimensionality of the low-rank matrices. Higher = more parameters = better fit but more memory
# - lora_alpha: Scaling factor for LoRA updates. Higher alpha = larger learning rate for LoRA weights
# - lora_dropout: Dropout probability for LoRA layers to prevent overfitting

LORA_R = 32  # Rank of 32 provides good balance between performance and memory
LORA_ALPHA = 64  # Alpha = 2*r is a common heuristic
LORA_DROPOUT = 0.05  # Small dropout for regularization

# Target modules for Gemma architecture
# These are the attention and MLP projection layers where LoRA adapters will be inserted
# Gemma uses a standard transformer architecture with:
# - q_proj, k_proj, v_proj: Query, Key, Value projections in attention
# - o_proj: Output projection after attention
# - gate_proj, up_proj, down_proj: MLP layers (Gemma uses SwiGLU activation)
TARGET_MODULES = [
    "q_proj",
    "k_proj",
    "v_proj",
    "o_proj",
    "gate_proj",
    "up_proj",
    "down_proj"
]

# ============================================================================
# TRAINING HYPERPARAMETERS
# ============================================================================
# Changed to 1 by Bryan: we want to avoid memorization as much as possible
NUM_EPOCHS = 1
BATCH_SIZE = 1  # Per device batch size (increase if you have more VRAM)
GRADIENT_ACCUMULATION_STEPS = 8  # Effective batch size = 8
LEARNING_RATE = 2e-4  # Standard learning rate for LoRA fine-tuning
MAX_SEQ_LENGTH = 2048  # Maximum sequence length (Gemma supports up to 8192, but we use 2048 for memory efficiency)
WARMUP_STEPS = 100  # Warmup steps for learning rate scheduler
LOGGING_STEPS = 10  # Log training metrics every N steps
SAVE_STEPS = 100  # Save checkpoint every N steps

# ============================================================================
# GENERATION PARAMETERS
# ============================================================================
MAX_NEW_TOKENS = 512  # Allow longer summaries to capture all details
TEMPERATURE = 0.7  # Moderate temperature for balance between creativity and coherence
TOP_P = 0.9  # Nucleus sampling for diverse but relevant outputs
TOP_K = 50  # Top-K sampling
REPETITION_PENALTY = 1.1  # Slight penalty to avoid repetitive text

print("✓ Configuration loaded successfully!")
print(f"  Model: {MODEL_NAME}")
print(f"  LoRA Rank: {LORA_R}, Alpha: {LORA_ALPHA}")
print(
    f"  Training: {NUM_EPOCHS} epochs, Batch Size: {BATCH_SIZE}, Gradient Accumulation: {GRADIENT_ACCUMULATION_STEPS}")
print(f"  Effective Batch Size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")

✓ Configuration loaded successfully!
  Model: google/medgemma-4b-it
  LoRA Rank: 32, Alpha: 64
  Training: 1 epochs, Batch Size: 1, Gradient Accumulation: 8
  Effective Batch Size: 8


## 2A. Google Colab Setup (OPTIONAL - Only for Colab Users)

**Use this section ONLY if you're running this notebook on Google Colab**

This section will:
1. Mount your Google Drive
2. Set the path to your dataset in Google Drive
3. Verify GPU availability

**Instructions:**
- If running on **Google Colab**, run the cells below
- If running **locally**, skip this entire section and go directly to Section 3

In [5]:
# ============================================================================
# MOUNT GOOGLE DRIVE (COLAB ONLY)
# ============================================================================

# This cell will mount your Google Drive to access your dataset
# You'll be prompted to authorize access to your Google Drive

try:
    from google.colab import drive

    # Mount Google Drive at /content/drive
    drive.mount('/content/drive')
    OUTPUT_DIR = "/content/drive/MyDrive/Colab Notebooks/Gen AI/Semester Project/medgemma-discharge-summarization/"

    print("✓ Google Drive mounted successfully!")
    print("  Your Drive is accessible at: /content/drive/MyDrive/")
    print("\nYou can now access files from your Google Drive.")

    IS_COLAB = True

except ImportError:
    print("⚠ Not running on Google Colab - skipping Drive mount")
    print("  If you're running locally, this is expected. Skip to Section 3.")
    IS_COLAB = False
    OUTPUT_DIR = "./medgemma-discharge-summarization"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
✓ Google Drive mounted successfully!
  Your Drive is accessible at: /content/drive/MyDrive/

You can now access files from your Google Drive.


## 3. Load and Prepare Dataset

This section provides **THREE** options for loading data:

### **Option A (Colab): Load from Google Drive**
- **Use if**: Running on Google Colab with dataset in Google Drive
- **Section**: 3A (Colab) below

### **Option B (Local): Load from Local File**
- **Use if**: Running locally with `mimic_cleaned_text_only.csv` in project directory
- **Section**: 3A (Local) below

**Instructions:**
- **Colab users**: Run Section 2A first, then use Section 3A (Colab)
- **Local users**: Skip Section 2A, use Section 3A (Local)

## 3A (Colab). Load Dataset from Google Drive

**Use this section if you're running on Google Colab and have your dataset in Google Drive**

This will load your MIMIC dataset directly from your Google Drive.

**Setup Instructions:**
1. Upload `mimic_cleaned_text_only.csv` to your Google Drive
2. Update the `DRIVE_DATASET_PATH` below with the correct path
3. Common paths:
   - `"/content/drive/MyDrive/mimic_cleaned_text_only.csv"` (root of My Drive)

In [6]:
# ============================================================================
# VERIFY GPU AND SETUP (COLAB ONLY)
# ============================================================================

# Check if running on Colab and verify GPU setup
if IS_COLAB:
    import subprocess

    print("Checking GPU availability on Colab...\n")

    # Run nvidia-smi to check GPU
    try:
        gpu_info = subprocess.check_output(['nvidia-smi'], encoding='utf-8')
        print(gpu_info)
        print("GPU is available!")
        print("\nIMPORTANT: Make sure you're using a GPU runtime:")
        print("Runtime → Change runtime type → Hardware accelerator → GPU (T4 or better recommended)")
    except:
        print("No GPU detected!")
        print("\nYou MUST enable GPU for this notebook:")
        print("  1. Go to Runtime → Change runtime type")
        print("  2. Set Hardware accelerator to 'GPU'")
        print("  3. Click Save")
        print("  4. Restart the runtime")

    # Check RAM
    import psutil

    ram_gb = psutil.virtual_memory().total / 1e9
    print(f"\nAvailable RAM: {ram_gb:.2f} GB")

    if ram_gb < 12:
        print("WARNING: Low RAM detected. Consider using Colab Pro for High-RAM runtime.")
else:
    print("Skipping Colab-specific checks (running locally)")

Checking GPU availability on Colab...

Tue Dec  9 22:18:23 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   31C    P0             51W /  400W |       5MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
         

In [7]:
# ============================================================================
# LOAD MIMIC DATASET FROM GOOGLE DRIVE (COLAB ONLY)
# ============================================================================

# IMPORTANT: Only run this cell if you're on Google Colab
# Update the path below to match where you uploaded your dataset in Google Drive

if IS_COLAB:
    import os

    # ========================================================================
    # CONFIGURE THIS PATH TO MATCH YOUR GOOGLE DRIVE STRUCTURE
    # ========================================================================
    # Update this to the actual path where you uploaded your CSV file in Google Drive
    DRIVE_DATASET_PATH = "/content/drive/MyDrive/Colab Notebooks/Gen AI/Semester Project/mimic_cleaned_text_only.csv"

    # Check if file exists
    if os.path.exists(DRIVE_DATASET_PATH):
        print(f"Loading MIMIC dataset from Google Drive...")
        print(f"Path: {DRIVE_DATASET_PATH}\n")

        # Load the CSV file using pandas
        mimic_df = pd.read_csv(DRIVE_DATASET_PATH)
        mimic_df = mimic_df[:40_000]

        print(f" Dataset loaded successfully from Google Drive!")
        print(f"  Total samples: {len(mimic_df)}")
        print(f"  Columns: {list(mimic_df.columns)}\n")

        # Display basic statistics
        print("Dataset Statistics:")
        print(f"  Average input length: {mimic_df['final_input'].str.len().mean():.0f} characters")
        print(f"  Average target length: {mimic_df['final_target'].str.len().mean():.0f} characters")
        print(f"  Minimum input length: {mimic_df['final_input'].str.len().min():.0f} characters")
        print(f"  Maximum input length: {mimic_df['final_input'].str.len().max():.0f} characters")

        # Add instruction column emphasizing HIGH RECALL
        instruction_text = "Summarize the following clinical discharge notes. Include ALL diagnoses, medications, vitals, lab results, procedures, and follow-up instructions. Ensure complete coverage of all medical entities."
        mimic_df['instruction'] = instruction_text

        # Rename columns to match expected format
        mimic_df = mimic_df.rename(columns={
            'final_input': 'input',
            'final_target': 'output'
        })

        # Remove rows with missing data
        initial_count = len(mimic_df)
        mimic_df = mimic_df.dropna(subset=['input', 'output'])
        dropped_count = initial_count - len(mimic_df)

        if dropped_count > 0:
            print(f"\n Removed {dropped_count} rows with missing data")

        # Convert to Hugging Face Dataset
        dataset = Dataset.from_pandas(mimic_df[['instruction', 'input', 'output']])

        # Split into train and test sets (90/10 split)
        dataset = dataset.train_test_split(test_size=0.05, seed=42)
        train_dataset = dataset["train"]
        test_dataset = dataset["test"]

        print(f"\n Dataset prepared and split!")
        print(f"  Training samples: {len(train_dataset)}")
        print(f"  Test samples: {len(test_dataset)}")

        # Display a sample
        print(f"\n{'=' * 80}")
        print("SAMPLE TRAINING EXAMPLE:")
        print(f"{'=' * 80}\n")
        print(f"Instruction: {train_dataset[0]['instruction'][:150]}...")
        print(f"\nInput (first 300 chars):\n{train_dataset[0]['input'][:300]}...")
        print(f"\nOutput (first 300 chars):\n{train_dataset[0]['output'][:300]}...")
        print(f"\n{'=' * 80}")

        print("\n MIMIC dataset loaded from Google Drive!")
        print("  You can now skip sections 3A (Local) and 3B (Sample Data)")
        print("  Proceed to Section 4 (Load Model with 4-bit Quantization)")

    else:
        print(f"File not found at: {DRIVE_DATASET_PATH}")
        print(f"\nPlease check:")
        print(f"  1. Is the file uploaded to your Google Drive?")
        print(f"  2. Is the path correct?")
        print(f"  3. Did you mount Google Drive (run Section 2A)?")
        print(f"\nTo find the correct path:")
        print(f"  1. In the left sidebar, click the folder icon")
        print(f"  2. Navigate to drive/MyDrive/")
        print(f"  3. Find your CSV file")
        print(f"  4. Right-click → Copy path")
        print(f"  5. Update DRIVE_DATASET_PATH above")

else:
    print("Not running on Google Colab - skipping Google Drive dataset loading")
    print("Use Section 3A (Local) or 3B (Sample Data) instead")

Loading MIMIC dataset from Google Drive...
Path: /content/drive/MyDrive/Colab Notebooks/Gen AI/Semester Project/mimic_cleaned_text_only.csv

 Dataset loaded successfully from Google Drive!
  Total samples: 40000
  Columns: ['final_input', 'final_target']

Dataset Statistics:
  Average input length: 2580 characters
  Average target length: 1207 characters
  Minimum input length: 237 characters
  Maximum input length: 14963 characters

 Dataset prepared and split!
  Training samples: 38000
  Test samples: 2000

SAMPLE TRAINING EXAMPLE:

Instruction: Summarize the following clinical discharge notes. Include ALL diagnoses, medications, vitals, lab results, procedures, and follow-up instructions. Ens...

Input (first 300 chars):
summarize chief complaint respiratory distress hypotension acute elevation cardiac biomarkers history present illness yo man pmh significant severe copd home htn initially presented dyspnearespiratory distress patient called neighbor called em em found tripod positi

## 3B. Load The Dataset

**Use this section if you have the `mimic_cleaned_text_only.csv` file**

This loads your actual MIMIC clinical discharge dataset with the correct column mappings:
- `final_input` → clinical notes
- `final_target` → reference summaries

In [21]:
import os

# Path to your MIMIC dataset CSV file
# Adjust this path if your file is located elsewhere
MIMIC_CSV_PATH = "mimic_cleaned_text_only.csv"

# Check if the file exists
if os.path.exists(MIMIC_CSV_PATH):
    print(f"Loading MIMIC dataset from: {MIMIC_CSV_PATH}\n")

    # Load the CSV file using pandas
    # The file should have two columns: final_input and final_target
    mimic_df = pd.read_csv(MIMIC_CSV_PATH)

    mimic_df = mimic_df[:20_000]

    print(f"✓ Dataset loaded successfully!")
    print(f"  Total samples: {len(mimic_df)}")
    print(f"  Columns: {list(mimic_df.columns)}\n")

    # Display basic statistics
    print("Dataset Statistics:")
    print(f"  Average input length: {mimic_df['final_input'].str.len().mean():.0f} characters")
    print(f"  Average target length: {mimic_df['final_target'].str.len().mean():.0f} characters")
    print(f"  Minimum input length: {mimic_df['final_input'].str.len().min():.0f} characters")
    print(f"  Maximum input length: {mimic_df['final_input'].str.len().max():.0f} characters")

    # Add a consistent instruction column
    # This instruction emphasizes HIGH RECALL - capturing all medical details
    instruction_text = "Summarize the following clinical discharge notes. Include ALL diagnoses, medications, vitals, lab results, procedures, and follow-up instructions. Ensure complete coverage of all medical entities."
    mimic_df['instruction'] = instruction_text

    # Rename columns to match the expected format
    # final_input → input (clinical notes)
    # final_target → output (reference summary)
    mimic_df = mimic_df.rename(columns={
        'final_input': 'input',
        'final_target': 'output'
    })

    # Remove any rows with missing data
    initial_count = len(mimic_df)
    mimic_df = mimic_df.dropna(subset=['input', 'output'])
    dropped_count = initial_count - len(mimic_df)

    if dropped_count > 0:
        print(f"\nRemoved {dropped_count} rows with missing data")

    # Convert to Hugging Face Dataset
    dataset = Dataset.from_pandas(mimic_df[['instruction', 'input', 'output']])

    # Split into train and test sets
    dataset = dataset.train_test_split(test_size=0.05, seed=42)
    train_dataset = dataset["train"]
    test_dataset = dataset["test"]

    print(f"\n✓ Dataset prepared and split!")
    print(f"  Training samples: {len(train_dataset)}")
    print(f"  Test samples: {len(test_dataset)}")

    # Display a sample from the training set
    print(f"\n{'=' * 80}")
    print("SAMPLE TRAINING EXAMPLE:")
    print(f"{'=' * 80}\n")
    print(f"Instruction: {train_dataset[0]['instruction'][:150]}...")
    print(f"\nInput (first 300 chars):\n{train_dataset[0]['input'][:300]}...")
    print(f"\nOutput (first 300 chars):\n{train_dataset[0]['output'][:300]}...")
    print(f"\n{'=' * 80}")

    print("\n✓ MIMIC dataset loaded successfully! You can now skip Section 3B.")
    print("  Proceed to Section 4 (Load Model with 4-bit Quantization)")

else:
    print(f"⚠ File not found: {MIMIC_CSV_PATH}")
    print(f"\nPlease either:")
    print(f"  1. Place the mimic_cleaned_text_only.csv file in the current directory")
    print(f"  2. Update MIMIC_CSV_PATH variable with the correct file path")
    print(f"  3. Skip to Section 3B to use sample data instead\n")
    print(f"Current directory: {os.getcwd()}")

⚠ File not found: mimic_cleaned_text_only.csv

Please either:
  1. Place the mimic_cleaned_text_only.csv file in the current directory
  2. Update MIMIC_CSV_PATH variable with the correct file path
  3. Skip to Section 3B to use sample data instead

Current directory: /content


## 4. Load Model with 4-bit Quantization (QLoRA)

QLoRA (Quantized LoRA) enables fine-tuning large models on consumer GPUs by:
1. Loading the base model in 4-bit precision (NormalFloat 4-bit)
2. Using double quantization to further reduce memory
3. Computing gradients in float16 for numerical stability
4. Training only LoRA adapter weights (a small fraction of total parameters)

In [8]:
compute_dtype = torch.float16  # Use float16 for faster computation

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=compute_dtype
)

print("✓ Quantization configuration created")
print(f"  Quantization type: NF4 (4-bit NormalFloat)")
print(f"  Double quantization: Enabled")
print(f"  Compute dtype: {compute_dtype}")

✓ Quantization configuration created
  Quantization type: NF4 (4-bit NormalFloat)
  Double quantization: Enabled
  Compute dtype: torch.float16


In [9]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [10]:
# - add_eos_token: Automatically add end-of-sequence token (important for Gemma)

tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME,
    trust_remote_code=True,
    padding_side="right",  # Right padding is standard for causal language models
    add_eos_token=True,  # Ensure EOS token is added for proper sequence termination
)

# Set the padding token to be the same as EOS token
# (Gemma models don't have a separate PAD token by default)
tokenizer.pad_token = tokenizer.eos_token

print("✓ Tokenizer loaded successfully")
print(f"  Vocabulary size: {len(tokenizer)}")
print(f"  EOS token: {tokenizer.eos_token} (ID: {tokenizer.eos_token_id})")
print(f"  PAD token: {tokenizer.pad_token} (ID: {tokenizer.pad_token_id})")

✓ Tokenizer loaded successfully
  Vocabulary size: 262145
  EOS token: <eos> (ID: 1)
  PAD token: <eos> (ID: 1)


In [11]:
# ============================================================================
# LOAD MODEL WITH QUANTIZATION
# ============================================================================

print("Loading model... This may take a few minutes.")

model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,  # Apply 4-bit quantization
    device_map="auto",  # Automatically distribute model across available devices
    trust_remote_code=True,
    dtype=compute_dtype,  # Use float16 for non-quantized layers
)

# Prepare model for k-bit training
# This function:
# 1. Freezes all base model weights
# 2. Enables gradient checkpointing to save memory
# 3. Prepares input embeddings for training
model = prepare_model_for_kbit_training(model)

# Enable gradient checkpointing for memory efficiency
# This trades compute for memory by recomputing activations during backward pass
model.config.use_cache = False  # Required for gradient checkpointing
model.gradient_checkpointing_enable()

print("  Model loaded successfully with 4-bit quantization")
print(f"  Model type: {model.config.model_type}")
print(f"  Number of parameters: {model.num_parameters() / 1e9:.2f}B")
print(f"  Device map: {model.hf_device_map}")

Loading model... This may take a few minutes.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  Model loaded successfully with 4-bit quantization
  Model type: gemma3
  Number of parameters: 4.30B
  Device map: {'': 0}


## 5. Configure LoRA Adapters

LoRA (Low-Rank Adaptation) works by adding small trainable matrices to specific layers of the frozen base model. This dramatically reduces the number of trainable parameters while maintaining performance.

In [12]:
# ============================================================================
# LORA CONFIGURATION
# ============================================================================

# LoRA configuration parameters:
# - r: Rank of the low-rank matrices (higher = more capacity but more parameters)
# - lora_alpha: Scaling factor (controls magnitude of LoRA updates)
# - target_modules: Which model layers to apply LoRA to
# - lora_dropout: Dropout for regularization
# - bias: Whether to train bias parameters ("none" is standard)
# - task_type: Type of task (CAUSAL_LM for text generation)

lora_config = LoraConfig(
    r=LORA_R,
    lora_alpha=LORA_ALPHA,
    target_modules=TARGET_MODULES,
    lora_dropout=LORA_DROPOUT,
    bias="none",  # Don't train bias parameters
    task_type="CAUSAL_LM",  # Causal language modeling task
)

# NOTE: We will NOT apply LoRA here with get_peft_model()
# Instead, we'll pass lora_config to SFTTrainer, which will handle it
# This is required for newer versions of TRL

print("  LoRA configuration created")
print(f"  Target modules: {TARGET_MODULES}")
print(f"  LoRA rank (r): {LORA_R}")
print(f"  LoRA alpha: {LORA_ALPHA}")
print(f"  LoRA dropout: {LORA_DROPOUT}")

  LoRA configuration created
  Target modules: ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
  LoRA rank (r): 32
  LoRA alpha: 64
  LoRA dropout: 0.05


## 6. Prepare Training Data with Gemma Prompt Format

**Critical:** Gemma models use a specific prompt format with special tokens:
- `<start_of_turn>user`: Indicates user input
- `<end_of_turn>`: Marks end of turn
- `<start_of_turn>model`: Indicates model output

Using the correct format is essential for optimal performance.

In [13]:
# ============================================================================
# GEMMA PROMPT FORMATTING FUNCTION
# ============================================================================

def format_prompt_gemma(sample):
    """
    Format a training sample using Gemma's conversation template.

    Gemma uses a turn-based conversation format:
    <start_of_turn>user
    {instruction}
    {input}
    <end_of_turn>
    <start_of_turn>model
    {output}<end_of_turn>

    Args:
        sample: Dictionary containing 'instruction', 'input', and 'output' keys

    Returns:
        Dictionary with formatted 'text' field
    """
    instruction = sample["instruction"]
    input_text = sample["input"]
    output_text = sample["output"]

    # Construct the full prompt using Gemma's format
    # The user turn contains both the instruction and the clinical notes
    # The model turn contains the expected summary output
    full_prompt = f"""<start_of_turn>user
{instruction}

Clinical Notes:
{input_text}<end_of_turn>
<start_of_turn>model
{output_text}<end_of_turn>"""

    return {"text": full_prompt}


# Apply formatting to both train and test datasets
train_dataset = train_dataset.map(format_prompt_gemma)
test_dataset = test_dataset.map(format_prompt_gemma)

print("✓ Dataset formatted with Gemma prompt template")
print("\nExample formatted prompt (truncated):")
print("=" * 80)
print(train_dataset[0]["text"][:500])
print("...")
print("=" * 80)

Map:   0%|          | 0/38000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

✓ Dataset formatted with Gemma prompt template

Example formatted prompt (truncated):
<start_of_turn>user
Summarize the following clinical discharge notes. Include ALL diagnoses, medications, vitals, lab results, procedures, and follow-up instructions. Ensure complete coverage of all medical entities.

Clinical Notes:
summarize chief complaint respiratory distress hypotension acute elevation cardiac biomarkers history present illness yo man pmh significant severe copd home htn initially presented dyspnearespiratory distress patient called neighbor called em em found tripod positi
...


## 7. Training Configuration and Trainer Setup

Configure the training process using Hugging Face's `TrainingArguments` and the specialized `SFTTrainer` from the TRL library.

In [14]:
# ============================================================================
# TRAINING ARGUMENTS
# ============================================================================

from transformers import TrainingArguments

# TrainingArguments control all aspects of the training loop:
# Memory optimization:
#   - gradient_accumulation_steps: Accumulate gradients over N steps (simulates larger batch)
#   - gradient_checkpointing: Trade compute for memory
#   - fp16: Use mixed precision training (faster + less memory)
#
# Optimization:
#   - learning_rate: Step size for parameter updates
#   - weight_decay: L2 regularization
#   - warmup_steps: Gradually increase LR at start of training
#   - lr_scheduler_type: How to adjust LR during training
#   - optim: Optimizer choice (adamw_torch is standard)

training_args = TrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=NUM_EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
    gradient_checkpointing=True,
    optim="adamw_torch",  # Standard AdamW optimizer
    learning_rate=LEARNING_RATE,
    weight_decay=0.01,
    bf16=True,
    max_grad_norm=1.0,  # Gradient clipping to prevent exploding gradients
    warmup_steps=WARMUP_STEPS,
    lr_scheduler_type="cosine",  # Cosine learning rate schedule
    logging_steps=LOGGING_STEPS,
    save_steps=SAVE_STEPS,
    save_total_limit=3,  # Keep only the 3 most recent checkpoints
    eval_strategy="steps",
    eval_steps=SAVE_STEPS,
    do_eval=True,
    report_to="none",  # Disable wandb/tensorboard (can enable if you want tracking)
    push_to_hub=False,  # Don't push to Hugging Face Hub automatically
)

print(" Training arguments configured")
print(f"  Total training steps: ~{len(train_dataset) * NUM_EPOCHS // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)}")
print(f"  Effective batch size: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"  Learning rate: {LEARNING_RATE}")
print(f"  Warmup steps: {WARMUP_STEPS}")

 Training arguments configured
  Total training steps: ~4750
  Effective batch size: 8
  Learning rate: 0.0002
  Warmup steps: 100


In [15]:
# ============================================================================
# CREATE SUPERVISED FINE-TUNING TRAINER
# ============================================================================

# SFTTrainer from TRL library is specifically
# designed for instruction fine-tuning of language models. It handles:
# - Proper formatting of training data
# - Automatic application of PEFT/LoRA adapters
# - Memory-efficient training with large sequence lengths

# Update training arguments to include max_seq_length
training_args.max_seq_length = MAX_SEQ_LENGTH


# Define formatting function to extract the 'text' field
def formatting_func(example):
    """
    Extract the formatted text from the dataset.
    Our dataset already has a 'text' column with Gemma-formatted prompts.
    """
    return example["text"]  # Return as list for batch processing


trainer = SFTTrainer(
    model=model,  # Pass the base model (before PEFT was applied)
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    peft_config=lora_config,  # SFTTrainer will apply LoRA adapters
    args=training_args,
    processing_class=tokenizer,  # Explicitly pass tokenizer to avoid processor issues
    formatting_func=formatting_func,  # Function to extract text from examples
)

print("  SFTTrainer initialized successfully")
print(f"  Using max sequence length: {MAX_SEQ_LENGTH}")
print(f"  LoRA adapters applied automatically by SFTTrainer")

# Print trainable parameters (now that PEFT has been applied by SFTTrainer)
trainable_params = sum(p.numel() for p in trainer.model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in trainer.model.parameters())
trainable_percent = 100 * trainable_params / total_params

print(f"\n  Trainable parameters: {trainable_params:,} ({trainable_percent:.2f}% of total)")
print(f"  Total parameters: {total_params:,}")
print(f"  Memory savings: Training only {trainable_percent:.2f}% of parameters!")

print("\nTrainer is ready to begin fine-tuning!")

Applying formatting function to train dataset:   0%|          | 0/38000 [00:00<?, ? examples/s]

Adding EOS to train dataset:   0%|          | 0/38000 [00:00<?, ? examples/s]

Tokenizing train dataset:   0%|          | 0/38000 [00:00<?, ? examples/s]

Truncating train dataset:   0%|          | 0/38000 [00:00<?, ? examples/s]

Applying formatting function to eval dataset:   0%|          | 0/2000 [00:00<?, ? examples/s]

Adding EOS to eval dataset:   0%|          | 0/2000 [00:00<?, ? examples/s]

Tokenizing eval dataset:   0%|          | 0/2000 [00:00<?, ? examples/s]

Truncating eval dataset:   0%|          | 0/2000 [00:00<?, ? examples/s]

  SFTTrainer initialized successfully
  Using max sequence length: 2048
  LoRA adapters applied automatically by SFTTrainer

  Trainable parameters: 65,576,960 (2.57% of total)
  Total parameters: 2,555,799,920
  Memory savings: Training only 2.57% of parameters!

Trainer is ready to begin fine-tuning!


## 8. Fine-Tune the Model

Now we train the model. This process will:
1. Iterate through the training data for `NUM_EPOCHS` epochs
2. Update only the LoRA adapter weights (not the base model)
3. Log training metrics periodically
4. Save checkpoints for recovery and evaluation

**Note:** Training time depends on your GPU and dataset size. For the sample data, this should complete in a few minutes.

In [16]:
# ============================================================================
# START TRAINING
# ============================================================================

print("Starting fine-tuning...\n")
print("This will train for {} epochs with:".format(NUM_EPOCHS))
print(f"  - {len(train_dataset)} training samples")
print(f"  - Batch size: {BATCH_SIZE} (effective: {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")
print(f"  - Learning rate: {LEARNING_RATE}")
print("\nMonitor the loss below. For good convergence, loss should decrease steadily.\n")
print("=" * 80)

# Train the model
# The trainer will handle:
# - Forward pass (compute predictions)
# - Loss computation (compare predictions to ground truth)
# - Backward pass (compute gradients)
# - Optimizer step (update LoRA weights)
# - Logging and checkpointing
training_output = trainer.train()

print("=" * 80)
print("\n  Training completed successfully!")
print(f"\nFinal training loss: {training_output.training_loss:.4f}")
print(f"Total training time: {training_output.metrics['train_runtime']:.2f} seconds")
print(f"Samples per second: {training_output.metrics['train_samples_per_second']:.2f}")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 1}.


Starting fine-tuning...

This will train for 1 epochs with:
  - 38000 training samples
  - Batch size: 1 (effective: 8)
  - Learning rate: 0.0002

Monitor the loss below. For good convergence, loss should decrease steadily.



Step,Training Loss,Validation Loss,Entropy,Num Tokens,Mean Token Accuracy
100,2.6271,2.751517,2.843249,549583.0,0.529274
200,2.5006,2.526204,2.556307,1093350.0,0.558953
300,2.2617,2.418662,2.393801,1650397.0,0.573743
400,2.3538,2.364438,2.287739,2193041.0,0.580509
500,2.229,2.317925,2.334433,2744977.0,0.58601
600,2.2344,2.28521,2.271238,3289212.0,0.590102
700,2.2852,2.25522,2.231842,3840295.0,0.593936
800,2.3493,2.235487,2.250539,4387801.0,0.595337
900,2.1742,2.210688,2.204488,4937220.0,0.599104
1000,2.1955,2.194798,2.221955,5484043.0,0.601481



  Training completed successfully!

Final training loss: 2.1071
Total training time: 37580.10 seconds
Samples per second: 1.01


In [17]:
# ============================================================================
# SAVE THE FINE-TUNED MODEL
# ============================================================================

# Save the trained LoRA adapters

output_dir_final = f"{OUTPUT_DIR}/final"
trainer.model.save_pretrained(output_dir_final)
tokenizer.save_pretrained(output_dir_final)

print(f"✓ Model saved to: {output_dir_final}")
print("\nThe saved files include:")
print("  - adapter_config.json: LoRA configuration")
print("  - adapter_model.bin: Trained LoRA weights")
print("  - tokenizer files")
print("\nTo load this model later, use:")
print(f"  model = AutoModelForCausalLM.from_pretrained('{MODEL_NAME}', ...)")
print(f"  model = PeftModel.from_pretrained(model, '{output_dir_final}')")

✓ Model saved to: /content/drive/MyDrive/Colab Notebooks/Gen AI/Semester Project/medgemma-discharge-summarization//final

The saved files include:
  - adapter_config.json: LoRA configuration
  - adapter_model.bin: Trained LoRA weights
  - tokenizer files

To load this model later, use:
  model = AutoModelForCausalLM.from_pretrained('google/medgemma-4b-it', ...)
  model = PeftModel.from_pretrained(model, '/content/drive/MyDrive/Colab Notebooks/Gen AI/Semester Project/medgemma-discharge-summarization//final')


---

## Summary

This notebook demonstrated:

1. ✅ **Environment setup** with all required libraries
2. ✅ **Model loading** with QLoRA (4-bit quantization)
3. ✅ **LoRA configuration** for efficient fine-tuning
4. ✅ **Data formatting** with Gemma prompt template
5. ✅ **Training** with SFTTrainer


**Key Takeaways:**
- QLoRA enables fine-tuning large models on consumer GPUs
- High recall requires careful prompt engineering and sufficient training data
- The fine-tuned model can be easily saved and reloaded