In [None]:
import pickle
from pathlib import Path
from typing import Tuple

import numpy as np
from medcat.cat import CAT
from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"

TRAINING_DATASET_PATH = DATA_DIR / "train.pkl"
MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)
RANDOM_SEED = 23

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset
len(dataset)

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)

In [None]:
filter_type_names = {
    "disorder",
    "finding",
    "morphologic abnormality",
    "organism",
    "physical object",
    "clinical drug",
    "medicinal product form",
    "procedure",
    "product",
}

type_name_to_id = {
    name: type_id for type_id, name in cat.cdb.addl_info["type_id2name"].items()
}

type_ids_filter = [type_name_to_id[type_name] for type_name in filter_type_names]
# type_ids_filter = [
#     "T020",
#     "T190",
#     "T049",
#     "T019",
#     "T047",
#     "T050",
#     "T033",
#     "T037",
#     "T048",
#     "T191",
#     "T046",
#     "T184",
# ] + ["T005", "T007"]

cui_filters = {
    cui
    for type_ids in type_ids_filter
    for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
}
cat.cdb.config.linking["filters"]["cuis"] = cui_filters
cat.cdb.config.linking["filters"]["cuis"] = set()

In [None]:
cat.pipe.spacy_nlp.pipeline

In [None]:
cat.pipe.spacy_nlp.disable_pipes(["Status"])

In [None]:
def extract_cuis_from_text(text: str, cat: CAT):
    text_ents = cat(text).ents if text else ()
    return {ent._.cui for ent in text_ents}

What % of non empty headings can be annotated by MedCAT and then have matching annotations in the notes?

In [None]:
# import re
# num_headings = 0
# hits_search = []
# partial_search = []
# not_match_search = []
# for doc in tqdm(dataset):
#     doc_notes = "\n".join(note.text.lower() for note in doc.physician_notes)
#     for para in doc.discharge_summary.bhc_paragraphs:
#         if para.heading == "":
#             continue
#         heading_split = [heading.lower() for heading in re.split("/", para.heading)]
#         for heading in heading_split:
#             num_headings += 1
#             if heading in doc_notes:
#                 hits_search.append(para.heading)
#                 break
#             synonyms = {name.replace("~", " ") for ent in cat(heading).ents for name in cat.cdb.cui2names[ent._.cui]}
#             if any(synonym in doc_notes for synonym in synonyms):
#                 partial_search.append((para.heading, [synonym for synonym in synonyms if synonym in doc_notes]))
#             else:
#                 not_match_search.append(para.heading)
# len(hits_search) / num_headings, len(partial_search) / num_headings, len(not_match_search) / num_headings,

In [None]:
num_headings = 0
hits = []
not_annotated = []
partial_match = []
strict_match = []
no_match = []
for doc in tqdm(dataset[:10]):
    doc_note_cuis = {
        cui
        for note in doc.physician_notes
        for cui in extract_cuis_from_text(note.text, cat)
    }
    for para in doc.discharge_summary.bhc_paragraphs:
        if not para.heading:
            continue
        para_cuis = extract_cuis_from_text(para.heading, cat)
        num_headings += 1
        if para_cuis.issubset(doc_note_cuis):
            hits.append(para.heading)
        elif para_cuis.intersection(doc_note_cuis):
            partial_match.append(para.heading)
        elif para.heading.lower() in "\n".join(
            note.text.lower() for note in doc.physician_notes
        ):
            strict_match.append(para.heading)
        elif not para_cuis:
            not_annotated.append(para.heading)
        else:
            no_match.append(para.heading)
len(hits) / num_headings, len(partial_match) / num_headings, len(
    strict_match
) / num_headings, len(no_match) / num_headings, len(not_annotated) / num_headings,

In [None]:
partial_match

In [None]:
missed = []
for doc in tqdm(dataset):
    for para in doc.discharge_summary.bhc_paragraphs:
        if not para.heading:
            continue
        if not cat(para.heading).ents:
            missed.append(para.heading)

In [None]:
len(missed)

In [None]:
cat("Acute Pancreatis").ents

In [None]:
len(missed) / len(
    [
        para
        for doc in dataset
        for para in doc.discharge_summary.bhc_paragraphs
        if para.heading
    ]
)

In [None]:
set(missed)

In [None]:
(len(hits) + len(partial_match) + len(strict_match)) / num_headings

In [None]:
strict_match

In [None]:
print(dataset[0].discharge_summary.bhc)

In [None]:
cat("diarrhea").ents

In [None]:
len(
    {
        cui
        for note in dataset[0].physician_notes
        for cui in extract_cuis_from_text(note.text, cat)
    }
)

In [None]:
cat.cdb.addl_info["type_id2name"][
    list(cat.cdb.cui2type_ids[cat.cdb.name2cuis["hypertension"][0]])[0]
]

In [None]:
dataset_heading_cuis = [
    {
        cui
        for para in doc.discharge_summary.bhc_paragraphs
        for cui in extract_cuis_from_text(para.heading, cat)
    }
    for doc in tqdm(dataset)
]

In [None]:
dataset_note_cuis = [
    {
        cui
        for note in doc.physician_notes
        for cui in extract_cuis_from_text(note.text, cat)
    }
    for doc in tqdm(dataset)
]

In [None]:
def calculate_tp_fp_fn(actual_set, predicted_set) -> Tuple[float, float, float]:
    true_positives = len(actual_set.intersection(predicted_set))
    false_positives = len(predicted_set - actual_set)
    false_negatives = len(actual_set - predicted_set)

    return true_positives, false_positives, false_negatives

In [None]:
tp_fp_fn = np.array(
    [
        calculate_tp_fp_fn(actual, pred)
        for actual, pred in zip(dataset_heading_cuis, dataset_note_cuis)
    ],
)
true_positives, false_positives, false_negatives = tp_fp_fn.sum(axis=0)

In [None]:
precision = (
    (true_positives / (true_positives + false_positives))
    if true_positives + false_positives != 0
    else 0.0
)
recall = (
    (true_positives / (true_positives + false_negatives))
    if true_positives + false_negatives != 0
    else 0.0
)
f1 = (2 * precision * recall / (precision + recall)) if precision + recall != 0 else 0.0
precision, recall, f1