In [None]:
import pickle
import re
from collections import Counter
from pathlib import Path
from typing import List

import pandas as pd
from pydantic import BaseModel
from tqdm.notebook import tqdm

from discharge_summaries.snomed.lookup import SnomedLookup

In [None]:
MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
SNOMED_DIR = Path.cwd().parent / "data" / "snomed"
PHRASE_MATCHER_FPATH = SNOMED_DIR / "snomed_phrase_matcher.pkl"

In [None]:
discharge_summary_df = pd.read_csv(MIMIC_III_DIR / "discharge_summaries_mimic.csv")

In [None]:
snomed_phrase_matcher = pickle.loads(PHRASE_MATCHER_FPATH.read_bytes())

In [None]:
for text in discharge_summary_df.iloc[:5]["BHC"]:
    print(text)
    print("*" * 80)

In [None]:
punctuation_prefix = "[^A-Za-z]*"
heading_regex = "[A-Za-z][^\n]*?"
heading_delimiter = "[-:] "
paragraph_text = ".*"

paragraph_split_regex = re.compile(
    f"\n\n(?={punctuation_prefix}{heading_regex}{heading_delimiter})"
)
heading_grouping_regex = re.compile(
    f"^{punctuation_prefix}({heading_regex}){heading_delimiter}({paragraph_text})",
    re.DOTALL,
)

In [None]:
class ProblemSection(BaseModel):
    heading: str
    text: str
    snomed_heading_cuis: List[int]


class BHC(BaseModel):
    hadm_id: str
    full_text: str
    reason_for_admission: str
    problem_sections: List[ProblemSection]

In [None]:
bhcs = []
for _, discharge_summary in tqdm(
    discharge_summary_df.iterrows(), total=len(discharge_summary_df)
):
    paragraphs = re.split(paragraph_split_regex, str(discharge_summary["BHC"]))

    first_match = re.match(heading_grouping_regex, paragraphs[0])
    if (
        not first_match
        or "assessment" in first_match.group(1).strip().lower()
        or "a/p" in first_match.group(1).strip().lower()
    ):
        reason_for_admission = paragraphs[0]
    else:
        reason_for_admission = ""
    problem_paragraph_start_idx = 1 if reason_for_admission else 0

    headings, texts = [], []
    for para in paragraphs[problem_paragraph_start_idx:]:
        match = re.match(heading_grouping_regex, para)
        headings.append(match.group(1).strip() if match else "")
        texts.append(match.group(2).strip() if match else para.strip())
    heading_snomed_cuis = snomed_phrase_matcher.pipe(headings)
    problem_sections = [
        ProblemSection(heading=h, text=t, snomed_heading_cuis=cuis)
        for h, t, cuis in zip(headings, texts, heading_snomed_cuis)
    ]

    bhcs.append(
        BHC(
            hadm_id=str(discharge_summary["HADM_ID"]),
            full_text=str(discharge_summary["BHC"]),
            reason_for_admission=reason_for_admission,
            problem_sections=problem_sections,
        )
    )

In [None]:
valid_bhcs = []
incorrect_format_bhcs = []
for bhc in bhcs:
    num_problem_paragraphs_w_heading = sum(
        1 for problem_paragraph in bhc.problem_sections if problem_paragraph.heading
    )
    if num_problem_paragraphs_w_heading > len(bhc.problem_sections) / 2:
        valid_bhcs.append(bhc)
    else:
        incorrect_format_bhcs.append(bhc)
len(valid_bhcs), len(incorrect_format_bhcs)

In [None]:
for bhc in incorrect_format_bhcs[:10]:
    print(bhc.full_text)
    print("*" * 80)

In [None]:
for bhc in valid_bhcs[20:30]:
    print("Full Text:")
    print(bhc.full_text)
    print("---")
    print("Reason for Admission:")
    print(bhc.reason_for_admission)
    print("Sections")
    for section in bhc.problem_sections:
        print("---")
        print(section.heading)
        print(section.text)
    print("*" * 80)

In [None]:
missed_heading = []
para_count = 0
for bhc in valid_bhcs:
    for para in bhc.problem_sections:
        if not para.snomed_heading_cuis:
            missed_heading.append(para.heading)
    para_count += len(bhc.problem_sections)
len(missed_heading) / para_count

In [None]:
Counter(missed_heading).most_common(50)

In [None]:
SNOMED_DIR = Path.cwd().parent / "data" / "snomed"
snomed_lookup = SnomedLookup.load(SNOMED_DIR)

In [None]:
extra_parent_cuis = {
    169443000,
    311788003,
    384760004,
}
print([snomed_lookup.cui_to_preferred_term[cui] for cui in extra_parent_cuis])

In [None]:
cui_and_missing_synonyms = [
    (365870005, "Code"),
    (311788003, "Access"),
    (384760004, "FEN"),  # Added as parent cui
    # Access is a parent cui
    (118231006, "Communication"),
    (301113001, "Rhythm"),
    (106063007, "Pump"),
    (169443000, "PPX"),  # Added as parent cui
    (73211009, "Diabetes"),
    (102957003, "Neuro"),
    (251015000, "Coronaries"),
    (36456004, "Dispo"),
    (384760004, "Nutrition"),
    (106063007, "CV"),
    (903081000000107, "Contact"),
    (36456004, "Disposition"),
    (118231006, "Comm"),
    (160931000119108, "Transaminitis"),
    (106048009, "Pulmonary"),  # ?
    (44054006, "DM2"),
    (299691001, "Heme"),
    (49436004, "Afib"),
    (401314000, "NSTEMI"),
    (106176003, "Endocrine"),
    (118238000, "Renal"),  # ?
    (19943007, "Cirrhosis"),
    (106048009, "Respiratory"),
    (116367006, "Psych"),
    (299691001, "Hematology"),
    (419284004, "AMS"),
    (301095005, "Cardiac"),
    (74474003, "GIB"),
    (166603001, "Elevated LFTs"),
    (106048009, "Resp"),
    (44054006, "DMII"),
    (301120008, "EKG changes"),
    # ('micu course', 30),
    (444931001, "Elevated troponin"),
    (106176003, "Endo"),
    (191480000, "ETOH withdrawal"),
    (37372002, "UGIB"),
    # ('goals of care', 77),
    # ('last name (un)', 25),
    (401303003, "STEMI"),
    (235856003, "ESLD"),
    #  ('anticoagulation', 24),
    (398137007, "CRI"),
    (106048009, "Pulm"),
    (233604007, "PNA"),
    (106063007, "Cardiovascular"),
    (284465006, "Social"),
    (405729008, "BRBPR"),  # 5 letter acronym
    (237840007, "Anion gap"),  # 3 different options here so chose parent
    (721104000, "Urosepsis"),
]

In [None]:
keep_child_cuis = True
cui_to_synonyms_lower = {}
for parent_cui in extra_parent_cuis:
    parent_synonyms = {
        synonym.lower()
        for synonym in snomed_lookup.cui_to_synonyms.get(parent_cui, set())
    }
    if not parent_synonyms:
        raise ValueError(f"Parent CUI {parent_cui} has no synonyms")
    cui_to_synonyms_lower[parent_cui] = parent_synonyms

    for child_cui in snomed_lookup.get_child_cuis(parent_cui):
        child_synonyms = {
            synonym.lower()
            for synonym in snomed_lookup.cui_to_synonyms.get(child_cui, set())
        }
        if keep_child_cuis:
            cui_to_synonyms_lower[child_cui] = child_synonyms
        else:
            cui_to_synonyms_lower[parent_cui].union(child_synonyms)

for cui, synonyms_lower in tqdm(cui_to_synonyms_lower.items()):
    snomed_phrase_matcher._phrase_matcher.add(
        str(cui), list(snomed_phrase_matcher._nlp.pipe(synonyms_lower))
    )

In [None]:
for cui, missing_synonym in cui_and_missing_synonyms:
    snomed_phrase_matcher._phrase_matcher.add(
        str(cui), list(snomed_phrase_matcher._nlp.pipe([missing_synonym.lower()]))
    )

In [None]:
headings = [para.heading for bhc in valid_bhcs for para in bhc.problem_sections]
snomed_codes = snomed_phrase_matcher.pipe(headings)
missed_headings = [
    heading.lower() for heading, cuis in zip(headings, snomed_codes) if not cuis
]

In [None]:
len(missed_headings) / len(headings)

In [None]:
Counter(missed_headings).most_common(100)

In [None]:
for bhc in valid_bhcs[:1000]:
    for idx, para in enumerate(bhc.problem_sections):
        if para.heading.lower() == "cri":
            print(idx, para)