In [None]:
import re
from collections import Counter
from pathlib import Path
from typing import List, Set

import pandas as pd
from pydantic import BaseModel
from spacy.lang.en import English
from spacy.tokens.span import Span

from discharge_summaries.preprocessing.preprocess_snomed import Snomed

In [None]:
MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
SNOMED_DIR = Path.cwd().parent / "data" / "snomed"

In [None]:
discharge_summary_df = pd.read_csv(MIMIC_III_DIR / "discharge_summaries_mimic.csv")

In [None]:
for text in discharge_summary_df.iloc[:5]["BHC"]:
    print(text)
    print("*" * 80)

In [None]:
snomed = Snomed.load(SNOMED_DIR)

In [None]:
spacy_tokenizer = English().tokenizer

In [None]:
whole_snomed_matcher = snomed.get_phrase_matcher(
    {
        "Clinical finding",
        "Organism",
        "Body structure, altered from its original anatomical structure",
    },
    spacy_tokenizer,
)

In [None]:
def is_subspan(sub_span: Span, span: Span) -> bool:
    return span.start <= sub_span.start and span.end >= sub_span.end


def filter_out_subspans(spans: List[Span]) -> List[Span]:
    sorted_spans = sorted(
        enumerate(spans), key=lambda idx_and_span: len(idx_and_span[1]), reverse=True
    )
    indices_to_keep: Set[int] = set()
    for i, span in sorted_spans:
        # Check if the span overlaps with any previously added spans
        if all(not is_subspan(span, spans[j]) for j in indices_to_keep):
            indices_to_keep.add(i)
    return [spans[i] for i in indices_to_keep]


def extract_snomed_cuis(heading: str, tokenizer, snomed_phrase_matcher) -> List[int]:
    snomed_matches = snomed_phrase_matcher(tokenizer(heading), as_spans=True)
    filtered_snomed_matches = filter_out_subspans(snomed_matches)
    snomed_cuis = [int(span.label_) for span in filtered_snomed_matches]
    return snomed_cuis

In [None]:
test = """Mr Known lastname 52368 is a 59M w HCV cirrhosis w grade II esophageal varices
admitted w coffee-ground emesis and melena concerning for UGIB,
s/p MICU stay for hypotension.

# UGIB: Pt did not have any more bleeds while in hospital. EGD
revealed erythema and erosion in the antrum and pylorus
compatible with non-steroidal induced gastritis. Pt did remember
taking increased doses of naproxen for backache. Started on
pantoprazole 40mg PO BID for one week with repeat endoscopy
scheduled in one week (4-30). Recommended to take tylenol (max
daily dose of 2gm) for pain instead of NSAIDs. Blood pressure
meds were held at first, given MICU admission for hypotension,
but were restarted on discharge.

# HCV Cirrhosis: appears to be progressing to liver failure,
with elevated INR at 1.6, decreased albumin at 2.6, tbili
slightly elevated at 3.6, and chronic LE edema. Pt was continued
on prophylactic medications.

# FULL CODE"""

In [None]:
for section in re.split(
    r"\n\n(?=(?:[^A-Za-z]*)(?:[A-Za-z][^\n]*?)(?:[-\.:][ \n]))",
    discharge_summary_df["BHC"][60],
):
    print(section[:10])

In [None]:
punctuation_prefix = "[^A-Za-z]*"
heading_regex = "[A-Za-z][^\n]*?"
heading_delimiter = "[-:] "
paragraph_text = ".*"

paragraph_split_regex = re.compile(
    f"\n\n(?={punctuation_prefix}{heading_regex}{heading_delimiter})"
)
heading_grouping_regex = re.compile(
    f"^{punctuation_prefix}({heading_regex}){heading_delimiter}({paragraph_text})",
    re.DOTALL,
)

In [None]:
paragraph_split_regex

In [None]:
# for section in (re.split(paragraph_split_regex, discharge_summary_df["BHC"][80])):
#     print(section)
#     groups = re.match(heading_grouping_regex, section)
#     if groups:
#         print(groups.group(1))
#         print(groups.group(2))
#         print("*" * 80)

In [None]:
class ProblemSection(BaseModel):
    heading: str
    text: str
    snomed_heading_cuis: List[int]


class BHC(BaseModel):
    hadm_id: str
    full_text: str
    reason_for_admission: str
    problem_sections: List[ProblemSection]

In [None]:
bhcs = []
for _, discharge_summary in discharge_summary_df.iterrows():
    paragraphs = re.split(paragraph_split_regex, str(discharge_summary["BHC"]))

    if re.match(heading_grouping_regex, paragraphs[0]):
        reason_for_admission = ""
        problem_paragraph_start_idx = 0
    else:
        reason_for_admission = paragraphs[0]
        problem_paragraph_start_idx = 1

    problem_sections = []
    for para in paragraphs[problem_paragraph_start_idx:]:
        match = re.match(heading_grouping_regex, para)
        if match:
            heading = match.group(1).strip()
            heading_snomed_cuis = extract_snomed_cuis(
                heading, spacy_tokenizer, whole_snomed_matcher
            )
            problem_sections.append(
                ProblemSection(
                    heading=heading,
                    text=match.group(2).strip(),
                    snomed_heading_cuis=heading_snomed_cuis,
                )
            )
        else:
            problem_sections.append(
                ProblemSection(heading="", text=para.strip(), snomed_heading_cuis=[])
            )
    bhcs.append(
        BHC(
            hadm_id=str(discharge_summary["HADM_ID"]),
            full_text=str(discharge_summary["BHC"]),
            reason_for_admission=reason_for_admission,
            problem_sections=problem_sections,
        )
    )

In [None]:
valid_bhcs = []
incorrect_format_bhcs = []
for bhc in bhcs:
    num_problem_paragraphs_w_heading = sum(
        1 for problem_paragraph in bhc.problem_sections if problem_paragraph.heading
    )
    if (
        num_problem_paragraphs_w_heading > 0
        and num_problem_paragraphs_w_heading >= len(bhc.problem_sections) / 2
    ):
        valid_bhcs.append(bhc)
    else:
        incorrect_format_bhcs.append(bhc)
len(valid_bhcs), len(incorrect_format_bhcs)

In [None]:
for bhc in incorrect_format_bhcs[:10]:
    print(bhc.full_text)
    print("*" * 80)

In [None]:
for bhc in valid_bhcs[20:30]:
    print(bhc.full_text)
    print("---")
    print(bhc.reason_for_admission)
    for section in bhc.problem_sections:
        print("---")
        print(section.heading)
        print(section.text)
    print("*" * 80)

In [None]:
missed_heading = []
para_count = 0
for bhc in valid_bhcs:
    for para in bhc.problem_sections:
        if not para.snomed_heading_cuis:
            missed_heading.append(para.heading)
    para_count += len(bhc.problem_sections)
len(missed_heading) / para_count

In [None]:
Counter(missed_heading).most_common(20)

In [None]:
# synonyms_of_interest_df = snomed.synonyms_df[
#     snomed.synonyms_df["cui"].isin(cuis_of_interest)
# ]
# len(synonyms_of_interest_df)

In [None]:
# def add_missing_synonyms(
#     missing_cui_and_synonyms: List[Tuple[str, str]], synonyms_df
# ) -> None:
#     missing_synonyms_df = pd.DataFrame.from_records(
#         missing_cui_and_synonyms, columns=synonyms_df.columns
#     )
#     synonyms_df = pd.concat([synonyms_df, missing_synonyms_df], ignore_index=True)
#     synonyms_df.drop_duplicates(inplace=True)
#     return synonyms_df

In [None]:
# cui_and_missing_synonyms = [
#     ("384760004", "FEN"),
#     # ("365870005", "Code Status"),
#     # ("365870005", "Code"),
#     ("73211009", "Diabetes"),
#     ("73211009", "DM"),  # Acryonym but shorter than 2 characters
#     ("44054006", "DM2"),
#     ("169443000", "PPX"),
#     # ("432138007", "Communication"),
#     # ("432138007", "Comm"),
#     # ("726711005", "Dispo"),
#     ("74474003", "GI Bleed"),
#     ("160931000119108", "transaminitis"),
#     ("49436004", "Afib"),
#     ("19943007", "Cirrhosis"),
#     ("74474003", "GIB"),
#     # ("386661006", "Fevers"),
#     # ("91175000", "Seizures"),
#     # ("311788003", "Access"),
#     # ("251149006", "Rhythm"),
#     # ("739122008", "Pump"),
# ]
# synonyms_of_interest_df = add_missing_synonyms(
#     cui_and_missing_synonyms, synonyms_of_interest_df
# )
# [(snomed.get_preferred_term(cui), synonym) for cui, synonym in cui_and_missing_synonyms]