In [None]:
import pickle
from pathlib import Path
from typing import Tuple

import numpy as np
from medcat.cat import CAT
from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"

TRAINING_DATASET_PATH = DATA_DIR / "train.pkl"
MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "umls_sm_pt2ch_533bab5115c6c2d6mimic_tuned_ebf8b5bb5099c274.zip"
)
RANDOM_SEED = 23

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset[:10]
len(dataset)

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)

In [None]:
cat.pipe.spacy_nlp.disable_pipes(["Status"])

In [None]:
# filter_type_names = {
#     "disorder",
#     "finding",
# }

# type_name_to_id = {
#     name: type_id for type_id, name in cat.cdb.addl_info["type_id2name"].items()
# }

# type_ids_filter = [type_name_to_id[type_name] for type_name in filter_type_names]
# type_ids_filter = [
#     "T020",
#     "T190",
#     "T049",
#     "T019",
#     "T047",
#     "T050",
#     "T033",
#     "T037",
#     "T048",
#     "T191",
#     "T046",
#     "T184",
# ] + ["T005", "T007"]
# cui_filters = {
#     cui
#     for type_ids in type_ids_filter
#     for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
# }
# cat.cdb.config.linking["filters"]["cuis"] = cui_filters

In [None]:
def extract_cuis_from_text(text: str, cat: CAT):
    text_ents = cat(text).ents if text else ()
    return {ent._.cui for ent in text_ents}

What % of non empty headings can be annotated by MedCAT and then have matching annotations in the notes?

In [None]:
num_headings = 0
hits = 0
not_annotated = 0
not_match = 0
partial_match = 0
for doc in tqdm(dataset):
    doc_note_cuis = {
        cui
        for note in doc.physician_notes
        for cui in extract_cuis_from_text(note.text, cat)
    }
    for para in doc.discharge_summary.bhc_paragraphs:
        if not para.text:
            continue
        para_cuis = extract_cuis_from_text(para.heading, cat)
        num_headings += 1
        if not para_cuis:
            not_annotated += 1
        elif not para_cuis.issubset(doc_note_cuis):
            if para_cuis.intersection(doc_note_cuis):
                partial_match += 1
            else:
                not_match += 1
        else:
            hits += 1
hits / num_headings, partial_match / num_headings, not_match / num_headings, not_annotated / num_headings,

In [None]:
dataset_heading_cuis = [
    {
        cui
        for para in doc.discharge_summary.bhc_paragraphs
        for cui in extract_cuis_from_text(para.heading, cat)
    }
    for doc in tqdm(dataset)
]

In [None]:
dataset_note_cuis = [
    {
        cui
        for note in doc.physician_notes
        for cui in extract_cuis_from_text(note.text, cat)
    }
    for doc in tqdm(dataset)
]

In [None]:
def calculate_tp_fp_fn(actual_set, predicted_set) -> Tuple[float, float, float]:
    true_positives = len(actual_set.intersection(predicted_set))
    false_positives = len(predicted_set - actual_set)
    false_negatives = len(actual_set - predicted_set)

    return true_positives, false_positives, false_negatives

In [None]:
tp_fp_fn = np.array(
    [
        calculate_tp_fp_fn(actual, pred)
        for actual, pred in zip(dataset_heading_cuis, dataset_note_cuis)
    ],
)
true_positives, false_positives, false_negatives = tp_fp_fn.sum(axis=0)

In [None]:
precision = (
    (true_positives / (true_positives + false_positives))
    if true_positives + false_positives != 0
    else 0.0
)
recall = (
    (true_positives / (true_positives + false_negatives))
    if true_positives + false_negatives != 0
    else 0.0
)
f1 = (2 * precision * recall / (precision + recall)) if precision + recall != 0 else 0.0
precision, recall, f1