In [None]:
import re
from collections import Counter
from pathlib import Path
from typing import List, Set

import pandas as pd
from pydantic import BaseModel
from spacy.lang.en import English
from spacy.tokens.span import Span

from discharge_summaries.preprocessing.preprocess_snomed import Snomed

In [None]:
MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
SNOMED_DIR = Path.cwd().parent / "data" / "snomed"

In [None]:
discharge_summary_df = pd.read_csv(MIMIC_III_DIR / "discharge_summaries_mimic.csv")

In [None]:
for text in discharge_summary_df.iloc[:5]["BHC"]:
    print(text)
    print("*" * 80)

In [None]:
snomed = Snomed.load(SNOMED_DIR)

In [None]:
spacy_tokenizer = English().tokenizer

In [None]:
whole_snomed_matcher = snomed.get_phrase_matcher(
    {
        "Clinical finding",
        "Organism",
        "Body structure, altered from its original anatomical structure",
    },
    spacy_tokenizer,
)

In [None]:
def is_subspan(sub_span: Span, span: Span) -> bool:
    return span.start <= sub_span.start and span.end >= sub_span.end


def filter_out_subspans(spans: List[Span]) -> List[Span]:
    sorted_spans = sorted(
        enumerate(spans), key=lambda idx_and_span: len(idx_and_span[1]), reverse=True
    )
    indices_to_keep: Set[int] = set()
    for i, span in sorted_spans:
        # Check if the span overlaps with any previously added spans
        if all(not is_subspan(span, spans[j]) for j in indices_to_keep):
            indices_to_keep.add(i)
    return [spans[i] for i in indices_to_keep]


def extract_snomed_cuis(heading: str, tokenizer, snomed_phrase_matcher) -> List[int]:
    snomed_matches = snomed_phrase_matcher(tokenizer(heading), as_spans=True)
    filtered_snomed_matches = filter_out_subspans(snomed_matches)
    snomed_cuis = [int(span.label_) for span in filtered_snomed_matches]
    return snomed_cuis

In [None]:
class ProblemParagraph(BaseModel):
    heading: str
    text: str
    snomed_heading_cuis: List[int]


class BHC(BaseModel):
    hadm_id: str
    full_text: str
    reason_for_admission_para: str
    problem_paragraphs: List[ProblemParagraph]


valid_bhcs = []
incorrect_format_bhcs = []
for _, discharge_summary in discharge_summary_df.iterrows():
    paragraphs = str(discharge_summary["BHC"]).split("\n\n")
    problem_paragraphs = []
    for para in paragraphs[1:]:
        match = re.match(
            r"^([^A-Za-z]*)([A-Za-z ][^\n]*?)([-\.:][ \n])(.*)", para, re.DOTALL
        )
        if match:
            heading = match.group(2).strip()
            heading_snomed_cuis = extract_snomed_cuis(
                heading, spacy_tokenizer, whole_snomed_matcher
            )
            problem_paragraphs.append(
                ProblemParagraph(
                    heading=heading,
                    text=match.group(4).strip(),
                    snomed_heading_cuis=heading_snomed_cuis,
                )
            )
        else:
            problem_paragraphs.append(
                ProblemParagraph(heading="", text=para.strip(), snomed_heading_cuis=[])
            )
    bhc = BHC(
        hadm_id=str(discharge_summary["HADM_ID"]),
        full_text=str(discharge_summary["BHC"]),
        reason_for_admission_para=paragraphs[0],
        problem_paragraphs=problem_paragraphs,
    )

    num_problem_paragraphs_w_heading = sum(
        1 for problem_paragraph in bhc.problem_paragraphs if problem_paragraph.heading
    )
    if (
        num_problem_paragraphs_w_heading > 0
        and num_problem_paragraphs_w_heading >= len(bhc.problem_paragraphs) / 2
    ):
        valid_bhcs.append(bhc)
    else:
        incorrect_format_bhcs.append(bhc)
len(valid_bhcs), len(incorrect_format_bhcs)

In [None]:
for bhc in incorrect_format_bhcs[:10]:
    print(bhc.full_text)
    print("*" * 80)

In [None]:
for bhc in valid_bhcs[:10]:
    print(bhc.full_text)
    print("*" * 80)

In [None]:
missed_heading = []
para_count = 0
for bhc in valid_bhcs:
    for para in bhc.problem_paragraphs:
        if not para.snomed_heading_cuis:
            missed_heading.append(para.heading)
    para_count += len(bhc.problem_paragraphs)
len(missed_heading), para_count

In [None]:
Counter(missed_heading).most_common(20)