In [None]:
import json
import pickle
import re
from collections import Counter
from pathlib import Path

from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"
OUTPUT_DIR = Path.cwd() / "output"


TRAINING_DATASET_PATH = DATA_DIR / "train_all_ds.pkl"
RANDOM_SEED = 23
AZURE_ENGINE = "gpt-4-32k"
AZURE_API_VERSION = "2023-07-01-preview"
# AZURE_ENGINE = "gpt-35-turbo"
# AZURE_API_VERSION = "2023-07-01-preview"

GUIDELINES_PYDANTIC_PATH = (
    Path.cwd().parent
    / "guidelines"
    / "eDischarge-Summary-v2.1-1st-Feb-21_pydantic.json"
)

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
len(dataset)

In [None]:
titles = Counter(
    title
    for record in tqdm(dataset)
    for title in re.findall(
        "(?<=\n\n)[a-zA-Z ]*?(?=:.*?\n)", record.discharge_summary.text
    )
)

In [None]:
common_titles = [
    title for title, count in titles.most_common() if count > len(dataset) * 0.95
]
common_titles

In [None]:
guideline_section_name_to_mimic_headings: dict[str, list[str]] = {
    "Social context": ["Family History", "Social History"],
    "Individual requirements": [],
    "Participation in research": [],
    "Admission details": [],
    "Diagnoses": ["Discharge Diagnosis"],
    "Procedures": ["Major Surgical or Invasive Procedure"],
    "Clinical summary": ["Brief Hospital Course"],
    "Investigation results": ["Pertinent Results"],
    "Assessment scale": [],
    "Legal information": [],
    "Safety alerts": [],
    "Allergies and adverse reactions": ["Allergies"],
    "Patient and carer concerns, expectations and wishes": [],
    "Information and advice given": [],
    "Plan and requested actions": ["Discharge Instructions", "Followup Instructions"],
}

In [None]:
def mimic_discharge_summary_to_prsb_format(
    discharge_summary_text: str,
    mimic_title_headings: list[str],
    prsb_section_name_to_mimic_heading: dict[str, list[str]],
):
    title_to_body = {}
    for section in re.split(
        f"\n\n(?=(?:{'|'.join(common_titles)}):.*?\n)", discharge_summary_text
    ):
        title_and_body = section.split(":", maxsplit=1)
        title_to_body[title_and_body[0]] = title_and_body[1].strip()

    assert {
        heading
        for headings in guideline_section_name_to_mimic_headings.values()
        for heading in headings
    } - set(title_to_body.keys()) == set()

    return {
        section_name: "\n\n".join(title_to_body[heading] for heading in mimic_headings)
        for section_name, mimic_headings in guideline_section_name_to_mimic_headings.items()
    }

In [None]:
for idx, record in enumerate(dataset[:2]):
    prsb_format_summary = mimic_discharge_summary_to_prsb_format(
        record.discharge_summary.text,
        common_titles,
        guideline_section_name_to_mimic_headings,
    )
    (OUTPUT_DIR / f"{idx}_json_schema_gt.txt").write_text(
        json.dumps(prsb_format_summary, indent=4)
    )