# Simplification & Summarization of Medical Information for Elderly Patients

Final Project for CS4120 Natural Language Processing (Fall 2024).

Contributors: Lucas Dunker, Shashank Jarmale, Andrew Sun, Dylan Weinmann

## Background Information & Approach

TODO:
- Describe the background information for why this model is useful, particularly for elderly patients
- Describe the datasets we are using and which models we are evaluating
- Give instructions for how to run our code

## Imports & Setup

In [None]:
from typing import Dict, Tuple

from tqdm import tqdm

import pandas as pd
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import PreTrainedTokenizer, T5Tokenizer, T5ForConditionalGeneration

In [None]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.mps.is_available() else 'cpu')
print(f"Using device {DEVICE}")

## Loading in Datasets

### Wikismall Dataset for General Text Simplification

In [None]:
# Wikismall Dataset Constants

WIKISMALL_TRAIN_SOURCE_PATH = "datasets/wikismall/train_source.txt"
WIKISMALL_TRAIN_TARGET_PATH = "datasets/wikismall/train_target.txt"
WIKISMALL_VALIDATION_SOURCE_PATH = "datasets/wikismall/val_source.txt"
WIKISMALL_VALIDATION_TARGET_PATH = "datasets/wikismall/val_target.txt"
WIKISMALL_TEST_SOURCE_PATH = "datasets/wikismall/test_source.txt"
WIKISMALL_TEST_TARGET_PATH = "datasets/wikismall/test_target.txt"

# TODO: might want to adjust both of these
WIKISMALL_BATCH_SIZE = 16
WIKISMALL_MAX_LENGTH = 128

In [None]:
class WikismallDataset(Dataset):
    """
    Dataset class representation for loading in the Wikismall dataset, allowing for fine-tuning of pretrained models with general text simplification.
    """

    def __init__(self, source_file: str, target_file: str, tokenizer: PreTrainedTokenizer, max_length: int):
        """
        Initializes the Wikismall dataset.

        Parameters:
        - source_file (str): Path to text source file.
        - target_file (str): Path to text target file (simplified information).
        - tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding.
        - max_length (int): The maximum length to use for tokenization.
        """

        with open(source_file, "r") as f:
            source_lines = [line.strip() for line in f]
        
        with open(target_file, "r") as f:
            target_lines = [line.strip() for line in f]
        
        assert len(source_lines) == len(target_lines), "Source and target dataset files must have same number of lines"

        # Tokenize the data and save it for later
        self.data = []
        for source_line, target_line in zip(source_lines, target_lines):
            source_tokenized = tokenizer(source_line, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
            target_tokenized = tokenizer(target_line, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")
            self.data.append({
                "input_ids": source_tokenized["input_ids"].squeeze(),
                "attention_mask": source_tokenized["attention_mask"].squeeze(),
                "labels": target_tokenized["input_ids"].squeeze(),
            })
    
    def __len__(self) -> int:
        """
        Calculates the length of the dataset.

        Returns:
        - (int): The number of samples in the dataset.
        """

        return len(self.data)

    def __getitem__(self, i: int) -> Dict[str, dict]:
        """
        For a given index in the dataset, retrieves the source and target encodings.

        Parameters:
        - i (int): The index of the relevant sample in the dataset.

        Returns:
        - (Dict[str, dict]): The input IDs, attention mask, and labels for the data sample.
        """

        return self.data[i]

In [None]:
def load_wikismall_dataset(tokenizer: PreTrainedTokenizer, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Loads the Wikismall dataset into three different PyTorch DataLoaders, for training, validation, and test.

    Parameters:
    - tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding.
    - batch_size (int): The batch size to use for data loading.

    Returns:
    - (Tuple[DataLoader, DataLoader, DataLoader]): The DataLoaders for the train, validation, and test datasets.
    """

    train_dataset = WikismallDataset(
        WIKISMALL_TRAIN_SOURCE_PATH,
        WIKISMALL_TRAIN_TARGET_PATH,
        tokenizer,
        WIKISMALL_MAX_LENGTH,
    )
    val_dataset = WikismallDataset(
        WIKISMALL_VALIDATION_SOURCE_PATH,
        WIKISMALL_VALIDATION_TARGET_PATH,
        tokenizer,
        WIKISMALL_MAX_LENGTH,
    )
    test_dataset = WikismallDataset(
        WIKISMALL_TEST_SOURCE_PATH,
        WIKISMALL_TEST_TARGET_PATH,
        tokenizer,
        WIKISMALL_MAX_LENGTH,
    )

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size)
    test_loader = DataLoader(test_dataset, batch_size)

    return train_loader, val_loader, test_loader

### MIMIC-IV-Ext-BHC Dataset for Medical Information Summarization

In [None]:
# MIMIC-IV-Ext-BHC Dataset Constants

MIMIC_DATASET_CSV_PATH = "datasets/mimic-iv-ext-bhc/mimic-iv-bhc.csv"

# TODO: this dataset is massive. we might not want to use all of it, to keep training times within reason.
MIMIC_DATA_USAGE_FRACTION = 0.25

RANDOM_STATE = 42
MIMIC_TEST_SIZE = 0.2
MIMIC_VAL_SIZE = 0.25  # Proportion of training data used for validation

# TODO: might want to adjust both of these
MIMIC_BATCH_SIZE = 16
MIMIC_MAX_LENGTH = 256

In [None]:
class MIMICDataset(Dataset):
    """
    Dataset class representation for loading in the MIMIC dataset, allowing for fine-tuning of pretrained models with text simplification
    specific to medical information/patient summaries.
    """

    def __init__(self, df: pd.DataFrame, tokenizer: PreTrainedTokenizer, max_length: int):
        """
        Initializes the Wikismall dataset.

        Parameters:
        - df (pd.DataFrame): The pandas dataframe for this dataset.
        - tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding.
        - max_length (int): The maximum length to use for tokenization.
        """

        # Tokenize the data and save it for later
        self.df = df
        self.df["input_ids"] = self.df["input"].apply(lambda x: tokenizer(x, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze())
        self.df["attention_mask"] = self.df["input"].apply(lambda x: tokenizer(x, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["attention_mask"].squeeze())
        self.df["labels"] = self.df["target"].apply(lambda x: tokenizer(x, padding="max_length", truncation=True, max_length=max_length, return_tensors="pt")["input_ids"].squeeze())
    
    def __len__(self) -> int:
        """
        Calculates the length of the dataset.

        Returns:
        - (int): The number of samples in the dataset.
        """

        return len(self.df)

    def __getitem__(self, i: int) -> Dict[str, dict]:
        """
        For a given index in the dataset, retrieves the source and target encodings.

        Parameters:
        - i (int): The index of the relevant sample in the dataset.

        Returns:
        - (Dict[str, dict]): The input IDs, attention mask, and labels for the data sample.
        """

        row = self.df.iloc[i]

        return {
            "input_ids": row["input_ids"],
            "attention_mask": row["attention_mask"],
            "labels": row["labels"],
        }

In [None]:
def load_mimic_dataset(mimic_df: pd.DataFrame, tokenizer: PreTrainedTokenizer, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Loads the MIMIC-IV-Ext-BHC dataset into three different PyTorch DataLoaders, for training, validation, and test.

    Parameters:
    - mimic_df (pd.DataFrame): The entire MIMIC dataset as a pandas dataframe.
    - tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding.
    - batch_size (int): The batch size to use for data loading.

    Returns:
    - (Tuple[DataLoader, DataLoader, DataLoader]): The DataLoaders for the train, validation, and test datasets.
    """

    # Sample from the dataframe to only use however much data we want
    mimic_df = mimic_df.sample(frac=MIMIC_DATA_USAGE_FRACTION)

    # Split overall dataframe into train, val, and test dataframes
    train_val_df, test_df = train_test_split(mimic_df, test_size=MIMIC_TEST_SIZE, random_state=RANDOM_STATE)
    train_df, val_df = train_test_split(train_val_df, test_size=MIMIC_VAL_SIZE, random_state=RANDOM_STATE)

    train_dataset = MIMICDataset(
        train_df,
        tokenizer,
        MIMIC_MAX_LENGTH,
    )
    val_dataset = MIMICDataset(
        val_df,
        tokenizer,
        MIMIC_MAX_LENGTH,
    )
    test_dataset = MIMICDataset(
        test_df,
        tokenizer,
        MIMIC_MAX_LENGTH,
    )

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size)
    test_loader = DataLoader(test_dataset, batch_size)

    return train_loader, val_loader, test_loader

In [None]:
mimic_df = pd.read_csv(MIMIC_DATASET_CSV_PATH)
mimic_df

## Fine-Tuning Pretrained Models on the Wikismall Dataset

### T5

In [None]:
# Load in base T5 model and tokenizer

# TODO: could use "t5-small" if this takes too long to train
t5_model_name = "t5-base"

t5_base = T5ForConditionalGeneration.from_pretrained(t5_model_name)
t5_base.to(DEVICE)
t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)

# Freeze base model parameters
for param in t5_base.parameters():
    param.requires_grad = False

In [None]:
# Create Wikismall data loaders for T5 training, validation, and test

t5_wikismall_train_loader, t5_wikismall_val_loader, t5_wikismall_test_loader = load_wikismall_dataset(t5_tokenizer, WIKISMALL_BATCH_SIZE)

In [None]:
# Define a fine-tunable T5 model (base T5 with an additional layer processing the output)

class T5Simplifier(nn.Module):
    """
    Represents a T5 model fine-tuned on the Wikismall dataset in order to do text simplification as a sequence generation model.
    """

    def __init__(self, t5_base):
        """
        Initializes the PyTorch model for fine-tuning T5.

        Parameters:
        - t5_base (T5ForConditionalGeneration): The T5 base model.
        """

        super(T5Simplifier, self).__init__()
        self.t5_base = t5_base
        self.layer1 = nn.Linear(t5_base.config.d_model, t5_base.config.vocab_size)

    def forward(self, input_ids, attention_mask, labels=None):
        """
        Runs a forward pass for the T5 model fine-tuned to do text simplification/summarization.
        """

        outputs = self.t5_base(input_ids=input_ids, attention_mask=attention_mask, labels=labels, output_hidden_states=True)
        hidden_states = outputs.encoder_last_hidden_state
        logits = self.layer1(hidden_states)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))

        return loss, logits

In [None]:
# Initialize the fine-tunable T5 model & optimizer

t5_simplifier = T5Simplifier(t5_base).to(DEVICE)

t5_optimizer = AdamW(t5_simplifier.layer1.parameters(), lr=5e-5)

In [None]:
# Training loop (fine-tune the T5 model on the Wikismall dataset)

NUM_EPOCHS = 3

for epoch in range(NUM_EPOCHS):
    t5_simplifier.train()
    total_loss = 0
    for batch in tqdm(t5_wikismall_train_loader, desc=f"Training Epoch {epoch + 1}"):
        t5_optimizer.zero_grad()

        # Move pre-tokenized inputs and labels to the device
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)

        # Forward pass
        loss, _ = t5_simplifier(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        t5_optimizer.step()

    print(f"Epoch {epoch + 1} Loss: {total_loss / len(t5_wikismall_train_loader)}")

    # Validation
    t5_simplifier.eval()
    with torch.no_grad():
        val_loss = 0
        for batch in t5_wikismall_val_loader:
            # Move pre-tokenized inputs and labels to the device
            input_ids = batch["input_ids"].to(DEVICE)
            attention_mask = batch["attention_mask"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            # Forward pass
            loss, _ = t5_simplifier(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
            val_loss += loss.item()

        print(f"Validation Loss: {val_loss / len(t5_wikismall_val_loader)}")

# Save the fine-tuned model
t5_simplifier.save_pretrained("t5-wikismall-finetuned")
t5_tokenizer.save_pretrained("t5-wikismall-finetuned")

## Fine-Tuning Pretrained Models on the MIMIC Dataset

### T5

In [None]:
# Load in base T5 model and tokenizer

# TODO: we should eventually be fine-tuning the version of T5 already fine-tuned on Wikismall here

t5_model_name = "t5-base"

t5_base = T5ForConditionalGeneration.from_pretrained(t5_model_name)
t5_base.to(DEVICE)
t5_tokenizer = T5Tokenizer.from_pretrained(t5_model_name)

# Freeze base model parameters
for param in t5_base.parameters():
    param.requires_grad = False

In [None]:
# Create MIMIC data loaders for T5 training, validation, and test

# TODO: took 8 minutes to run with MIMIC_DATA_USAGE_FRACTION set to 0.25 - probably worth saving the final dataframes and just loading/using those directly
t5_mimic_train_loader, t5_mimic_val_loader, t5_mimic_test_loader = load_mimic_dataset(mimic_df, t5_tokenizer, MIMIC_BATCH_SIZE)

## Evaluate Model Performance

### GPT-2

### BERT

### T5

### BART

### BigBirdPegasus