In [None]:
import pickle
import re
from pathlib import Path
from typing import List

import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import DischargeSummary, Note, Record
from discharge_summaries.schemas.output import Paragraph

In [None]:
RANDOM_SEED = 23
DATA_DIR = Path.cwd().parent / "data"
MIMIC_DIR = DATA_DIR / "physionet.org" / "files"

MIMIC_III_DIR = MIMIC_DIR / "mimiciii" / "1.4"
MIMIC_IV_DIR = MIMIC_DIR / "mimiciv" / "2.2" / "note"

TRAIN_SAVE_PATH = DATA_DIR / "train.pkl"
TEST_SAVE_PATH = DATA_DIR / "test.pkl"

## Read in MIMIC III notes

In [None]:
full_df = pd.read_csv(MIMIC_III_DIR / "NOTEEVENTS.csv")

### Pre processing

Remove error and duplicate rows

In [None]:
full_df = full_df[full_df["ISERROR"] != 1]
full_df.drop("ISERROR", axis=1, inplace=True)
full_df = full_df.drop_duplicates()
full_df.head()

In [None]:
len(full_df), len(full_df["HADM_ID"].unique())

Keep only Physician and discharge notes

In [None]:
full_df["CATEGORY"].unique()

Group by HADM_ID and keep ones with DS

In [None]:
grouped_df = full_df.groupby("HADM_ID")
df = grouped_df.filter(
    lambda group: all(
        item in group["CATEGORY"].values for item in ["Discharge summary", "Physician "]
    )
)

In [None]:
def clean_text(text: str) -> str:
    # Tidy up new lines
    cleaned_text = re.sub(r"\n\.\n", r"\n\n", text)
    cleaned_text = re.sub(r"\n {2,}", "\n", cleaned_text)
    cleaned_text = re.sub(r"\n{3,}", "\n\n", cleaned_text)
    # Make text paragraphs be on 1 line
    cleaned_text = re.sub(r"\n *(?=[a-z])", " ", cleaned_text)
    return cleaned_text


df["TEXT"] = df["TEXT"].apply(clean_text)

In [None]:
df["CHARTTIME"] = df["CHARTTIME"].fillna(full_df["CHARTDATE"] + " 23:59:59")
df = df.sort_values(by=["HADM_ID", "CHARTTIME"])
df = df.reset_index(drop=True)
len(df), len(df["HADM_ID"].unique())

In [None]:
def extract_bhc(discharge_summary_text: str) -> str:
    start_pattern = r"\nBrief Hospital Course:\n"
    end_pattern = r"\nMedications on Admission:\n"
    # Match any characters between the start and end pattern
    match = re.search(
        f"{start_pattern}(.*?){end_pattern}", discharge_summary_text, re.DOTALL
    )
    if not match:
        return ""
    return match.group(1)


def extract_bhc_paragraphs(bhc: str) -> List[Paragraph]:
    bhc_paragraphs = []
    # # or a digit, followed by non-alpha characters (split on)
    # followed by an alpha character and then anything then a ; or - or : before new line starts
    heading_punctuation = r"[;:\-\.]"
    heading_regex_pattern = (
        f"(?:#|\\d[\\.|\\)])[^a-zA-Z]*?(?=[a-zA-Z][^\n]*?(?:{heading_punctuation})["
        " |\n])"
    )
    for idx, paragraph_text in enumerate(
        re.split("\n\n" + heading_regex_pattern, bhc.strip())
    ):
        if "\n\n" in paragraph_text:
            return []
        if idx == 0:
            attempt_first_para_split = re.split(
                "^" + heading_regex_pattern, paragraph_text, maxsplit=1
            )
            paragraph_text = (
                attempt_first_para_split[1]
                if len(attempt_first_para_split) > 1
                else ": " + paragraph_text
            )
        split = re.split(f"{heading_punctuation}[ |\n]", paragraph_text, maxsplit=1)
        bhc_paragraphs.append(
            Paragraph(
                heading=split[0].strip(),
                text=split[1].strip() if len(split) > 1 else "",
            )
        )
    return bhc_paragraphs

In [None]:
dataset = []
missing_bhc, missing_paragraphs = 0, 0
for hadm_id, group_df in tqdm(df.groupby("HADM_ID")):
    physician_notes = [
        Note(
            text=series["TEXT"],
            datetime=series["CHARTTIME"],
            category=series["CATEGORY"],
            description=series["DESCRIPTION"],
        )
        for _, series in group_df[group_df["CATEGORY"] == "Physician "].iterrows()
    ]

    discharge_summary_row = group_df[group_df["CATEGORY"] == "Discharge summary"].iloc[
        0
    ]
    bhc = extract_bhc(discharge_summary_row["TEXT"])
    bhc_paragraphs = extract_bhc_paragraphs(bhc)
    if not bhc:
        missing_bhc += 1
        continue
    if len(bhc_paragraphs) <= 1:
        missing_paragraphs += 1
        continue

    discharge_summary = DischargeSummary(
        text=discharge_summary_row["TEXT"],
        datetime=discharge_summary_row["CHARTTIME"],
        category=discharge_summary_row["CATEGORY"],
        description=discharge_summary_row["DESCRIPTION"],
        bhc=bhc,
        bhc_paragraphs=bhc_paragraphs,
    )

    record = Record(
        physician_notes=sorted(physician_notes),
        discharge_summary=discharge_summary,
        hadm_id=hadm_id,
        subject_id=group_df["SUBJECT_ID"].iloc[0],
    )
    dataset.append(record)
len(dataset), missing_bhc, missing_paragraphs, len(df.groupby("HADM_ID"))

In [None]:
train_dataset, test_dataset = train_test_split(
    dataset, test_size=0.2, random_state=RANDOM_SEED
)
len(train_dataset), len(test_dataset)

In [None]:
with open(TRAIN_SAVE_PATH, "wb") as out_file:
    pickle.dump([record.dict() for record in train_dataset], out_file)

In [None]:
with open(TEST_SAVE_PATH, "wb") as out_file:
    pickle.dump([record.dict() for record in test_dataset], out_file)