In [None]:
import pickle
from collections import Counter
from pathlib import Path

from medcat.cat import CAT
from spacy.matcher import PhraseMatcher
from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import Record
from discharge_summaries.schemas.span import Span

In [None]:
DATA_DIR = Path.cwd().parent / "data"

TRAINING_DATASET_PATH = DATA_DIR / "train.pkl"
TRAINING_ANNO_DATASET_PATH = DATA_DIR / "train_anno.pkl"
DATASET_NOTE_CUI_CACHE_PATH = DATA_DIR / "dataset_note_cui_cache.json"
MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)
RANDOM_SEED = 23
LOG_FILE = "./medcat.log"
DIRECT_LABEL = "DIRECT"

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset
len(dataset)

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)

In [None]:
cat.pipe.spacy_nlp.disable_pipes(["Status"])

In [None]:
def extract_cuis_from_text(text: str, cat: CAT):
    text_ents = cat(text).ents if text else ()
    return {ent._.cui for ent in text_ents}

In [None]:
filter_type_names = {
    "disorder",
    "finding",
    "morphologic abnormality",
    "organism",
    "physical object",
    "clinical drug",
    "medicinal product form",
    "procedure",
    "product",
}

type_name_to_id = {
    name: type_id for type_id, name in cat.cdb.addl_info["type_id2name"].items()
}

type_ids_filter = [type_name_to_id[type_name] for type_name in filter_type_names]

full_cui_filters = {
    cui
    for type_ids in type_ids_filter
    for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
}
cat.cdb.config.linking["filters"]["cuis"] = full_cui_filters
len(full_cui_filters)

In [None]:
dataset_annotations = []
spacy_pipeline = cat.pipe.spacy_nlp
for doc in tqdm(dataset):
    doc_annotations = []

    doc_headings = [
        para.heading for para in doc.discharge_summary.bhc_paragraphs if para.heading
    ]
    matcher = PhraseMatcher(spacy_pipeline.vocab, attr="LOWER")
    matcher.add("DIRECT", list(spacy_pipeline.tokenizer.pipe(doc_headings)))

    cat.cdb.config.linking["filters"]["cuis"] = full_cui_filters
    para_cuis = {ent._.cui for ent in cat("\n\n".join(doc_headings)).ents}

    cat.cdb.config.linking["filters"]["cuis"] = para_cuis
    for idx, note in enumerate(doc.physician_notes):
        note_spacy = cat(note.text)
        note_annotations = [
            Span(start=ent.start_char, end=ent.end_char, text=ent.text, label=ent._.cui)
            for ent in cat(note.text).ents
        ]
        cui_anno_start_ends = {(anno.start, anno.end) for anno in note_annotations}
        for match in matcher(spacy_pipeline.tokenizer(note.text), as_spans=True):
            if (match.start_char, match.end_char) not in cui_anno_start_ends:
                note_annotations.append(
                    Span(
                        start=match.start_char,
                        end=match.end_char,
                        text=match.text,
                        label=DIRECT_LABEL,
                    )
                )
        doc_annotations.append(note_annotations)
    dataset_annotations.append(doc_annotations)

In [None]:
matcher = PhraseMatcher(spacy_pipeline.vocab, attr="LOWER")
matcher.add("DIRECT", list(spacy_pipeline.tokenizer.pipe(["apple"])))
matcher(spacy_pipeline.tokenizer("banna apple"))

In [None]:
text_annotations = {
    anno.text for note_annos in dataset_annotations[0] for anno in note_annos
}
headings = sorted(
    [
        para.heading
        for para in dataset[0].discharge_summary.bhc_paragraphs
        if para.heading
    ]
)
text_annotations, headings

In [None]:
num_headings = 0
num_matches = 0
cui_hits = []
partial_cui_hit = []
strict_match = []
no_match = []
cat.cdb.config.linking["filters"]["cuis"] = full_cui_filters

for doc, docs_annotations in tqdm(zip(dataset, dataset_annotations)):
    doc_anno_cuis = {
        anno.label
        for note_annotations in docs_annotations
        for anno in note_annotations
        if anno.label != "DIRECT"
    }
    doc_anno_direct_text = {
        anno.text.lower()
        for note_annotations in docs_annotations
        for anno in note_annotations
        if anno.label == "DIRECT"
    }

    for para in doc.discharge_summary.bhc_paragraphs:
        if not para.heading:
            continue
        num_headings += 1
        para_cuis = {ent._.cui for ent in cat(para.heading).ents}
        if para_cuis and para_cuis.issubset(doc_anno_cuis):
            # print(para_cuis, doc_anno_cuis)
            cui_hits.append(para.heading)
        elif para_cuis.intersection(doc_anno_cuis) != set():
            partial_cui_hit.append(para.heading)
        elif para.heading.lower() in doc_anno_direct_text:
            strict_match.append(para.heading)
        else:
            no_match.append(para.heading)

In [None]:
total_hit_rate = (
    len(cui_hits) + len(partial_cui_hit) + len(strict_match)
) / num_headings
cui_hit_rate = len(cui_hits) / num_headings
partial_cui_hit_rate = len(partial_cui_hit) / num_headings
strict_match_rate = len(strict_match) / num_headings
no_match_rate = len(no_match) / num_headings

total_hit_rate, cui_hit_rate, partial_cui_hit_rate, strict_match_rate, no_match_rate

In [None]:
no_match

In [None]:
strict_match

In [None]:
Counter(no_match).most_common()

In [None]:
with open(TRAINING_ANNO_DATASET_PATH, "wb") as out_file:
    pickle.dump(dataset_annotations, out_file)