In [None]:
import pickle
from collections import Counter
from datetime import datetime
from pathlib import Path
from typing import List

from medcat.utils.preprocess_snomed import Snomed
from spacy.lang.en import English
from spacy.matcher import PhraseMatcher
from spacy.tokens import Span
from spacy.training import biluo_to_iob, offsets_to_biluo_tags
from tqdm.notebook import tqdm

from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"

TRAINING_DATASET_PATH = DATA_DIR / "train.pkl"
TIMESTAMP = datetime.now().strftime("%Y_%m_%d_%H_%M")
TRAINING_ANNO_DATASET_PATH = DATA_DIR / f"train_anno_{TIMESTAMP}.pkl"
DATASET_NOTE_CUI_CACHE_PATH = DATA_DIR / "dataset_note_cui_cache.json"
MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)
RANDOM_SEED = 23
LOG_FILE = "./medcat.log"
DIRECT_LABEL = "DIRECT"
SNOMED_PATH = (
    Path.cwd().parent / "data" / "SnomedCT_InternationalRF2_PRODUCTION_20230731T120000Z"
)

SPACY_MODEL = "en_core_web_md"

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset
len(dataset)

Preprocessing SNOMED CT for MedCAT

In [None]:
sowmed = Snomed(str(SNOMED_PATH))
sowmed.uk_ext = True

In [None]:
df = sowmed.to_concept_df()
df.head()

In [None]:
df["description_type_ids"].unique()

In [None]:
filter_type_names = {
    "disorder",
    "finding",
    "morphologic abnormality",
    "organism",
    "physical object",
    "clinical drug",
    "medicinal product form",
    "procedure",
    "product",
}
assert all(name in df["description_type_ids"].unique() for name in filter_type_names)

In [None]:
df = df[df["description_type_ids"].isin(filter_type_names)]
len(df)

In [None]:
# df_subset = df[df['description_type_ids'].isin(['finding', 'disorder'])]
df_subset = df[df["name_status"] == "A"]
len(df_subset), len(df_subset["cui"].unique())

In [None]:
df_subset.head(10)

In [None]:
tokenizer = English().tokenizer

In [None]:
snomed_matcher = PhraseMatcher(tokenizer.vocab, "LOWER")
for cui, group_df in tqdm(df_subset.groupby("cui")):
    snomed_matcher.add(cui, list(tokenizer.pipe(group_df["name"])))

In [None]:
matches = snomed_matcher(tokenizer("heart attack"), as_spans=True)
matches[0].label_

In [None]:
def resolve_overlapping_spans(spans: List[Span]) -> List[Span]:
    spans.sort(key=lambda span: len(span), reverse=True)

    merged_spans: List[Span] = []
    for span in spans:
        overlap = False
        for existing_span in merged_spans:
            if span.start < existing_span.end and span.end > existing_span.start:
                overlap = True
                if len(span) >= len(existing_span):
                    merged_spans.remove(existing_span)
                    merged_spans.append(span)
                break

        if not overlap:
            merged_spans.append(span)

    return sorted(merged_spans, key=lambda span: span.start)

In [None]:
dataset_annotations = []
for doc in tqdm(dataset):
    doc_annotations = []

    doc_headings = [
        para.heading for para in doc.discharge_summary.bhc_paragraphs if para.heading
    ]
    para_cuis = {
        match.label_
        for match in snomed_matcher(tokenizer("\n\n".join(doc_headings)), as_spans=True)
    }

    doc_matcher = PhraseMatcher(tokenizer.vocab, attr="LOWER")
    doc_matcher.add("DIRECT", list(tokenizer.pipe(doc_headings)))
    for cui in para_cuis:
        doc_matcher.add(
            cui, list(tokenizer.pipe(df_subset[df_subset["cui"] == cui]["name"]))
        )

    for note in doc.physician_notes:
        spacy_note = tokenizer(note.text)
        spans = doc_matcher(spacy_note, as_spans=True)
        resolved_spans = resolve_overlapping_spans(spans)
        doc_annotations.append({"doc": spacy_note, "spans": resolved_spans})
    dataset_annotations.append(doc_annotations)

In [None]:
text_annotations = {
    anno.text for note_annos in dataset_annotations[0] for anno in note_annos["spans"]
}
headings = sorted(
    [
        para.heading
        for para in dataset[0].discharge_summary.bhc_paragraphs
        if para.heading
    ]
)
text_annotations, headings

In [None]:
num_headings = 0
num_matches = 0
cui_hits = []
partial_cui_hit = []
strict_match = []
no_match = []

for doc, docs_annotations in tqdm(zip(dataset, dataset_annotations)):
    doc_anno_cuis = {
        anno.label_
        for note_annotations in docs_annotations
        for anno in note_annotations["spans"]
        if anno.label_ != DIRECT_LABEL
    }
    doc_anno_direct_text = {
        anno.text.lower()
        for note_annotations in docs_annotations
        for anno in note_annotations["spans"]
        if anno.label_ == DIRECT_LABEL
    }
    for para in doc.discharge_summary.bhc_paragraphs:
        if not para.heading:
            continue
        num_headings += 1
        para_cuis = {
            match.label_
            for match in snomed_matcher(tokenizer(para.heading), as_spans=True)
        }
        if para_cuis and para_cuis.issubset(doc_anno_cuis):
            cui_hits.append(para.heading)
        elif para_cuis.intersection(doc_anno_cuis) != set():
            partial_cui_hit.append(para.heading)
        elif para.heading.lower() in doc_anno_direct_text:
            strict_match.append(para.heading)
        else:
            no_match.append(para.heading)

In [None]:
total_hit_rate = (
    len(cui_hits) + len(partial_cui_hit) + len(strict_match)
) / num_headings
cui_hit_rate = len(cui_hits) / num_headings
partial_cui_hit_rate = len(partial_cui_hit) / num_headings
strict_match_rate = len(strict_match) / num_headings
no_match_rate = len(no_match) / num_headings

total_hit_rate, cui_hit_rate, partial_cui_hit_rate, strict_match_rate, no_match_rate

In [None]:
Counter(no_match).most_common()

In [None]:
dataset_iob_annotations = []
for doc_annotations in dataset_annotations:
    doc_iob_annotations = []
    for note_annotations in doc_annotations:
        offsets = [
            (anno.start_char, anno.end_char, "PRIORITY")
            for anno in note_annotations["spans"]
        ]
        iob_annotations = biluo_to_iob(
            offsets_to_biluo_tags(note_annotations["doc"], offsets)
        )
        tokens = [token.text for token in note_annotations["doc"]]
        doc_iob_annotations.append({"tokens": tokens, "ner_tags": iob_annotations})
    dataset_iob_annotations.append(doc_iob_annotations)

In [None]:
for para_iob_annotations in dataset_iob_annotations[0][:100]:
    for token, tag in zip(
        para_iob_annotations["tokens"], para_iob_annotations["ner_tags"]
    ):
        if tag != "O":
            print(repr(token), tag)

In [None]:
with open(TRAINING_ANNO_DATASET_PATH, "wb") as out_file:
    pickle.dump(dataset_iob_annotations, out_file)