https://htmlpreview.github.io/?https://github.com/CogStack/MedCATtutorials/blob/main/notebooks/specialised/Preprocessing_SNOMED_CT.html

In [None]:
# import medcat
# import os
# import logging
# import medcat.linking.vector_context_model as vcm

# LOGFILE = './medcat.log'
# vcm.logger.level = logging.DEBUG


# def reset_all_logger_handlers(log_file='temp_medcat.log'):
#   # reset logger handlers in case a block is run multiple times
#     medcat.logger.handlers = medcat.logger.handlers[:1] # include the default NullHandler
#     vcm.logger.handlers = []
#     # remove temp log file if it exists
#     if os.path.exists(log_file):
#         os.remove(log_file)

# reset_all_logger_handlers(LOGFILE)
# vcm.logger.addHandler(logging.FileHandler(LOGFILE))

In [None]:
import pickle
from collections import Counter
from pathlib import Path
from typing import Dict, List, Set, Tuple

import rich
from medcat.cat import CAT
from medcat.cdb import CDB
from spacy.tokens import Span
from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"
GT_DATA_PATH = DATA_DIR / "train.pkl"

MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)

## Load Data

In [None]:
with open(GT_DATA_PATH, "rb") as in_file:
    gt_dataset = [Record(**record) for record in pickle.load(in_file)]
len(gt_dataset)

## Load Model

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)

In [None]:
rich.print(cat.config.dict())

In [None]:
cat.config.ner.min_name_len = 2
cat.config.ner.upper_case_limit_len = 0
cat.config.general.spell_check_len_limit = 0

In [None]:
cat.pipe.spacy_nlp.disable_pipes(["Status"])

## Helper functions

In [None]:
def annotate_dataset(dataset: List[Record], cat: CAT) -> List[List[Span]]:
    return [
        [
            cat(para.heading).ents if para.heading else ()
            for para in doc.discharge_summary.bhc_paragraphs
        ]
        for doc in tqdm(dataset)
    ]


def calc_hit_rate(dataset: List[Record], dataset_ents: List[List[Span]]):
    num_annotated_headings = sum(
        1 for doc_ents in dataset_ents for para_ents in doc_ents if para_ents
    )
    num_headings = sum(
        1
        for doc in dataset
        for para in doc.discharge_summary.bhc_paragraphs
        if para.heading
    )
    return num_annotated_headings / num_headings


def get_top_n_missed_headings(
    dataset: List[Record], dataset_ents: List[List[Span]], n=10
):
    missed_headings = [
        para.heading.lower()
        for doc, doc_ents in zip(dataset, dataset_ents)
        for para, ents in zip(doc.discharge_summary.bhc_paragraphs, doc_ents)
        if para.heading and not ents
    ]
    return Counter(missed_headings).most_common(n)

## Off the shelf performance

In [None]:
dataset_ents = annotate_dataset(gt_dataset, cat)

In [None]:
calc_hit_rate(gt_dataset, dataset_ents)

In [None]:
get_top_n_missed_headings(gt_dataset, dataset_ents)

## Add commonly missed concepts to model

In [None]:
def collect_missed_heading_and_bhcs(
    dataset: List[Record], dataset_ents: List[List[Span]]
) -> List[Tuple[str, str]]:
    return [
        (para.heading, doc.discharge_summary.bhc)
        for doc, doc_ents in zip(dataset, dataset_ents)
        for para, para_ents in zip(doc.discharge_summary.bhc_paragraphs, doc_ents)
        if para.heading and not para_ents
    ]


def add_missing_heading_concepts_to_medcat_model(
    missed_heading_and_bhcs: List[Tuple[str, str]],
    missing_name_to_mapped_cui: Dict[str, str],
    cat: CAT,
):
    # Check assumption that all added headings are 1 token long
    for missed_heading in missing_name_to_mapped_cui.keys():
        assert len(cat.pipe.spacy_nlp.make_doc(missed_heading)) == 1
    for missed_heading, bhc in tqdm(missed_heading_and_bhcs):
        mapped_cui = missing_name_to_mapped_cui.get(missed_heading.lower(), "")
        if not mapped_cui:
            continue

        doc = cat.pipe.spacy_nlp.make_doc(bhc)
        for token in doc:
            if token.text.lower() == missed_heading.lower():
                cat.add_and_train_concept(
                    mapped_cui,
                    token.text,
                    ontologies={"Added"},
                    type_ids=cat.cdb.cui2type_ids[mapped_cui],
                    spacy_doc=doc,
                    spacy_entity=[token],
                )

In [None]:
missing_name_to_mapped_cui = {
    "code": "365870005",
    "fen": "300893006",
    "rhythm": "251149006",
    "ppx": "169443000",
    "dm": "73211009",
    "dispo": "726542003",
    "transaminitis": "160931000119108",
    "comm": "263536004",
}

for missed_heading, mapped_cui in missing_name_to_mapped_cui.items():
    print(missed_heading, cat.cdb.cui2preferred_name[mapped_cui])

In [None]:
missed_heading_and_bhcs = collect_missed_heading_and_bhcs(gt_dataset, dataset_ents)
add_missing_heading_concepts_to_medcat_model(
    missed_heading_and_bhcs, missing_name_to_mapped_cui, cat
)

Sanity check training worked

In [None]:
for missed_heading in missing_name_to_mapped_cui.keys():
    assert len(cat(missed_heading.lower()).ents) == 1
    assert len(cat(missed_heading.upper()).ents) == 1

### Evaluate improvement

In [None]:
dataset_ents = annotate_dataset(gt_dataset, cat)

In [None]:
calc_hit_rate(gt_dataset, dataset_ents)

In [None]:
get_top_n_missed_headings(gt_dataset, dataset_ents)

### Now with filtering...

In [None]:
def cui_to_type_id_name(cui: str, cdb: CDB) -> Set[str]:
    return {cdb.addl_info["type_id2name"][tui_id] for tui_id in cdb.cui2type_ids[cui]}

In [None]:
Counter(
    tuple(
        sorted(
            tui_name
            for ent in para_ents
            for tui_name in cui_to_type_id_name(ent._.cui, cat.cdb)
        )
    )
    for doc_ents in dataset_ents
    for para_ents in doc_ents
).most_common(20)

In [None]:
filter_type_names = {
    "disorder",
    "finding",
    "morphologic abnormality",
    "procedure",
    "situation",
    "observable entity",
    "attribute",
    "substance",
    "organism",
    "disposition",
}

for doc_ents in dataset_ents[:100]:
    for para_ents in doc_ents:
        tuis = {
            tui for ent in para_ents for tui in cui_to_type_id_name(ent._.cui, cat.cdb)
        }
        if para_ents and not filter_type_names.intersection(tuis):
            print(
                para_ents[0].doc,
                [(ent, cui_to_type_id_name(ent._.cui, cat.cdb)) for ent in para_ents],
            )

In [None]:
filtered_dataset_ents = [
    [
        [
            ent
            for ent in para_ents
            if cui_to_type_id_name(ent._.cui, cat.cdb).intersection(filter_type_names)
        ]
        for para_ents in doc_ents
    ]
    for doc_ents in dataset_ents
]

In [None]:
calc_hit_rate(gt_dataset, filtered_dataset_ents)

In [None]:
get_top_n_missed_headings(gt_dataset, filtered_dataset_ents)

Doesn't work so commented out for now

In [None]:
# missing_name_to_mapped_cui = {
#     "coronaries" : "53741008",
#     "pump" : "739122008",
#     "contact" : "263536004",
# }

# for missed_heading, mapped_cui in missing_name_to_mapped_cui.items():
#     print(missed_heading, cat.cdb.cui2preferred_name[mapped_cui])

In [None]:
# missed_heading_and_bhcs = collect_missed_heading_and_bhcs(gt_dataset, filtered_dataset_ents)
# add_missing_heading_concepts_to_medcat_model(
# missed_heading_and_bhcs, missing_name_to_mapped_cui, cat
# )

In [None]:
type_name_to_id = {
    name: type_id for type_id, name in cat.cdb.addl_info["type_id2name"].items()
}

type_ids_filter = [type_name_to_id[type_name] for type_name in filter_type_names]
cui_filters = {
    cui
    for type_ids in type_ids_filter
    for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
}
cat.cdb.config.linking["filters"]["cuis"] = cui_filters

In [None]:
dataset_ents = annotate_dataset(gt_dataset, cat)

In [None]:
calc_hit_rate(gt_dataset, dataset_ents)

In [None]:
get_top_n_missed_headings(gt_dataset, dataset_ents)

## Save Model

In [None]:
cat.create_model_pack(MODEL_PATH.parent, MODEL_PATH.stem + "mimic_tuned")