In [None]:
import pickle
import re
from collections import Counter
from datetime import datetime
from pathlib import Path

from spacy.lang.en import English
from spacy.matcher import PhraseMatcher

from discharge_summaries.preprocessing.preprocess_snomed import Snomed
from discharge_summaries.schemas.mimic import Record
from medcat.cat import CAT
from tqdm.notebook import tqdm
import pandas as pd

In [None]:
DATA_DIR = Path.cwd().parent / "data"

TRAINING_DATASET_PATH = DATA_DIR / "train.pkl"
TIMESTAMP = datetime.now().strftime("%Y_%m_%d_%H_%M")

RANDOM_SEED = 23
SNOMED_DIR = (
            Path.cwd().parent / "data" / "SnomedCT_InternationalRF2_PRODUCTION_20230731T120000Z" / "Snapshot"/ "Terminology"
        )

description_file = SNOMED_DIR / "sct2_Description_Snapshot-en_INT_20230731.txt"
relation_file = SNOMED_DIR / "sct2_Relationship_Snapshot_INT_20230731.txt"

SPACY_MODEL = "en_core_md"

In [None]:
MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset
len(dataset)

In [None]:
titles = Counter(
    title
    for record in tqdm(dataset)
    for title in re.findall(
        "(?<=\n\n)[a-zA-Z ]*?(?=:.*?\n)", record.discharge_summary.text
    )
)

In [None]:
common_titles = [
    title for title, count in titles.most_common() if count > len(dataset) * 0.95
]
common_titles

In [None]:
dataset_title_to_body = []
for record in tqdm(dataset):
    title_to_body = {}
    for section in re.split(
        f"\n\n(?=(?:{'|'.join(common_titles)}):.*?\n)", record.discharge_summary.text
    ):
        title_and_body = section.split(":", maxsplit=1)
        title_to_body[title_and_body[0]] = title_and_body[1].strip()
    dataset_title_to_body.append(title_to_body)

In [None]:
snomed = Snomed(description_file, relation_file)


In [None]:
parent_cuis = [
    "404684003", 
    '118956008', 
    "384760004", 
    "365870005", 
    "169443000",
]
[snomed.get_preferred_term(cui) for cui in parent_cuis]

In [None]:
cui_and_missing_synonyms = [
    ("38341003", "HTN"),
    ("384760004", "FEN"),
    ("53741008", "CAD"),
    ("169443000", "PPX"),
    ("365870005", "Code Status"),
    ("365870005", "Code"),
    ("73211009", "Diabetes"),
    ("73211009", "DM"),
    ("44054006", "DM2"),
]
[(snomed.get_preferred_term(cui), synonym) for cui, synonym in cui_and_missing_synonyms]

In [None]:
missing_synonyms_df = pd.DataFrame.from_records(cui_and_missing_synonyms, columns=snomed.synonyms_df.columns)
snomed.synonyms_df = pd.concat([snomed.synonyms_df, missing_synonyms_df], ignore_index=True)
snomed.synonyms_df.tail(10)

In [None]:
synonyms_of_interest = snomed.synonyms_df[snomed.synonyms_df["cui"].isin(cuis_of_interest)]
len(synonyms_of_interest)

In [None]:
cuis_of_interest = {
    child_cui
    for parent_cui in parent_cuis
    for child_cui in snomed.get_child_cuis(parent_cui)
}
len(cuis_of_interest)

In [None]:
cuis_of_interest

In [None]:
tokenizer_spacy = English().tokenizer

In [None]:
snomed_matcher = PhraseMatcher(tokenizer_spacy.vocab, "LOWER")
for cui, group_df in tqdm(synonyms_of_interest.groupby("cui")):
    snomed_matcher.add(cui, list(tokenizer_spacy.pipe(group_df["name"])))

In [None]:
def filter_and_keep_longest(spans):
    sorted_spans = sorted(spans, key=lambda span: (span.start, -span.end, -len(span.text)))
    filtered_spans = []
    previous_end = -1
    longest_span = None

    for span in sorted_spans:
        if span.start >= previous_end:
            if longest_span:
                filtered_spans.append(longest_span)
            longest_span = span
            previous_end = span.end
        elif span.end > previous_end and len(span) > len(longest_span):
            longest_span = span
            previous_end = span.end

    if longest_span:
        filtered_spans.append(longest_span)

    return filtered_spans

In [None]:
dataset_diagnosis_cuis = []

for title_to_body in dataset_title_to_body[:1000]:
    doc_diagnosis_cuis = []
    for line in re.split("[\n,]", title_to_body.get("Discharge Diagnosis", "")):
        matches = filter_and_keep_longest(snomed_matcher(tokenizer_spacy(line), as_spans=True))
        if matches:
            doc_diagnosis_cuis.append(matches[0].label_)
    dataset_diagnosis_cuis.append(doc_diagnosis_cuis)
    # print("*"*80)

In [None]:
count = 0
num_paras = 0
misses = []
for record_title_to_body in dataset_title_to_body:
    bhc = record_title_to_body.get("Brief Hospital Course", "")
    for bhc_paragraph in bhc.split("\n\n")[1:]:
        heading = re.split("[:-]", bhc_paragraph, 1)[0]
        # if any(skip_heading in heading.lower() for skip_heading in ("code", "fen", "access", "htn")):
        #     continue
        matches = filter_and_keep_longest(snomed_matcher(tokenizer_spacy(heading), as_spans=True))
        if matches:
            count += 1
        if not matches:
            misses.append(heading.lower())
        num_paras += 1
count/num_paras

In [None]:
# access, rhythm, pump

In [None]:
snomed_matcher(tokenizer_spacy("fen"), as_spans=True)

In [None]:
len(tokenizer_spacy("fen"))

In [None]:
Counter(misses).most_common()

In [None]:
print(dataset_title_to_body[21].get("Brief Hospital Course", ""))

In [None]:
for record_title_to_body, record_diagnosis_cuis in zip(dataset_title_to_body[:10], dataset_diagnosis_cuis):
    bhc = record_title_to_body.get("Brief Hospital Course", "")
    for bhc_paragraph in bhc.split("\n\n")[1:]:
        first_line = bhc_paragraph.split("\n")[0]
        matches = filter_and_keep_longest(snomed_matcher(tokenizer_spacy(first_line), as_spans=True))
        print(first_line)
        # print(df_p_terms.loc[matches[0].label_]["name"] if matches else "None")
        # print()
    print()
    print(record_title_to_body.get("Discharge Diagnosis", ""))
    print(record_title_to_body.get("Past Medical History", ""))
    # print([df_p_terms.loc[cui]["name"] for cui in record_diagnosis_cuis])
    print("*"*80)

In [None]:
Counter(df_p_terms.loc[span.label_]['description_type_ids'] for span in tqdm(diagnosis_spans)).most_common(5)

In [None]:
count = 0
keep_sentences = []
chunks = [sentence.replace("\n", " ") for note in dataset[0].physician_notes for chunk in note.text.split("\n\n") for sentence in re.split("(?<=\.) |\n(?![a-z])", chunk)]
chunks = list(dict.fromkeys(chunks))
for chunk in chunks:
    doc = tokenizer_spacy(chunk)
    matches = snomed_matcher(doc)
    if matches:
        keep_sentences.append(chunk)
        # for match in matches:
        #     print(doc[match[1]:match[2]])
        #     print("*" * 80)
count, len(chunks)

In [None]:
chunks[:10]

In [None]:
print(dataset[0].physician_notes[0].text)

In [None]:
matches

In [None]:
count = 0
for chunk in tqdm(chunks):
    matches = cat(chunk)
    if matches:
        count += 1
count, len(chunks)