In [None]:
import pickle
from pathlib import Path

from medcat.cat import CAT
from tqdm.notebook import tqdm

from discharge_summaries.schemas.medcat import MedCATSpan
from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"
GT_DATA_PATH = DATA_DIR / "train.pkl"

MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)

In [None]:
with open(GT_DATA_PATH, "rb") as in_file:
    gt_dataset = [Record(**record) for record in pickle.load(in_file)]
len(gt_dataset)

In [None]:
categories = {note.category for record in gt_dataset for note in record.physician_notes}
categories

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)

In [None]:
cat.pipe.force_remove("Status")
cat.pipe.spacy_nlp.pipeline

https://lhncbc.nlm.nih.gov/semanticnetwork/download.html

In [None]:
# Disorders + [Virus, Bacterium]
type_ids_filter = [
    "T020",
    "T190",
    "T049",
    "T019",
    "T047",
    "T050",
    "T033",
    "T037",
    "T048",
    "T191",
    "T046",
    "T184",
] + ["T005", "T007"]

cui_filters = {
    cui
    for type_ids in type_ids_filter
    for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
}
cat.cdb.config.linking["filters"]["cuis"] = cui_filters

In [None]:
def cui_to_name(cui: str, cat: CAT) -> str:
    return cat.cdb.get_name(cui)


def cuis_to_names(cuis: list[str], cat: CAT) -> list[str]:
    return sorted(cat.cdb.get_name(cui) for cui in cuis)

In [None]:
def extract_spans(text: str, cat: CAT) -> list[MedCATSpan]:
    spans = []
    annotated_text = cat(text)
    if annotated_text:
        for ent in annotated_text.ents:
            context = text[
                max(0, ent.start_char - 30) : min(ent.end_char + 30, len(text))
            ]
            spans.append(MedCATSpan.from_spacy_span(ent, cat, context=context))
    return spans


def extract_cuis(text: str, cat: CAT) -> set[str]:
    annotated_text = cat(text)
    if not annotated_text:
        return set()
    return {ent._.cui for ent in annotated_text.ents}

In [None]:
# num_hits = 0
# num_gts = 0
# for sample in tqdm(gt_dataset[:10]):
#     note_cuis = {
#         cui
#         for note in tqdm(sample.physician_notes)
#         for cui in extract_cuis(note.text, cat)
#     }
#     for para in sample.discharge_summary.bhc_paragraphs:
#         heading_cuis = extract_cuis(para.heading, cat)
#         if not heading_cuis:
#             continue
#         if heading_cuis.intersection(note_cuis):
#             num_hits += 1
#         num_gts += 1

# num_hits / num_gts

In [None]:
num_hits = 0
num_gts = 0
num_preds = 0
for sample in tqdm(gt_dataset[:100]):
    note_cuis = {
        cui
        for note in sample.physician_notes
        if note.category == "Physician "
        for cui in extract_cuis(note.text, cat)
    }
    for para in sample.discharge_summary.bhc_paragraphs:
        heading_cuis = extract_cuis(para.heading, cat)
        if not heading_cuis:
            continue
        if heading_cuis.intersection(note_cuis):
            num_hits += 1
        num_gts += 1
    num_preds += len(note_cuis)

num_hits / num_gts, num_hits / num_preds

In [None]:
sample = gt_dataset[6]
note_spans = [
    span
    for note in tqdm(sample.physician_notes)
    for span in extract_spans(note.text, cat)
]
discharge_spans = [
    extract_spans(para.heading, cat) for para in sample.discharge_summary.bhc_paragraphs
]

In [None]:
note_cui_to_span = {span.cui: span for span in note_spans}
discharge_cui_to_span = [
    {span.cui: span for span in para_spans} for para_spans in discharge_spans
]

In [None]:
discharge_cui_to_span

In [None]:
note_cuis = set(note_cui_to_span.keys())
misses = []
num_hits = 0
num_gts = 0
for heading_cui_to_span in discharge_cui_to_span:
    if not heading_cui_to_span:
        continue
    elif note_cuis.intersection(heading_cui_to_span.keys()):
        num_hits += 1
    else:
        misses.extend(
            [(span.name, span.context) for span in heading_cui_to_span.values()]
        )
    num_gts += 1
num_hits / num_gts

In [None]:
print(sample.discharge_summary.bhc)

In [None]:
[
    [span.name for span in heading_cui_to_span.values()]
    for heading_cui_to_span in discharge_cui_to_span
]

In [None]:
misses

In [None]:
[(note_cui_to_span[cui].name, note_cui_to_span[cui].context) for cui in note_cuis]

In [None]:
for note in sample.physician_notes:
    print(note.category)
    print(note.text)
    print("*" * 80)