In [None]:
import pickle
import re
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Set

# from langchain.chains import LLMChain
# from langchain.chat_models import AzureChatOpenAI
from medcat.cat import CAT
from spacy.pipeline import Sentencizer
from tqdm.notebook import tqdm

# from discharge_summaries.prompts.diagnosis_summary import diagnosis_summary_prompt
from discharge_summaries.schemas.medcat import MedCATSpan
from discharge_summaries.schemas.mimic import Note, Record

In [None]:
DATA_DIR = Path.cwd().parent / "data"
DATA_PATH = DATA_DIR / "train.pkl"

MODEL_PATH = Path.cwd().parent / "models" / "medcat_model_pack_328be3555e2c7a37"

In [None]:
with open(DATA_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset[:5]
len(dataset)

## Models

https://lhncbc.nlm.nih.gov/ii/tools/MetaMap/Docs/SemanticTypes_2018AB.txt

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)
# type_ids_filter = [
#     "T020",
#     "T190",
#     "T049",
#     "T019",
#     "T047",
#     "T050",
#     "T033",
#     "T037",
#     "T048",
#     "T191",
#     "T046",
#     "T184",
# ] + ["T005", "T007"]
# cui_filters = {
#     cui
#     for type_ids in type_ids_filter
#     for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
# }
# cat.cdb.config.linking["filters"]["cuis"] = cui_filters

In [None]:
cat.pipe.spacy_nlp.pipeline

In [None]:
cat(dataset[0].discharge_summary.bhc_paragraphs[0].text).ents

In [None]:
cat.pipe._nlp.add_pipe(
    "sentencizer", config={"punct_chars": Sentencizer.default_punct_chars.append("\n")}
)
cat.pipe._nlp.enable_pipe("sentencizer")

In [None]:
cat.pipe.spacy_nlp.pipeline

## Generate summary

In [None]:
def annotate_medcat_entities(text: str, cat: CAT, datetime: str) -> List[MedCATSpan]:
    doc = cat(text)
    sentences = list(doc.sents)
    num_context_sentences = 1
    total_context_sentences = 2 * num_context_sentences + 1
    ents = []
    for sent_idx, sent in enumerate(doc.sents):
        if total_context_sentences > len(sentences):
            context = sentences
        if sent_idx - num_context_sentences < 0:
            context = sentences[:total_context_sentences]
        elif sent_idx + num_context_sentences + 1 > len(sentences):
            context = sentences[-total_context_sentences:]
        else:
            context = sentences[
                sent_idx - num_context_sentences : sent_idx + num_context_sentences + 1
            ]
        context_str = " ".join([str(sent) for sent in context])
        for ent in sent.ents:
            ents.append(
                MedCATSpan.from_spacy_span(
                    ent, cat, context=context_str, datetime=datetime
                )
            )
    return ents


def group_entities_by_cui(entities: List[MedCATSpan]) -> Dict[str, List[MedCATSpan]]:
    cui_to_entities: Dict[str, List[MedCATSpan]] = defaultdict(list)
    for entity in entities:
        cui_to_entities[entity.cui].append(entity)
    return cui_to_entities


def string_to_word_set(string: str) -> Set[str]:
    return set(re.findall(r"\w+", string.lower()))


def de_duplicate_entities_based_on_content_overlap(
    entities: List[MedCATSpan], threshold=0.75
) -> List[MedCATSpan]:
    unique_entities_and_set: List[tuple[MedCATSpan, Set[str]]] = []

    for entity in entities:
        is_duplicate = False
        entity_set = string_to_word_set(entity.context)
        for _, unique_entity_set in unique_entities_and_set:
            overlap = len(entity_set.intersection(unique_entity_set)) / len(
                unique_entity_set
            )
            if overlap > threshold:
                is_duplicate = True
                break

        if not is_duplicate:
            unique_entities_and_set.append((entity, entity_set))

    return [entity for entity, _ in unique_entities_and_set]


def extract_cui_to_entities(
    physician_notes: List[Note], cat: CAT
) -> Dict[str, List[MedCATSpan]]:
    entities = [
        ent
        for note in physician_notes
        if note.category == "Physician "
        for ent in annotate_medcat_entities(note.text, cat, datetime=note.datetime)
    ]

    cui_to_entities = group_entities_by_cui(entities)

    cui_to_entities = {
        cui: de_duplicate_entities_based_on_content_overlap(entities)
        for cui, entities in cui_to_entities.items()
    }

    return cui_to_entities


def cui_to_name(cui: str, cat: CAT) -> str:
    return cat.cdb.get_name(cui)


def cuis_to_names(cuis: List[str], cat: CAT) -> List[str]:
    return sorted(cat.cdb.get_name(cui) for cui in cuis)

In [None]:
extracts = [
    extract_cui_to_entities(sample.physician_notes, cat) for sample in tqdm(dataset)
]

In [None]:
heading_cuis = [ for ]

In [None]:
extracts[0]

In [None]:
# timestamp_str = datetime.now().strftime("%Y_%m_%d_%H_%M")
# output_path = OUTPUT_DIR / (timestamp_str + ".json")

# with output_path.open("w") as fout:
#     json.dump([[para.Dict() for para in bhc] for bhc in generations], fout, indent=4)

## UMLS Matching (legacy)

formatted_summaries = []
for summary in output.split("\n\n"):
    words = summary.split(" ")
    lines = [" ".join(words[i : i + 10]) for i in range(0, len(words), 10)]
    formatted_summaries.append("\n".join(lines))

(DATA_DIR / "test_output.txt").write_text("\n\n".join(formatted_summaries))

def calc_matches_including_parents(cuis: set[str], cuis_compare: set[str], cat: CAT)->List[str]:
    cuis_compare_parents = {parent_cui for cui in cuis_compare for parent_cui in get_related_cuis(cui, cat, relation_labels=['CHD','RN'])}
    matched = []
    for cui in cuis:
        if cui in cuis_compare:
            matched.append(cui)
        elif cui in cuis_compare_parents:
            matched.append(cui)
        elif len(get_related_cuis(cui, cat, relation_labels=['CHD','RN']).intersection(cuis_compare)) > 0:
            matched.append(cui)
        else:
            print(cui_to_name(cui, cat))
            print(cuis_to_names(cuis_compare_parents, cat))
            print(cuis_to_names(get_related_cuis(cui, cat, relation_labels=['CHD','RN']), cat))

    return matched

matched = []
cuis_true_parents = {parent_cui for cui in cui_true for parent_cui in get_related_cuis(cui, cat, relation_labels=['CHD','RN'])}
for cui_pred in cuis_pred:
    if cui_pred in cuis_true or cui_pred in cuis_true_parents or get_related_cuis(cui_pred, cat, relation_labels=['CHD','RN']).intersection(cuis_true):
        matched.append(cui_pred)

len(matched)/len(cui_pred)

matches = base_matches.union(parent_true_matches).union(parent_pred_matches)
cuis_to_names(matches, cat), cuis_to_names(cuis_true, cat), cuis_to_names(cuis_pred, cat)

def precision(cuis_true: set[str], cuis_pred: set[str])->float:
    cui_true_w_parents = cuis_true.union({parent_cui for cui in cui_true for parent_cui in get_related_cuis(cui, relation_labels=['CHD','RN'])})
    true_positive = 0
    print(cuis_to_names(cui_true_w_parents, cat))
    for cui_pred in cuis_pred:
        cui_pred_w_parents = set(cui_pred).union(get_related_cuis(cui_pred, relation_labels=['CHD','RN']))
        if cui_pred_w_parents.intersection(cui_true_w_parents):
            print(cui_to_name(cui_pred, cat), cuis_to_names(cui_pred_w_parents.intersection(cui_true_w_parents), cat))
            true_positive += 1

    return true_positive/len(cui_pred)

precision(cui_true, cui_pred)

import requests

def get_related_cuis(cui:str, cat:CAT, relation_labels:List[str]|None=None, api_key:str=UMLS_API_KEY, umls_base_url:str=UMLS_BASE_URL, page_size:int=1000)->set[str]:
    url = f"{umls_base_url}/content/current/CUI/{cui}/relations"
    relation_label_str = ",".join(relation_labels) if relation_labels else ""
    params = {"apiKey": api_key, "includeRelationLabels": relation_label_str, "pageNumber": 1, "pageSize": page_size}
    results = []
    completed = False
    while not completed:
        response = requests.get(url, params)
        if response.status_code == 200:
            response_json = response.json()
            results.extend(response_json["result"])
            params["pageNumber"] += 1
            if response_json["pageCount"] == 1:
                completed = True
        else:
            print(f"Failed to retrieve concept information for CUI: {cui}. Error code {response.status_code}")
            completed = True

    related_cuis = {result["relatedId"].split("/")[-1] for result in results if "relatedId" in result}
    return {cui for cui in related_cuis if cui in cat.cdb.cui2names}

def get_child_to_parent_cui(org_cuis: List[str], cat)->Dict[str, str]:
    child_to_parent_cui = {}
    for cui in org_cuis:
        parent_cuis = get_related_cuis(cui, cat, relation_labels=['CHD','RN'])
        parent_cuis = parent_cuis.intersection(org_cuis)
        if parent_cuis:
            child_to_parent_cui[cui] = next(iter(parent_cuis))

    child_to_extend_parent_cui = {}
    for child_cui in child_to_parent_cui.keys():
        parent_exists = True
        extend_parent_cui = child_cui
        while parent_exists:
            if extend_parent_cui in child_to_parent_cui:
                extend_parent_cui = child_to_parent_cui[extend_parent_cui]
            else:
                parent_exists = False
        if extend_parent_cui != child_cui:
            child_to_extend_parent_cui[child_cui] = extend_parent_cui
    return child_to_extend_parent_cui

child_to_parent_cui = get_child_to_parent_cui(cui_to_entities.keys(), cat)
{cui_to_name(child_cui, cat): cui_to_name(parent_cui, cat) for child_cui, parent_cui in child_to_parent_cui.items()}