In [None]:
import pickle
from collections import Counter
from pathlib import Path

from medcat.cat import CAT
from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"

TRAINING_DATASET_PATH = DATA_DIR / "train.pkl"
DATASET_NOTE_CUI_CACHE_PATH = DATA_DIR / "dataset_note_cui_cache.json"
MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)
RANDOM_SEED = 23
LOG_FILE = "./medcat.log"

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset
len(dataset)

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)

In [None]:
cat.pipe.spacy_nlp.disable_pipes(["Status"])

In [None]:
cat.pipe.spacy_nlp.pipeline

In [None]:
for _, component in cat.pipe.spacy_nlp.pipeline:
    # component = cat.pipe.spacy_nlp.get_pipe(component_name)
    dependencies = component.requires if hasattr(component, "requires") else []
    print(f"Component: {component}, Dependencies: {dependencies}")

In [None]:
def extract_cuis_from_text(text: str, cat: CAT):
    text_ents = cat(text).ents if text else ()
    return {ent._.cui for ent in text_ents}

In [None]:
filter_type_names = {
    "disorder",
    "finding",
    "morphologic abnormality",
    "organism",
    "physical object",
    "clinical drug",
    "medicinal product form",
    "procedure",
    "product",
}

type_name_to_id = {
    name: type_id for type_id, name in cat.cdb.addl_info["type_id2name"].items()
}

type_ids_filter = [type_name_to_id[type_name] for type_name in filter_type_names]

cui_filters = {
    cui
    for type_ids in type_ids_filter
    for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
}
cat.cdb.config.linking["filters"]["cuis"] = cui_filters
len(cui_filters)

In [None]:
dataset_note_cuis = [
    extract_cuis_from_text(
        "/n/n".join(
            note.text for note in doc.physician_notes if note.category in {"Physician "}
        ),
        cat,
    )
    for doc in tqdm(dataset[:10])
]

In [None]:
num_headings = 0
num_matches = 0
cui_hits = []
partial_cui_hit = []
strict_match = []
no_match = []
for doc, doc_note_cuis in tqdm(zip(dataset, dataset_note_cuis)):
    for para in doc.discharge_summary.bhc_paragraphs:
        if not para.heading:
            continue
        num_headings += 1
        para_cuis = extract_cuis_from_text(para.heading, cat)

        split_headings = para.heading.split("/")
        joined_doc_notes = "\n\n".join(
            note.text.lower()
            for note in doc.physician_notes
            if note.category in {"Physician "}
        )

        if para_cuis and para_cuis.issubset(doc_note_cuis):
            cui_hits.append(para.heading)
        elif para_cuis.intersection(doc_note_cuis) != set():
            partial_cui_hit.append(para.heading)
        elif para.heading.lower() in joined_doc_notes:
            strict_match.append(para.heading)
        else:
            no_match.append(para.heading)

In [None]:
total_hit_rate = (
    len(cui_hits) + len(partial_cui_hit) + len(strict_match)
) / num_headings
cui_hit_rate = len(cui_hits) / num_headings
partial_cui_hit_rate = len(partial_cui_hit) / num_headings
strict_match_rate = len(strict_match) / num_headings
no_match_rate = len(no_match) / num_headings

total_hit_rate, cui_hit_rate, partial_cui_hit_rate, strict_match_rate, no_match_rate

In [None]:
no_match

In [None]:
strict_match

In [None]:
Counter(no_match).most_common()