Medical Question-Answering Model Training Pipeline
---------------------------------------------------
This notebook implements a fine-tuning pipeline for a medical question-answering system
using the Flan-T5 model from Google. The model is trained to generate accurate and 
relevant responses to medical questions based on a provided dataset.

Author: Navdeep
Last Modified: April 2025

In [1]:
# ============================================================================
# 1. ENVIRONMENT AND LIBRARY SETUP
# ============================================================================

# Fix potential OpenMP runtime library conflicts

# Fix potential OpenMP runtime library conflicts
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Import core libraries
import pandas as pd  # For data manipulation
import torch  # For deep learning operations
import numpy as np  # For numerical computation

# Import HuggingFace Transformers libraries for fine-tuning
from transformers import (
    AutoModelForSeq2SeqLM,  # For loading pre-trained seq2seq models
    AutoTokenizer,  # For tokenizing text
    Seq2SeqTrainingArguments,  # For configuring training parameters
    Seq2SeqTrainer,  # For managing the training process
    DataCollatorForSeq2Seq,  # For batching and padding sequences
    EarlyStoppingCallback  # For stopping training when performance plateaus
)

# Import dataset and evaluation utilities
from datasets import Dataset  # For creating and managing datasets
import evaluate  # For loading evaluation metrics
import nltk  # For natural language processing
from nltk.tokenize import sent_tokenize  # For sentence tokenization


In [2]:
# ============================================================================
# 2. DATA LOADING AND PREPROCESSING
# ============================================================================
import re  # For text cleaning with regular expressions

# Load the medical question-answer dataset from CSV
df = pd.read_csv(r"C:\Users\Navdeep\Documents\ML_Challenge\mle_screening_dataset.csv")

# Define a function to clean special formatting and unwanted content from text
def clean_special_terms(text):
    """
    Cleans text by removing URLs, email addresses, HTML tags, and normalizing whitespace.
    
    Args:
        text (str or None): The text to clean
        
    Returns:
        str: The cleaned text
    """
    # Handle non-string inputs by returning empty string
    if not isinstance(text, str):
        return ""
    
    # Remove URLs
    text = re.sub(r'https?://\S+|www\.\S+', '', text)
    
    # Remove email addresses
    text = re.sub(r'\S+@\S+', '', text)
    
    # Remove HTML tags
    text = re.sub(r'<.*?>', '', text)
    
    # Normalize whitespace (remove extra spaces, tabs, newlines)
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text

# Apply the cleaning function to the answers
df["answer"] = df["answer"].apply(clean_special_terms)

# Clean questions and handle missing values
df["question"] = df["question"].str.strip()  # Remove leading/trailing whitespace
df["answer"] = df["answer"].fillna("").str.strip()  # Replace NaN with empty string

# Rename columns to match the expected format for seq2seq models
df = df.rename(columns={"question": "input_text", "answer": "target_text"})

# Convert pandas DataFrame to HuggingFace Dataset format for easier processing
dataset = Dataset.from_pandas(df)

# Split the dataset into training (90%) and testing (10%) sets
# Fixed random seed (42) ensures reproducibility
dataset = dataset.train_test_split(test_size=0.1, seed=42)

In [3]:
# ============================================================================
# 3. MODEL AND TOKENIZER INITIALIZATION
# ============================================================================

# Select the pre-trained model - Flan-T5 small is a good balance of size and performance
model_name = "google/flan-t5-small"  # Smaller model to fit within GPU constraints
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
# ============================================================================
# 4. TEXT PREPROCESSING AND TOKENIZATION
# ============================================================================

# Download NLTK resources for sentence tokenization
nltk.download('punkt')
nltk.download('punkt_tab')

# Define a function to preprocess and tokenize the dataset
def preprocess_function(examples):
    """
    Prepares examples for the model by adding a task-specific prefix and tokenizing.
    
    Args:
        examples (dict): Batch of examples with input_text and target_text fields
        
    Returns:
        dict: Processed examples with input_ids, attention_mask, and labels
    """
    # Add a task-specific prefix to help the model understand the task
    prefix = "Answer the medical question: "
    inputs = [prefix + question for question in examples["input_text"]]
    
    # Tokenize inputs (questions) with truncation at 512 tokens
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    
    # Tokenize targets (answers) with truncation at 128 tokens
    labels = tokenizer(text_target=examples["target_text"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    
    return model_inputs

# Apply the preprocessing function to both training and test datasets
tokenized_datasets = dataset.map(preprocess_function, batched=True)

# Load the pre-trained T5 model
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Navdeep\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt_tab to
[nltk_data]     C:\Users\Navdeep\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt_tab is already up-to-date!


Map:   0%|          | 0/14765 [00:00<?, ? examples/s]

Map:   0%|          | 0/1641 [00:00<?, ? examples/s]

In [5]:
# ============================================================================
# 5. TRAINING CONFIGURATION
# ============================================================================

# Configure training parameters
training_args = Seq2SeqTrainingArguments(
    output_dir=r"C:\Users\Navdeep\Documents\ML_Challenge\flan-t5-medical-qa_3",  # Directory to save model checkpoints
    eval_strategy="epoch",  # Evaluate after each epoch
    save_strategy="epoch",  # Save checkpoint after each epoch
    learning_rate=3e-5,  # Learning rate (adjusted for fine-tuning)
    per_device_train_batch_size=8,  # Batch size for training
    per_device_eval_batch_size=8,  # Batch size for evaluation
    weight_decay=0.01,  # L2 regularization to prevent overfitting
    save_total_limit=3,  # Keep only the 3 best checkpoints to save disk space
    num_train_epochs=20,  # Maximum number of training epochs
    predict_with_generate=True,  # Use generation for evaluation (needed for ROUGE)
    load_best_model_at_end=True,  # Load the best model at the end of training
    metric_for_best_model='eval_rougeL'  # Metric to track for best model selection
)

In [6]:
# ============================================================================
# 6. DATA COLLATION AND METRICS SETUP
# ============================================================================

# Define data collator for padding batches to the same length
data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=-100,  # Ignore padding tokens in loss calculation
    pad_to_multiple_of=8 if training_args.fp16 else None,  # Optimize for GPU efficiency
)

# Load the ROUGE metric for evaluation
metric = evaluate.load("rouge")

# Define a function to compute evaluation metrics
def compute_metrics(eval_preds):
    """
    Compute ROUGE scores between predictions and references.
    
    Args:
        eval_preds (tuple): Tuple containing predictions and labels
        
    Returns:
        dict: Dictionary of evaluation metrics
    """
    preds, labels = eval_preds
    
    # Replace -100 padding tokens with the tokenizer's pad token for decoding
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    
    # Decode the predictions and labels back into text
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Format for ROUGE: expects newlines between sentences
    decoded_preds = ["\n".join(sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(sent_tokenize(label.strip())) for label in decoded_labels]
    
    # Compute ROUGE scores
    result = metric.compute(
        predictions=decoded_preds, 
        references=decoded_labels, 
        use_stemmer=True  # Use stemming for better matching
    )
    
    # Convert scores to percentages for easier interpretation
    result = {key: value * 100 for key, value in result.items()}
    
    # Add mean generated length as an additional metric
    prediction_lens = [len(pred.split()) for pred in decoded_preds]
    result["gen_len"] = np.mean(prediction_lens)
    
    # Round values for cleaner output
    return {k: round(v, 4) for k, v in result.items()}    

In [7]:
# Initialize the Seq2Seq trainer
trainer = Seq2SeqTrainer(
    model=model,  # Pre-trained model to fine-tune
    args=training_args,  # Training arguments
    train_dataset=tokenized_datasets["train"],  # Training dataset
    eval_dataset=tokenized_datasets["test"],  # Evaluation dataset
    tokenizer=tokenizer,  # Tokenizer
    data_collator=data_collator,  # Data collator for batching
    compute_metrics=compute_metrics,  # Metrics computation function
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5, early_stopping_threshold=0.01)]  # Stop if no improvement for 5 epochs
)

# Start the training process
trainer.train()

# Note: The trained model will be saved in the output_dir specified in training_args
# The best model can be loaded for inference using AutoModelForSeq2SeqLM.from_pretrained(output_dir)

  trainer = Seq2SeqTrainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss,Rouge1,Rouge2,Rougel,Rougelsum,Gen Len
1,2.3951,2.05575,20.078,12.1841,18.2536,19.1895,12.0817
2,2.2077,1.962967,20.8936,12.9038,19.038,19.9827,12.1749
3,2.119,1.909121,20.7993,12.8772,18.9372,19.8883,12.1432
4,2.1019,1.870995,20.9152,12.9891,19.1003,20.0403,12.1542
5,2.0216,1.84213,21.0432,13.1766,19.2303,20.1409,12.1255
6,1.9954,1.820188,21.028,13.1404,19.1948,20.1126,12.1853
7,1.9698,1.802049,21.1719,13.1873,19.3244,20.2379,12.2163
8,1.9594,1.785552,21.1476,13.2325,19.3132,20.2249,12.2322
9,1.9324,1.77328,21.2294,13.3077,19.3781,20.3292,12.2163
10,1.9202,1.762884,21.4212,13.3602,19.5399,20.4887,12.2639


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].


TrainOutput(global_step=35074, training_loss=1.9800004664906865, metrics={'train_runtime': 8594.5524, 'train_samples_per_second': 34.359, 'train_steps_per_second': 4.296, 'total_flos': 3005267464034304.0, 'train_loss': 1.9800004664906865, 'epoch': 19.0})