Medical Question-Answering: Evaluation & Inference
---------------------------------------------------
This script performs two tasks:
1. Evaluates multiple checkpoints of a fine-tuned Flan-T5 model on a test set.
2. Demonstrates the model’s response quality on sample medical questions.

Author: Navdeep  
Last Modified: April 2025

In [5]:
# ============================================================================
# 1. LIBRARY IMPORTS AND SETUP
# ============================================================================

import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"  # Fix for potential OpenMP conflict

import pandas as pd
import torch
import numpy as np
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    DataCollatorForSeq2Seq
)
from datasets import Dataset
from tqdm import tqdm

# ============================================================================
# 2. DATA LOADING AND PREPROCESSING
# ============================================================================

# Load and clean the medical Q&A dataset
df = pd.read_csv(r"C:\Users\Navdeep\Documents\ML_Challenge\mle_screening_dataset.csv")
df["question"] = df["question"].str.strip()
df["answer"] = df["answer"].fillna("").str.strip()
df = df.rename(columns={"question": "input_text", "answer": "target_text"})

# Convert to HuggingFace Dataset and split into train/test
dataset = Dataset.from_pandas(df)
dataset = dataset.train_test_split(test_size=0.1, seed=42)

# Load tokenizer
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Define the preprocessing/tokenization function
def preprocess_function(examples):
    prefix = "Answer the medical question: "
    inputs = [prefix + question for question in examples["input_text"]]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    labels = tokenizer(text_target=examples["target_text"], max_length=128, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

# Tokenize the test set
tokenized_test = dataset["test"].map(preprocess_function, batched=True)

# ============================================================================
# 3. EVALUATE MULTIPLE CHECKPOINTS
# ============================================================================

# Define the paths to the model checkpoints to be evaluated
checkpoint_paths = [
    r"C:\Users\Navdeep\Documents\ML_Challenge\flan-t5-medical-qa_3\checkpoint-35074",
    r"C:\Users\Navdeep\Documents\ML_Challenge\flan-t5-medical-qa_3\checkpoint-33228",
    r"C:\Users\Navdeep\Documents\ML_Challenge\flan-t5-medical-qa_3\checkpoint-25844"
]

def compute_loss(model_path, num_samples=None):
    """
    Computes the average loss on the test set for a given model checkpoint.

    Args:
        model_path (str): Path to the checkpoint directory.
        num_samples (int or None): Number of test samples to evaluate. If None, uses all.

    Returns:
        tuple: (average_loss, checkpoint_name)
    """
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda").eval()
    print(f"\nEvaluating checkpoint: {model_path}")

    # Optional subset selection
    test_dataset = tokenized_test
    if num_samples and num_samples < len(tokenized_test):
        test_dataset = tokenized_test.select(range(num_samples))
        print(f"Using {num_samples} test samples")
    else:
        print(f"Using full test set: {len(test_dataset)} samples")

    # Setup data loader
    data_collator = DataCollatorForSeq2Seq(tokenizer, model=model, return_tensors="pt")
    from torch.utils.data import DataLoader
    test_dataloader = DataLoader(
        test_dataset.remove_columns(["input_text", "target_text"]),
        batch_size=8,
        collate_fn=data_collator,
        shuffle=False
    )

    # Compute average loss
    total_loss, num_batches = 0.0, 0
    with torch.no_grad():
        for batch in tqdm(test_dataloader, desc="Evaluating"):
            batch = {k: v.to("cuda") for k, v in batch.items()}
            loss = model(**batch).loss
            total_loss += loss.item()
            num_batches += 1

    avg_loss = total_loss / num_batches
    checkpoint_name = model_path.split("checkpoint-")[-1]
    return avg_loss, checkpoint_name

if __name__ == "__main__":
    print("Starting evaluation of model checkpoints...\n")
    results = []

    # Loop through each checkpoint and compute test loss
    for path in checkpoint_paths:
        loss, name = compute_loss(path)
        results.append({"checkpoint": name, "loss": loss})
        print(f"✅ Checkpoint {name} → Test Loss: {loss:.4f}")

    # Display all results in a tabular format
    print("\n📊 Summary of Test Losses:")
    print("=" * 50)
    print(f"{'Checkpoint':<20} | {'Test Loss':<10}")
    print("-" * 50)
    for res in results:
        print(f"{res['checkpoint']:<20} | {res['loss']:<10.4f}")
    print("=" * 50)

Map:   0%|          | 0/1641 [00:00<?, ? examples/s]

Computing test loss for models using None samples...
Using device: cuda
Using all 1641 samples from test set


100%|███████████████████████████████████████████████| 206/206 [00:12<00:00, 16.50it/s]


Checkpoint: 35074, Test Loss: 1.7212
Using device: cuda
Using all 1641 samples from test set


100%|███████████████████████████████████████████████| 206/206 [00:11<00:00, 17.33it/s]


Checkpoint: 33228, Test Loss: 1.7225
Using device: cuda
Using all 1641 samples from test set


100%|███████████████████████████████████████████████| 206/206 [00:11<00:00, 17.39it/s]

Checkpoint: 25844, Test Loss: 1.7358

Summary of Test Loss Results:
Checkpoint           | Test Loss 
--------------------------------------------------
35074                | 1.7212    
33228                | 1.7225    
25844                | 1.7358    





In [11]:
# ============================================================================
# 4. SAMPLE QUESTION-ANSWER INFERENCE
# ============================================================================

# Necessary imports
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

def load_and_test_model(model_path, questions):
    """
    Generates answers for a list of medical questions using a fine-tuned model.

    Args:
        model_path (str): Path to a saved model checkpoint.
        questions (list): List of medical questions as strings.

    Returns:
        dict: Dictionary mapping each question to the generated answer.
    """
    model = AutoModelForSeq2SeqLM.from_pretrained(model_path).to("cuda").eval()
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    results = {}

    for question in questions:
        # Prepare input
        prefixed = "Answer the medical question: " + question
        inputs = tokenizer(prefixed, return_tensors="pt", max_length=512, truncation=True).to("cuda")

        # Generate prediction
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_length=128,
                num_beams=5,
                early_stopping=True,
                no_repeat_ngram_size=3,
                repetition_penalty=1.5
            )
        answer = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        results[question] = answer

        # Print Q&A
        print(f"🔹 Question: {question}")
        print(f"🔸 Answer: {answer}")
        print("-" * 50)
    
    return results

# Example inference usage
if __name__ == "__main__":
    print("\nRunning example inference...\n")
    example_model_path = r"C:\Users\Navdeep\Documents\ML_Challenge\flan-t5-medical-qa_3\checkpoint-35074"
    example_questions = [
        "What is hypertension and what blood pressure readings indicate this condition?",
        "What are the main functions of the kidneys in the human body?",
        "How does insulin work in controlling blood sugar levels?"
    ]
    load_and_test_model(example_model_path, example_questions)

Question: What is hypertension and what blood pressure readings indicate this condition?
Answer: Hypertension is a condition that affects the blood pressure of the body. It is caused by a change (mutation) in the X-linked gene, which is involved in the normal flow of blood to the body's bloodstream. This condition is inherited in an autosomal recessive pattern, which means both copies of the gene in each cell have mutations. The parents of an individual with hypertension typically have one copy of the mutated gene, but they typically do not show signs and symptoms of the condition.
--------------------------------------------------
Question: What are the main functions of the kidneys in the human body?
Answer: Key Points - The kidneys are made up of many different organs, including the kidneys, tissues, and organs. - There are three main functions of kidneys in the body: - Kidneys are part of the body's immune system, which helps fight infection and protects the body from infection - I