In [None]:
import asyncio
import json
import os
import pickle
import re
from collections import defaultdict
from datetime import datetime
from pathlib import Path

from dotenv import load_dotenv

# from langchain.chains import LLMChain
# from langchain.chat_models import AzureChatOpenAI
from medcat.cat import CAT
from spacy.pipeline import Sentencizer
from tqdm.notebook import tqdm

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel, ChatModel
from discharge_summaries.openai_llm.prompts import generate_diagnosis_summary_prompt

# from discharge_summaries.prompts.diagnosis_summary import diagnosis_summary_prompt
from discharge_summaries.schemas.medcat import MedCATSpan
from discharge_summaries.schemas.mimic import Note, Record
from discharge_summaries.schemas.output import Paragraph

from discharge_summaries.openai_llm.message import Message, Role

In [None]:
DATA_DIR = Path.cwd().parent / "data"
OUTPUT_DIR = Path.cwd() / "output_medcat"
DATA_PATH = DATA_DIR / "train.pkl"

MODEL_PATH = Path.cwd().parent / "models" / "umls_sm_pt2ch_533bab5115c6c2d6.zip"

In [None]:
with open(DATA_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
# dataset = dataset
len(dataset)

In [None]:
sample = dataset[1]
inputs = [note.text for note in sample.physician_notes]
gts = [para.heading for para in sample.discharge_summary.bhc_paragraphs]

In [None]:
len(inputs)

In [None]:
gts

In [None]:
cat = CAT.load_model_pack(MODEL_PATH)
cat.pipe.force_remove("Status")
type_ids_filter = [
    "T020",
    "T190",
    "T049",
    "T019",
    "T047",
    "T050",
    "T033",
    "T037",
    "T048",
    "T191",
    "T046",
    "T184",
] + ["T005", "T007"]
cui_filters = {
    cui
    for type_ids in type_ids_filter
    for cui in cat.cdb.addl_info["type_id2cuis"][type_ids]
}
cat.cdb.config.linking["filters"]["cuis"] = cui_filters
cat.pipe._nlp.add_pipe(
    "sentencizer", config={"punct_chars": Sentencizer.default_punct_chars.append("\n")}
)
cat.pipe._nlp.enable_pipe("sentencizer")

In [None]:
def annotate_medcat_entities(
    text: str, cat: CAT, num_context_sentences=1
) -> list[MedCATSpan]:
    doc = cat(text)
    if not doc:
        return []
    sentences = list(doc.sents)
    total_context_sentences = 2 * num_context_sentences + 1
    ents = []
    for sent_idx, sent in enumerate(doc.sents):
        if total_context_sentences > len(sentences):
            context = sentences
        elif sent_idx - num_context_sentences < 0:
            context = sentences[:total_context_sentences]
        elif sent_idx + num_context_sentences + 1 > len(sentences):
            context = sentences[-total_context_sentences:]
        else:
            context = sentences[
                sent_idx - num_context_sentences : sent_idx + num_context_sentences + 1
            ]
        context_str = " ".join([str(sent) for sent in context])
        for ent in sent.ents:
            ents.append(MedCATSpan.from_spacy_span(ent, cat, context=context_str))
    return ents


def group_entities_by_cui(entities: list[MedCATSpan]) -> dict[str, list[MedCATSpan]]:
    cui_to_entities: dict[str, list[MedCATSpan]] = defaultdict(list)
    for entity in entities:
        cui_to_entities[entity.cui].append(entity)
    return cui_to_entities


def string_to_word_set(string: str) -> set[str]:
    return set(re.findall(r"\w+", string.lower()))


def de_duplicate_entities_based_on_content_overlap(
    entities: list[MedCATSpan], threshold=0.75
) -> list[MedCATSpan]:
    unique_entities_and_set: list[tuple[MedCATSpan, set[str]]] = []

    for entity in entities:
        is_duplicate = False
        entity_set = string_to_word_set(entity.context)
        for _, unique_entity_set in unique_entities_and_set:
            overlap = len(entity_set.intersection(unique_entity_set)) / len(
                unique_entity_set
            )
            if overlap > threshold:
                is_duplicate = True
                break

        if not is_duplicate:
            unique_entities_and_set.append((entity, entity_set))

    return [entity for entity, _ in unique_entities_and_set]


def extract_cui_to_entities(
    physician_notes: list[Note], cat: CAT
) -> dict[str, list[MedCATSpan]]:
    entities = [
        ent
        for note in physician_notes
        # if note.category == "Physician "
        for ent in annotate_medcat_entities(note, cat)
    ]

    cui_to_entities = group_entities_by_cui(entities)

    cui_to_entities = {
        cui: de_duplicate_entities_based_on_content_overlap(entities)
        for cui, entities in cui_to_entities.items()
    }

    return cui_to_entities

In [None]:
medcat_cui_to_entities = extract_cui_to_entities(inputs, cat)

In [None]:
medcat_cui_to_entities

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version="2023-05-15",
    engine="gpt-35-turbo",
    temperature=0,
    timeout=20,
)

In [None]:
def generate_validation_prompt(
    term: str,
) -> list[Message]:
    system_message = Message(
        role=Role.SYSTEM,
        content=f"""You are an expert medical assistant.
Your task is to decide if a user provided term is a medical disorder.
---
User messages follow the format.
Is the following term is a medical disorder?
Term: $[term]
---
Assistant messages follow the format.
$[Yes/No]
""",
    )
    user_message = Message(
        role=Role.USER,
        content=f"""Is the following term is a medical disorder?
Term: {term}""",
    )

    return [system_message, user_message]

In [None]:
cui_to_llm_output = {}

for cui, ents in tqdm(medcat_cui_to_entities.items()):
    name = cat.cdb.get_name(cui)
    cui_to_llm_output[cui] = llm.query(generate_validation_prompt(name))

In [None]:
outputs

In [None]:
def generate_validation_prompt(
    disorder_entities: list[list[MedCATSpan]],
) -> list[Message]:
    system_message = Message(
        role=Role.SYSTEM,
        content=f"""You are an expert medical assistant aiding a user to write a patient's discharge summary.
Your task is to decide which medical disorders to include in the summary.
This is done by analyzing extracts from the patient's medical records which explicitly mention each disorder.
The included disorders should be those most critical to the patient's health.
---
User messages follow the format.
Which of the following terms should be included in the discharge summary?
Disorder: $[disorder]
EHR Extracts: $[new line separated extracts from the patient's medical records mentioning the disorder]
---
Assistant messages follow the format.
$[List of disorders to include in the discharge summary]
""",
    )
    
    user_content = "

".join(f"Disorder: {}
EHR Extracts: {}" for disorder, ents in disorder_entities)

    user_message = Message(
        role=Role.USER,
        content=f"""Which of the following terms should be included in the discharge summary?
Term: {term}"""
    )

    return [system_message, user_message]

In [None]:
filterd_cui_to_entities = {cui, ent}

In [None]:
outputs

In [None]:
def generate_extract_prompt(physician_note: str) -> list[Message]:
    system_message = Message(
        role=Role.SYSTEM,
        content=f"""You are an expert medical assistant aiding a user to write a patient's discharge summary.
Your task is to write a list of confirmed medical disorders experienced by the patient during their stay to include in the discharge summary.
This list should use UMLS preferred terms.
You may only use information provided by the user.
---
User messages follow the format.
Physician Note: $[physician note]
---
Assistant messages follow the format.
$[List of UMLS preferred term disorders]
""",
    )
    user_message = Message(
        role=Role.USER,
        content=f"""Physician Note: {physician_note}""",
    )

    return [system_message, user_message]

In [None]:
outputs = [
    # https://stackoverflow.com/questions/55409641/asyncio-run-cannot-be-called-from-a-running-event-loop-when-using-jupyter-no
    await llm.aquery(generate_extract_prompt(note))
    for note in tqdm(inputs)
]

In [None]:
outputs

In [None]:
joined_output_context = "\n".join([output.content for output in outputs])
output_cuis = {str(span._.cui) for span in cat(joined_output_context).ents}
output_cuis

In [None]:
input_cuis = {str(span._.cui) for note in tqdm(inputs) for span in cat(note).ents}
input_cuis

In [None]:
def cuis_to_names(cuis: list[str], cat: CAT) -> list[str]:
    return sorted(cat.cdb.get_name(cui) for cui in cuis)

In [None]:
len(input_cuis), len(output_cuis)

In [None]:
cuis_to_names(output_cuis - input_cuis, cat)

In [None]:
cuis_to_names(set(gt_medcat_cuis) - input_cuis, cat)

In [None]:
cuis_to_names(set(gt_medcat_cuis) - output_cuis, cat)

In [None]:
sample.

In [None]:
cuis_to_names(output_cuis, cat)

In [None]:
output_cui_names = cuis_to_names(output_cuis, cat)
output_cui_names

In [None]:
def generate_reduce_cuis_prompt(physician_note: str) -> list[Message]:
    system_message = Message(
        role=Role.SYSTEM,
        content=f"""You are an expert medical assistant aiding a user to write a patient's discharge summary.
Your task is to rewrite a given list of confirmed medical disorders experienced by the patient during their stay to include in the discharge summary.
To do this, you must reduce the list to only include the most important disorders.
The final list must only contain terms provided by the user.
---
User messages follow the format.
Disorder list: $[physician note]
---
Assistant messages follow the format.
$[List of UMLS preferred term disorders]
""",
    )
    user_message = Message(
        role=Role.USER,
        content=f"""Physician Note: {physician_note}""",
    )

    return [system_message, user_message]

In [None]:
inputs[0]

In [None]:
generate_extract_prompt(inputs[0])

In [None]:
len(inputs)

In [None]:
inputs[0].category

In [None]:
print(output.content)

In [None]:
gts

In [None]:
print(inputs[0])

In [None]:
print(sample.discharge_summary.bhc)