In [None]:
import pickle
import re
from collections import Counter
from pathlib import Path

from matplotlib import pyplot as plt
from medcat.cat import CAT
from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"
GT_DATA_PATH = DATA_DIR / "train.pkl"

MODEL_PATH = Path.cwd().parent / "models" / "umls_sm_pt2ch_533bab5115c6c2d6.zip"

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)
cat.pipe.force_remove("Status")

## Load train headings

In [None]:
with open(GT_DATA_PATH, "rb") as in_file:
    train_dataset = [Record(**record) for record in pickle.load(in_file)]
len(train_dataset)

In [None]:
train_headings = [
    para.heading
    for sample in train_dataset
    for para in sample.discharge_summary.bhc_paragraphs
    if para.heading
]
len(train_headings)

In [None]:
for heading in train_headings[:100]:
    print(heading)
    for entity in cat.get_entities(heading)["entities"].values():
        print(entity["pretty_name"], entity["types"], entity["type_ids"])
    print("*" * 80)

## Type IDS for each heading

In [None]:
def extract_type_ids(text: str, cat: CAT) -> set[str]:
    annotated_text = cat(text)
    return (
        {
            type_id
            for ent in annotated_text.ents
            for type_id in cat.cdb.cui2type_ids.get(ent._.cui, [])
        }
        if annotated_text
        else set()
    )

In [None]:
train_heading_type_ids = [
    extract_type_ids(heading, cat) for heading in tqdm(train_headings)
]

Recall

In [None]:
num_hits = sum(1 for type_id in train_heading_type_ids if type_id)
(num_hits) / len(train_headings)

In [None]:
type_id_frequency = Counter(
    type_id for type_ids in train_heading_type_ids for type_id in type_ids
)
type_id_frequency

In [None]:
k_proportion_labelled = []
for k in range(1, len(type_id_frequency) + 1):
    chosen_type_ids = {type_id for type_id, _ in type_id_frequency.most_common(k)}
    num_headings_labelled = sum(
        [
            1
            for heading_ids in train_heading_type_ids
            if heading_ids.intersection(chosen_type_ids)
        ]
    )
    k_proportion_labelled.append(num_headings_labelled / len(train_headings))
k_proportion_labelled

In [None]:
plt.plot(range(1, len(chosen_type_ids) + 1), k_proportion_labelled)

Type_ids chosen as disorder + virus and bacteria

In [None]:
chosen_type_ids = set(
    [
        "T020",
        "T190",
        "T049",
        "T019",
        "T047",
        "T050",
        "T033",
        "T037",
        "T048",
        "T191",
        "T046",
        "T184",
    ]
    + ["T005", "T007"]
)
num_headings_labelled = sum(
    [
        1
        for heading_ids in train_heading_type_ids
        if heading_ids.intersection(chosen_type_ids)
    ]
)
num_headings_labelled / len(train_headings)

## Look at which headings were totally missed

In [None]:
missed_headings = [
    heading.lower()
    for heading, type_ids in zip(train_headings, train_heading_type_ids)
    if not type_ids
]
formatted_misses = []
for heading in missed_headings:
    split_space = heading.split(" ", maxsplit=1)
    if len(split_space) > 1:
        formatted_misses.append(split_space[1].strip())
    else:
        formatted_misses.append(heading[1:])
Counter(formatted_misses)

In [None]:
train_headings_reduced = [
    para.heading
    for sample in train_dataset
    for para in sample.discharge_summary.bhc_paragraphs
    if para.heading
    and not re.search(r"code|access|communication|fen", para.heading.lower())
]
len(train_headings_reduced)

In [None]:
train_heading_type_ids_reduced = [
    extract_type_ids(heading, cat) for heading in tqdm(train_headings_reduced)
]

All type_ids

In [None]:
num_hits = sum(1 for type_id in train_heading_type_ids_reduced if type_id)
num_hits / len(train_headings_reduced)

Selected type ids

In [None]:
num_headings_labelled = sum(
    [
        1
        for heading_ids in train_heading_type_ids_reduced
        if heading_ids.intersection(chosen_type_ids)
    ]
)
num_headings_labelled / len(train_headings_reduced)