In [None]:
import pickle
import re
from collections import Counter
from datetime import datetime
from pathlib import Path
from typing import List, Tuple

from spacy.lang.en import English
from spacy.matcher import PhraseMatcher
from spacy.training import biluo_to_iob, offsets_to_biluo_tags
from tqdm.notebook import tqdm
from transformers import AutoTokenizer

from discharge_summaries.preprocessing.preprocess_snomed import Snomed
from discharge_summaries.schemas.mimic import Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"

TRAINING_DATASET_PATH = DATA_DIR / "train.pkl"
TIMESTAMP = datetime.now().strftime("%Y_%m_%d_%H_%M")
TRAINING_ANNO_DATASET_PATH = DATA_DIR / f"train_anno_{TIMESTAMP}.pkl"
DATASET_NOTE_CUI_CACHE_PATH = DATA_DIR / "dataset_note_cui_cache.json"
MODEL_PATH = (
    Path.cwd().parent
    / "models"
    / "mc_modelpack_snomed_int_16_mar_2022_25be3857ba34bdd5.zip"
)
RANDOM_SEED = 23
LOG_FILE = "./medcat.log"
DIRECT_LABEL = "DIRECT"
SNOMED_PATH = (
    Path.cwd().parent / "data" / "SnomedCT_InternationalRF2_PRODUCTION_20230731T120000Z"
)

SPACY_MODEL = "en_core_sci_md"
MAX_SEGMENT_TOKEN_LENGTH = 400
HF_MODEL_NAME = "roberta-base"

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset
len(dataset)

Preprocessing SNOMED CT for MedCAT

In [None]:
sowmed = Snomed(str(SNOMED_PATH))
sowmed.uk_ext = True

In [None]:
df = sowmed.to_concept_df()
df.head()

In [None]:
df["description_type_ids"].unique()

In [None]:
filter_type_names = {
    "disorder",
    "finding",
    "morphologic abnormality",
    "organism",
    "physical object",
    "clinical drug",
    "medicinal product form",
    "procedure",
    "product",
}
assert all(name in df["description_type_ids"].unique() for name in filter_type_names)

In [None]:
df = df[df["description_type_ids"].isin(filter_type_names)]
len(df)

In [None]:
# df_subset = df[df['description_type_ids'].isin(['finding', 'disorder'])]
df_subset = df[df["name_status"] == "A"]
len(df_subset), len(df_subset["cui"].unique())

In [None]:
df_subset.head(10)

In [None]:
tokenizer_spacy = English().tokenizer

In [None]:
snomed_matcher = PhraseMatcher(tokenizer_spacy.vocab, "LOWER")
for cui, group_df in tqdm(df_subset.groupby("cui")):
    snomed_matcher.add(cui, list(tokenizer_spacy.pipe(group_df["name"])))

In [None]:
doc = tokenizer_spacy("heart attack")
matches = snomed_matcher(doc)
tokenizer_spacy.vocab.strings[matches[0][0]]

In [None]:
def split_note_into_chunks(
    note_text: str, max_segment_token_length: int, tokenizer: AutoTokenizer
) -> List[str]:
    chunks = []
    for section in note_text.split("\n\n"):
        chunk_token_length = 0
        chunk = ""
        for line in re.split("\n(?=[^ a-z])|(?<=\\.)\\s", section):
            line_token_length = len(tokenizer(line)["input_ids"])
            if line_token_length > max_segment_token_length:
                raise ValueError(line)
            if chunk_token_length + line_token_length < max_segment_token_length:
                chunk += f"\n{line}"
                chunk_token_length += line_token_length
            else:
                chunks.append(chunk.strip())
                chunk = line
                chunk_token_length = line_token_length
        # Final chunk
        chunks.append(chunk.strip())
    return chunks

In [None]:
def resolve_overlapping_matches(
    matches: List[Tuple[str, int, int]]
) -> List[Tuple[str, int, int]]:
    matches.sort(key=lambda match: match[2] - match[1], reverse=True)

    merged_matches: List[Tuple[str, int, int]] = []
    for match in matches:
        overlap = False
        for existing_match in merged_matches:
            if match[1] < existing_match[2] and match[2] > existing_match[1]:
                overlap = True
                if len(match) >= len(existing_match):
                    merged_matches.remove(existing_match)
                    merged_matches.append(match)
                break

        if not overlap:
            merged_matches.append(match)

    return sorted(merged_matches, key=lambda match: match[1])

In [None]:
tokenizer_hf = AutoTokenizer.from_pretrained(
    HF_MODEL_NAME, add_prefix_space=True, use_fast=True
)

In [None]:
dataset_annotations = []
for doc in tqdm(dataset):
    doc_annotations = []

    doc_headings = [
        para.heading for para in doc.discharge_summary.bhc_paragraphs if para.heading
    ]
    para_cuis = {
        match.label_
        for match in snomed_matcher(
            tokenizer_spacy("\n\n".join(doc_headings)), as_spans=True
        )
    }

    doc_matcher = PhraseMatcher(tokenizer_spacy.vocab, attr="LOWER")
    for heading in doc_headings:
        doc_matcher.add(
            f"{DIRECT_LABEL}-{heading}", list(tokenizer_spacy.pipe([heading]))
        )
    for cui in para_cuis:
        doc_matcher.add(
            cui, list(tokenizer_spacy.pipe(df_subset[df_subset["cui"] == cui]["name"]))
        )

    for note in doc.physician_notes:
        for chunk in split_note_into_chunks(
            note.text, MAX_SEGMENT_TOKEN_LENGTH, tokenizer_hf
        ):
            spacy_chunk = tokenizer_spacy(chunk)
            matches = doc_matcher(spacy_chunk)
            resolved_matches = resolve_overlapping_matches(matches)
            offsets = [
                (
                    spacy_chunk[start_token:end_token].start_char,
                    spacy_chunk[start_token:end_token].end_char,
                    tokenizer_spacy.vocab.strings[match_id],
                )
                for match_id, start_token, end_token in resolved_matches
            ]
            iob_annotations = biluo_to_iob(offsets_to_biluo_tags(spacy_chunk, offsets))
            tokens = [token.text for token in spacy_chunk]
            doc_annotations.append({"tokens": tokens, "ner_tags": iob_annotations})
    dataset_annotations.append(doc_annotations)

In [None]:
text_annotations = {
    token
    for chunk_annos in dataset_annotations[0]
    for token, tag in zip(chunk_annos["tokens"], chunk_annos["ner_tags"])
    if tag != "O"
}
headings = sorted(
    [
        para.heading
        for para in dataset[0].discharge_summary.bhc_paragraphs
        if para.heading
    ]
)
text_annotations, headings

In [None]:
num_headings = 0
num_matches = 0
cui_hits = []
partial_cui_hit = []
strict_match = []
no_match = []

for doc, docs_annotations in tqdm(zip(dataset, dataset_annotations)):
    doc_anno_cuis = {
        tag[len("B-") :]
        for chunk_annotations in docs_annotations
        for tag in chunk_annotations["ner_tags"]
        if tag != "O" and DIRECT_LABEL not in tag
    }
    doc_anno_direct_text = {
        tag[len(f"B-{DIRECT_LABEL}-") :].lower()
        for chunk_annotations in docs_annotations
        for tag in chunk_annotations["ner_tags"]
        if DIRECT_LABEL in tag
    }

    for para in doc.discharge_summary.bhc_paragraphs:
        if not para.heading:
            continue
        num_headings += 1
        para_cuis = {
            tokenizer_spacy.vocab.strings[label_id]
            for label_id, _, _ in snomed_matcher(tokenizer_spacy(para.heading))
        }
        if para_cuis and para_cuis.issubset(doc_anno_cuis):
            cui_hits.append(para.heading)
        elif para_cuis.intersection(doc_anno_cuis) != set():
            partial_cui_hit.append(para.heading)
        elif para.heading.lower() in doc_anno_direct_text:
            strict_match.append(para.heading)
        else:
            no_match.append(para.heading)

In [None]:
total_hit_rate = (
    len(cui_hits) + len(partial_cui_hit) + len(strict_match)
) / num_headings
cui_hit_rate = len(cui_hits) / num_headings
partial_cui_hit_rate = len(partial_cui_hit) / num_headings
strict_match_rate = len(strict_match) / num_headings
no_match_rate = len(no_match) / num_headings

total_hit_rate, cui_hit_rate, partial_cui_hit_rate, strict_match_rate, no_match_rate

In [None]:
Counter(no_match).most_common()

In [None]:
with open(TRAINING_ANNO_DATASET_PATH, "wb") as out_file:
    pickle.dump(dataset_annotations, out_file)
str(TRAINING_ANNO_DATASET_PATH.name)

In [None]:
num_annos = sum(
    1
    for doc_annos in dataset_annotations
    for chunk_annos in doc_annos
    for tag in chunk_annos["ner_tags"]
    if tag != "O"
)
num_tags = sum(
    len(chunk_annos["ner_tags"])
    for doc_annos in dataset_annotations
    for chunk_annos in doc_annos
)
num_annos / num_tags * 100

In [None]:
num_annos_chunks = sum(
    1
    for doc_annos in dataset_annotations
    for chunk_annos in doc_annos
    if set(chunk_annos["ner_tags"]) != {"O"}
)
num_chunks = sum(1 for doc_annos in dataset_annotations for chunk in doc_annos)
num_annos_chunks / num_chunks * 100