# Simplification & Summarization of Medical Information for Elderly Patients

## Background Information & Approach

TODO:
- Describe the background information for why this model is useful, particularly for elderly patients
- Describe the datasets we are using and which models we are evaluating
- Give instructions for how to run our code

## Imports & Setup

In [89]:
from typing import Dict, Tuple

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from transformers import PreTrainedTokenizer, BertTokenizer, BertModel

from tqdm import tqdm

In [90]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {DEVICE}")

Using device cpu


## Loading in Datasets

### Wikismall Dataset for General Text Simplification

In [91]:
# Wikismall Dataset Constants

WIKISMALL_TRAIN_SOURCE_PATH = "datasets/wikismall/train_source.txt"
WIKISMALL_TRAIN_TARGET_PATH = "datasets/wikismall/train_target.txt"
WIKISMALL_VALIDATION_SOURCE_PATH = "datasets/wikismall/val_source.txt"
WIKISMALL_VALIDATION_TARGET_PATH = "datasets/wikismall/val_target.txt"
WIKISMALL_TEST_SOURCE_PATH = "datasets/wikismall/test_source.txt"
WIKISMALL_TEST_TARGET_PATH = "datasets/wikismall/test_target.txt"

WIKISMALL_BATCH_SIZE = 16

In [92]:
class WikismallDataset(Dataset):
    """
    Dataset class representation for loading in the Wikismall dataset, allowing for fine-tuning of pretrained models with general text simplification.
    """

    def __init__(self, source_file: str, target_file: str, tokenizer: PreTrainedTokenizer):
        """
        Initializes the Wikismall dataset.

        Parameters:
        - source_file (str): Path to text source file.
        - target_file (str): Path to text target file (simplified information).
        - tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding.
        """

        self.tokenizer = tokenizer

        with open(source_file, "r") as f:
            self.source_lines = [line.strip() for line in f]
        
        with open(target_file, "r") as f:
            self.target_lines = [line.strip() for line in f]
        
        assert len(self.source_lines) == len(self.target_lines), "Source and target dataset files must have same number of lines"
    
    def __len__(self) -> int:
        """
        Calculates the length of the dataset.

        Returns:
        - (int): The number of samples in the dataset.
        """

        return len(self.source_lines)

    def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
        """
        For a given index in the dataset, retrieves the source and target encodings.

        Parameters:
        - i (int): The index of the relevant sample in the dataset.

        Returns:
        - (Dict[str, torch.Tensor]): The mapping of source and target encodings for the data sample.
        """

        source_line = self.source_lines[i]
        target_line = self.target_lines[i]

        source_encodings = self.tokenizer(
            source_line,
            max_length=128,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        target_encodings = self.tokenizer(
            target_line,
            max_length=128,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        return {
            "input_ids": source_encodings["input_ids"].squeeze(),
            "attention_mask": source_encodings["attention_mask"].squeeze(),
            "labels": target_encodings["input_ids"].squeeze(),
            "decoder_attention_mask": target_encodings["attention_mask"].squeeze(),
        }

In [93]:
def load_wikismall_dataset(tokenizer: PreTrainedTokenizer, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Loads the Wikismall dataset into three different PyTorch DataLoaders, for training, validation, and test.

    Parameters:
    - tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding.
    - batch_size (int): The batch size to use for data loading.

    Returns:
    - (Tuple[DataLoader, DataLoader, DataLoader]): The DataLoaders for the train, validation, and test datasets.
    """

    train_dataset = WikismallDataset(
        WIKISMALL_TRAIN_SOURCE_PATH,
        WIKISMALL_TRAIN_TARGET_PATH,
        tokenizer,
    )
    val_dataset = WikismallDataset(
        WIKISMALL_VALIDATION_SOURCE_PATH,
        WIKISMALL_VALIDATION_TARGET_PATH,
        tokenizer,
    )
    test_dataset = WikismallDataset(
        WIKISMALL_TEST_SOURCE_PATH,
        WIKISMALL_TEST_TARGET_PATH,
        tokenizer,
    )

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size)
    test_loader = DataLoader(test_dataset, batch_size)

    return train_loader, val_loader, test_loader

### MIMIC-IV-Ext-BHC Dataset for Medical Information Summarization

TODO: need to load in this dataset once granted access to it

## Fine-Tuning Pretrained Models on the Wikismall Dataset

### BERT

In [94]:
# Load in base BERT model and tokenizer

bert_base = BertModel.from_pretrained("bert-base-uncased")
for param in bert_base.parameters():
    param.requires_grad = False
hidden_size = bert_base.config.hidden_size
vocab_size = bert_base.config.vocab_size

bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

In [95]:
# Create data loaders for BERT training, validation, and test

bert_train_loader, bert_val_loader, bert_test_loader = load_wikismall_dataset(bert_tokenizer, WIKISMALL_BATCH_SIZE)

In [None]:
# Create a fine-tunable version of BERT

class BertSimplifier(nn.Module):
    """
    Represents a BERT model fine-tuned on the Wikismall dataset in order to do text simplification as a sequence generation model.
    """

    def __init__(self, bert_base: BertModel, hidden_size: int, vocab_size: int):
        """
        Initialized the PyTorch model for fine-tuned BERT.

        Parameters:
        - bert_base (BertModel): The BERT base model.
        - hidden_size (int): The BERT hidden size dimension.
        - vocab_size (int): The BERT vocab size dimension.
        """

        super(BertSimplifier, self).__init__()
        self.bert_base = bert_base
        self.decoder = nn.Linear(hidden_size, vocab_size)
        self.softmax = nn.LogSoftmax(dim=-1)
    
    def forward(self, input_ids, attention_mask=None, labels=None):
        """
        Forward pass for the text simplification fine-tuned BERT model.
        """

        outputs = self.bert_base(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs.last_hidden_state
        logits = self.decoder(sequence_output)

        loss = None
        if labels is not None:
            loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
            loss = loss_fn(logits.view(-1, logits.size(-1)), labels.view(-1))
        
        return loss, logits

In [97]:
# Initialize model and optimizer

NUM_EPOCHS = 5

bert_simplifier = BertSimplifier(bert_base, hidden_size, vocab_size)
bert_simplifier.to(DEVICE)
bert_simplifier.train()
optimizer = AdamW(bert_simplifier.parameters(), lr=5e-5)

# Training loop
for epoch in range(NUM_EPOCHS):
    total_loss = 0
    print(f"Epoch {epoch + 1} / {NUM_EPOCHS}")

    progress_bar = tqdm(bert_train_loader, desc="Training", unit="batch")

    for batch in progress_bar:
        optimizer.zero_grad()

        # Extract inputs and labels from the batch
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch["labels"].to(DEVICE)
        decoder_attention_mask = batch["decoder_attention_mask"].to(DEVICE)

        # Forward pass
        loss, _ = bert_simplifier(input_ids=input_ids, attention_mask=attention_mask, labels=labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Update the progress bar description with the current loss
        progress_bar.set_postfix({"loss": loss.item()})

    avg_loss = total_loss / len(bert_train_loader)
    print(f"Epoch {epoch + 1}, Loss: {avg_loss:.4f}")

Epoch 1 / 5


Training:  32%|███▏      | 1756/5553 [26:54<58:11,  1.09batch/s, loss=3.78]  


KeyboardInterrupt: 