# Simplification & Summarization of Medical Information for Elderly Patients

## Background Information & Approach

TODO:
- Describe the background information for why this model is useful, particularly for elderly patients
- Describe the datasets we are using and which models we are evaluating
- Give instructions for how to run our code

## Imports & Setup

In [None]:
from typing import Dict, Tuple

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import PreTrainedTokenizer

from tqdm import tqdm

In [5]:
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device {DEVICE}")

Using device cpu


## Loading in Datasets

### Wikismall Dataset for General Text Simplification

In [None]:
# Wikismall Dataset Constants

WIKISMALL_TRAIN_SOURCE_PATH = "datasets/wikismall/train_source.txt"
WIKISMALL_TRAIN_TARGET_PATH = "datasets/wikismall/train_target.txt"
WIKISMALL_VALIDATION_SOURCE_PATH = "datasets/wikismall/val_source.txt"
WIKISMALL_VALIDATION_TARGET_PATH = "datasets/wikismall/val_target.txt"
WIKISMALL_TEST_SOURCE_PATH = "datasets/wikismall/test_source.txt"
WIKISMALL_TEST_TARGET_PATH = "datasets/wikismall/test_target.txt"

WIKISMALL_BATCH_SIZE = 16

In [None]:
class WikismallDataset(Dataset):
    """
    Dataset class representation for loading in the Wikismall dataset, allowing for fine-tuning of pretrained models with general text simplification.
    """

    def __init__(self, source_file: str, target_file: str, tokenizer: PreTrainedTokenizer):
        """
        Initializes the Wikismall dataset.

        Parameters:
        - source_file (str): Path to text source file.
        - target_file (str): Path to text target file (simplified information).
        - tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding.
        """

        self.tokenizer = tokenizer

        with open(source_file, "r") as f:
            self.source_lines = [line.strip() for line in f]
        
        with open(target_file, "r") as f:
            self.target_lines = [line.strip() for line in f]
        
        assert len(self.source_lines) == len(self.target_lines), "Source and target dataset files must have same number of lines"
    
    def __len__(self) -> int:
        """
        Calculates the length of the dataset.

        Returns:
        - (int): The number of samples in the dataset.
        """

        return len(self.source_lines)

    def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
        """
        For a given index in the dataset, retrieves the source and target encodings.

        Parameters:
        - i (int): The index of the relevant sample in the dataset.

        Returns:
        - (Dict[str, torch.Tensor]): The mapping of source and target encodings for the data sample.
        """

        source_line = self.source_lines[i]
        target_line = self.target_lines[i]

        source_encodings = self.tokenizer(
            source_line,
            return_tensors="pt",
        )
        target_encodings = self.tokenizer(
            target_line,
            return_tensors="pt",
        )

        return {
            "input_ids": source_encodings["input_ids"].squeeze(),
            "attention_mask": source_encodings["attention_mask"].squeeze(),
            "labels": target_encodings["input_ids"].squeeze(),
            "decoder_attention_mask": target_encodings["attention_mask"].squeeze(),
        }

In [None]:
def load_wikismall_dataset(tokenizer: PreTrainedTokenizer, batch_size: int) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Loads the Wikismall dataset into three different PyTorch DataLoaders, for training, validation, and test.

    Parameters:
    - tokenizer (PreTrainedTokenizer): The tokenizer to use for encoding.
    - batch_size (int): The batch size to use for data loading.

    Returns:
    - (Tuple[DataLoader, DataLoader, DataLoader]): The DataLoaders for the train, validation, and test datasets.
    """

    train_dataset = WikismallDataset(
        WIKISMALL_TRAIN_SOURCE_PATH,
        WIKISMALL_TRAIN_TARGET_PATH,
        tokenizer,
    )
    val_dataset = WikismallDataset(
        WIKISMALL_VALIDATION_SOURCE_PATH,
        WIKISMALL_VALIDATION_TARGET_PATH,
        tokenizer,
    )
    test_dataset = WikismallDataset(
        WIKISMALL_TEST_SOURCE_PATH,
        WIKISMALL_TEST_TARGET_PATH,
        tokenizer,
    )

    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size)
    test_loader = DataLoader(test_dataset, batch_size)

    return train_loader, val_loader, test_loader


# TODO: need to use model-specific tokenizers for this
train_loader, val_loader, test_loader = load_wikismall_dataset(None, WIKISMALL_BATCH_SIZE)

### MIMIC-IV-Ext-BHC Dataset for Medical Information Summarization

TODO: need to load in this dataset once granted access to it

## Fine-Tuning Pretrained Models on the Wikismall Dataset