In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import pickle
import re
from pathlib import Path

import pandas as pd
from sklearn.model_selection import train_test_split

from discharge_summaries.schemas.mimic import DischargeSummary, Note, Record
from discharge_summaries.schemas.output import Paragraph

In [None]:
RANDOM_SEED = 23
DATA_DIR = Path.cwd().parent / "data"
MIMIC_DIR = DATA_DIR / "physionet.org" / "files"

MIMIC_III_DIR = MIMIC_DIR / "mimiciii" / "1.4"
MIMIC_IV_DIR = MIMIC_DIR / "mimiciv" / "2.2" / "note"

TRAIN_SAVE_PATH = DATA_DIR / "train.pkl"
TEST_SAVE_PATH = DATA_DIR / "test.pkl"

## Read in MIMIC III notes

In [None]:
full_df = pd.read_csv(MIMIC_III_DIR / "NOTEEVENTS.csv")

### Pre processing

Remove error and duplicate rows

In [None]:
full_df = full_df[full_df["ISERROR"] != 1]
full_df.drop("ISERROR", axis=1, inplace=True)
full_df = full_df.drop_duplicates()
full_df.head()

In [None]:
len(full_df), len(full_df["HADM_ID"].unique())

Keep only Physician and discharge notes

In [None]:
full_df["CATEGORY"].unique()

In [None]:
full_df = full_df.drop(
    full_df[
        (full_df["CATEGORY"] == "Discharge summary")
        & (full_df["DESCRIPTION"] == "Addendum")
    ].index
)
len(full_df), len(full_df["HADM_ID"].unique())

Group by HADM_ID and only keep rows with both a discharge summary and physician note

In [None]:
grouped_df = full_df.groupby("HADM_ID")
df = grouped_df.filter(
    lambda group: "Discharge summary" in group["CATEGORY"].unique()
    and "Physician " in group["CATEGORY"].unique()
)

In [None]:
len(df), len(df["HADM_ID"].unique())

In [None]:
def clean_text(text: str) -> str:
    cleaned_text = re.sub(r"\n\.\n", r"\n\n", text)
    cleaned_text = re.sub(r"\n {2,}", "\n", cleaned_text)
    cleaned_text = re.sub(r"\n{3,}", "\n\n", cleaned_text)
    cleaned_text = re.sub(r"\n *(?=[a-z])", " ", cleaned_text)
    return cleaned_text


df["TEXT"] = df["TEXT"].apply(clean_text)

In [None]:
df["CHARTTIME"] = df["CHARTTIME"].fillna(full_df["CHARTDATE"] + " 23:59:59")
df = df.sort_values(by=["HADM_ID", "CHARTTIME"])
df = df.reset_index(drop=True)
len(df), len(df["HADM_ID"].unique())

In [None]:
def extract_bhc(discharge_summary_text: str) -> str:
    start_pattern = r"\nBrief Hospital Course:\n"
    end_pattern = r"\nMedications on Admission:\n"
    match = re.search(
        f"{start_pattern}(.*?){end_pattern}", discharge_summary_text, re.DOTALL
    )
    if not match:
        return ""
    return match.group(1)


def extract_bhc_paragraphs(bhc: str) -> list[Paragraph]:
    bhc_paragraphs = []
    for idx, paragraph_text in enumerate(re.split("\n\n(?=#[^\n]*:)", bhc)):
        if "\n\n" in paragraph_text:
            return []
        if idx == 0:
            heading = ""
            text = paragraph_text
        else:
            split = re.split(":", paragraph_text, maxsplit=1)
            heading = split[0]
            text = split[1].strip()
        bhc_paragraphs.append(Paragraph(text=text, heading=heading))
    return bhc_paragraphs

In [None]:
dataset = []
for hadm_id, group_df in df.groupby("HADM_ID"):
    physician_notes = [
        Note(
            text=series["TEXT"],
            datetime=series["CHARTTIME"],
            category=series["CATEGORY"],
            description=series["DESCRIPTION"],
        )
        for _, series in group_df[
            group_df["CATEGORY"] != "Discharge summary"
        ].iterrows()
    ]

    discharge_summary_row = group_df[group_df["CATEGORY"] == "Discharge summary"].iloc[
        0
    ]
    bhc = extract_bhc(discharge_summary_row["TEXT"])
    bhc_paragraphs = extract_bhc_paragraphs(bhc)
    if len(bhc_paragraphs) <= 1:
        continue

    discharge_summary = DischargeSummary(
        text=discharge_summary_row["TEXT"],
        datetime=discharge_summary_row["CHARTTIME"],
        category=discharge_summary_row["CATEGORY"],
        description=discharge_summary_row["DESCRIPTION"],
        bhc=bhc,
        bhc_paragraphs=bhc_paragraphs,
    )

    record = Record(
        physician_notes=sorted(physician_notes),
        discharge_summary=discharge_summary,
        hadm_id=hadm_id,
        subject_id=group_df["SUBJECT_ID"].iloc[0],
    )
    dataset.append(record)
len(dataset)

In [None]:
sample = dataset[3]
for section in sample.discharge_summary.bhc_paragraphs:
    print(section.heading, ":", section.text)
    print("*" * 80)
# print(sample.discharge_summary.bhc)
# print("*"*80)

In [None]:
sample = dataset[RANDOM_SEED]
print(dataset[10].discharge_summary.bhc)

In [None]:
sample.discharge_summary.bhc_paragraphs

In [None]:
train_dataset, test_dataset = train_test_split(
    dataset, test_size=0.5, random_state=RANDOM_SEED
)
len(train_dataset), len(test_dataset)

In [None]:
with open(TRAIN_SAVE_PATH, "wb") as out_file:
    pickle.dump([record.dict() for record in train_dataset], out_file)

In [None]:
with open(TEST_SAVE_PATH, "wb") as out_file:
    pickle.dump([record.dict() for record in test_dataset], out_file)