In [None]:
from medcat.cat import CAT

In [None]:
import os
import pickle
from collections import Counter
from pathlib import Path

import numpy as np
from dotenv import load_dotenv
from sklearn.model_selection import train_test_split
from tqdm.notebook import tqdm

from discharge_summaries.schemas.medcat import MedCATSpan
from discharge_summaries.schemas.mimic import Record

In [None]:
load_dotenv()


UMLS_API_KEY = os.environ.get("UMLS_API_KEY")

UMLS_BASE_URL = "https://uts-ws.nlm.nih.gov/rest"

DATA_DIR = Path.cwd().parent / "data"

CLEAN_DATASET_SAVE_PATH = DATA_DIR / "clean_df.pkl"
MODEL_PATH = Path.cwd().parent / "models" / "umls_sm_pt2ch_533bab5115c6c2d6.zip"
RANDOM_SEED = 23

In [None]:
with open(CLEAN_DATASET_SAVE_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
len(dataset)

In [None]:
train_dataset, test_dataset = train_test_split(
    dataset, test_size=0.5, random_state=RANDOM_SEED
)
train_dataset = train_dataset[:100]
len(train_dataset), len(test_dataset)

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)

In [None]:
def run_medcat_on_dataset(dataset: list[str], cat: CAT) -> list[list[MedCATSpan]]:
    dataset_entities = []
    for sample_string in tqdm(dataset):
        medcat_entities = cat(sample_string)
        dataset_entities.append(
            [MedCATSpan.from_spacy_span(entity, cat) for entity in medcat_entities.ents]
            if medcat_entities
            else []
        )
    return dataset_entities

## What is in a bhc?

In [None]:
bhc_entities = run_medcat_on_dataset(
    [sample.discharge_summary.bhc for sample in train_dataset], cat
)
bhc_entities[0]

In [None]:
type_ids = [
    type_id
    for doc_entity in bhc_entities
    for entity in doc_entity
    for type_id in entity.type_ids
]
Counter(type_ids)

In [None]:
bhc_entities_047 = [
    [entity for entity in doc_entities if "T047" in entity.type_ids]
    for doc_entities in bhc_entities
]
bhc_entities_047_cui = [
    {entity.cui for entity in doc_entities} for doc_entities in bhc_entities_047
]
bhc_entities_047_cui[0]

## Prefixes

In [None]:
prefixes = ["\n".join(sample.discharge_summary.prefixes) for sample in train_dataset]
prefixes

In [None]:
prefix_entities = run_medcat_on_dataset(prefixes, cat)
Counter(
    type_id
    for doc_entities in prefix_entities
    for entity in doc_entities
    for type_id in entity.type_ids
)

In [None]:
prefix_entities_047 = [
    [entity for entity in doc_entities if "T047" in entity.type_ids]
    for doc_entities in prefix_entities
]
prefix_entities_047_cui = [
    {entity.cui for entity in doc_entities} for doc_entities in prefix_entities_047
]
prefix_entities_047_cui[0]

## Notes entities

In [None]:
type_ids_filter = ["T047"]
cui_filters = {
    cui
    for type_ids in type_ids_filter
    for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
}
cat.cdb.config.linking["filters"]["cuis"] = cui_filters

In [None]:
notes = [
    "\n".join(note.text for note in sample.physician_notes) for sample in train_dataset
]
notes_entities = run_medcat_on_dataset(notes, cat)
notes_entities[0]

In [None]:
notes_entities_cui = [
    {entity.cui for entity in doc_entities} for doc_entities in notes_entities
]
notes_entities_cui[0]

In [None]:
def calculate_precision_recall(actual_set, predicted_set) -> tuple[float, float]:
    if len(actual_set) == 0 and len(predicted_set) == 0:
        return 1, 1
    elif len(actual_set) == 0:
        return 0, 1
    if len(predicted_set) == 0:
        return 1, 0

    true_positives = len(actual_set.intersection(predicted_set))
    false_positives = len(predicted_set - actual_set)
    false_negatives = len(actual_set - predicted_set)

    precision = true_positives / (true_positives + false_positives)
    recall = true_positives / (true_positives + false_negatives)

    return precision, recall

In [None]:
prefix_precision_recall = np.array(
    [
        calculate_precision_recall(actual, pred)
        for actual, pred in zip(prefix_entities_047_cui, notes_entities_cui)
    ]
).mean(axis=0)
prefix_precision_recall

In [None]:
bhc_precision_recall = np.array(
    [
        calculate_precision_recall(actual, pred)
        for actual, pred in zip(bhc_entities_047_cui, notes_entities_cui)
    ]
).mean(axis=0)
bhc_precision_recall

In [None]:
bhc_entities_047[0]

In [None]:
"neutropenia" in notes[0].lower()

In [None]:
sorted({entity.name for entity in bhc_entities_047[0]})

In [None]:
print(train_dataset[0].discharge_summary.bhc)

In [None]:
sorted({entity.name for entity in notes_entities[0]})