In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import re
from datetime import datetime
from pathlib import Path

import pandas as pd
from dotenv import load_dotenv
from medcat.cat import CAT

from discharge_summaries.schemas.mimic import PhysicianNote

In [None]:
load_dotenv()

In [None]:
EXAMPLE_DIR = Path.cwd() / "examples"
OUTPUT_DIR = Path.cwd() / "output"

MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)
cat.config.ner.min_name_len = 2
cat.config.linking.similarity_threshold = 0

## RCP Example

In [None]:
notes_df = pd.read_excel(
    Path.cwd().parent
    / "data"
    / "rcp"
    / "5. Activity-practice discharge summary writing task_0.xlsx",
    sheet_name="Notes",
    header=4,
)
notes_df.rename({"Unnamed: 0": "timestamp", "Unnamed: 1": "text"}, axis=1, inplace=True)
notes_df.head()

In [None]:
blank_rows = notes_df.isnull().all(axis=1)
consecutive_blank_rows = blank_rows & blank_rows.shift(-1)

split_dfs = []
start_index = 0
for end_index in consecutive_blank_rows[consecutive_blank_rows].index:
    split_dfs.append(notes_df.iloc[start_index:end_index])
    start_index = end_index + 2
split_dfs.append(notes_df.iloc[start_index:])

In [None]:
notes = []

for split_df in split_dfs:
    date_string = split_df["timestamp"].tolist()[0]
    date_string_excl_day = date_string.split(" ", 1)[1]
    timestamp = datetime.strptime(date_string_excl_day, "%d %b %Y %H:%M")
    notes.append(
        PhysicianNote(
            timestamp=timestamp.strftime("%Y-%m-%d %H:%M"),
            text="\n".join(split_df["text"].dropna().tolist()),
            hadm_id="0",
        )
    )

In [None]:
response_json = json.loads((EXAMPLE_DIR / "example_1.json").read_text())

In [None]:
def flatten_dict(nested_dict, parent_key=""):
    flattened_dict = {}
    for key, value in nested_dict.items():
        new_key = f"{parent_key}__{key}" if parent_key else key
        if isinstance(value, dict):
            flattened_dict.update(flatten_dict(value, new_key))
        else:
            flattened_dict[new_key] = value
    return flattened_dict


def merge_dict_values(dict_with_lists):
    merge_value_dict = {}
    for k, v in dict_with_lists.items():
        if isinstance(v, list):
            flat_v = [
                ". ".join(sub_v.values()) if isinstance(sub_v, dict) else sub_v
                for sub_v in v
            ]
            merge_value_dict[k] = ". ".join(flat_v)
        else:
            merge_value_dict[k] = v
    return merge_value_dict

In [None]:
response_json_flat = merge_dict_values(flatten_dict(response_json))
response_json_flat

In [None]:
ignore_type_names = {"attribute", "qualifier value", "observable entity"}
keep_type_ids = {
    type_id
    for type_id, name in cat.cdb.addl_info["type_id2name"].items()
    if name not in ignore_type_names
}
cui_filters = {
    cui
    for type_ids in keep_type_ids
    for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
}
cat.cdb.config.linking["filters"]["cuis"] = cui_filters

In [None]:
response_json_ents = cat.multiprocessing(response_json_flat.items())
response_json_cuis = {
    entity["cui"]
    for field_entity_dict in response_json_ents.values()
    for entity in field_entity_dict["entities"].values()
}

In [None]:
for field, field_ents in response_json_ents.items():
    print(field)
    print(response_json_flat[field])
    print(
        [
            (ent["pretty_name"], ent["source_value"])
            for ent in field_ents["entities"].values()
        ]
    )
    print()

In [None]:
cat.cdb.config.linking["filters"]["cuis"] = response_json_cuis

In [None]:
notes_str = "\n".join(n.text for n in notes)
note_ents = cat.get_entities(notes_str)

In [None]:
note_cuis = {entity["cui"] for entity in note_ents["entities"].values()}

In [None]:
def get_acronym(phrase: str) -> str:
    words = re.split(" |-", phrase)
    acronym = "".join(
        [
            letter
            for word in words
            for letter_idx, letter in enumerate(word)
            if letter_idx == 0 or letter.isupper()
        ]
    )
    return acronym.upper()


misses = set()
for field, entities in response_json_ents.items():
    for ent in entities["entities"].values():
        ent_acyronym = get_acronym(ent["source_value"])
        if ent["cui"] in note_cuis:
            pass
        elif ent["source_value"].lower() in notes_str.lower():
            pass
        elif len(ent_acyronym) > 1 and re.search(
            r"\s" + ent_acyronym + r"\s", notes_str, re.S
        ):
            pass
        else:
            misses.add(
                (ent["cui"], ent["pretty_name"], ent["types"][0], ent["source_value"])
            )
misses

In [None]:
print(notes_str)