In [None]:
from pathlib import Path

import pandas as pd
import torch
from sentence_transformers import SentenceTransformer, util

In [None]:
MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
ICD_DEFINITIONS_PATH = MIMIC_III_DIR / "D_ICD_DIAGNOSES.csv"
ICD_GTS_PATH = MIMIC_III_DIR / "DIAGNOSES_ICD.csv"
NOTES_PATH = MIMIC_III_DIR / "NOTEEVENTS.csv"

In [None]:
icd_definitions = pd.read_csv(ICD_DEFINITIONS_PATH)

In [None]:
embedder = SentenceTransformer("all-mpnet-base-v2")

In [None]:
corpus_embeddings = embedder.encode(
    icd_definitions["LONG_TITLE"].values, convert_to_tensor=True
)

In [None]:
query_embedding = embedder.encode("Cholera", convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)[0]
top_results = torch.topk(cos_scores, k=3)

In [None]:
for score, idx in zip(top_results[0], top_results[1]):
    print(icd_definitions.iloc[idx.item()], "(Score: {:.4f})".format(score))

In [None]:
icd_gts = pd.read_csv(ICD_GTS_PATH)

In [None]:
icd_gts

In [None]:
icd_gts[icd_gts["HADM_ID"] == 172335]

In [None]:
full_df = pd.read_csv(NOTES_PATH)

In [None]:
discharge_summaries = full_df[full_df["CATEGORY"] == "Discharge summary"]

In [None]:
# discharge_summaries = discharge_summaries[discharge_summaries["ISERROR"] != 1]
# discharge_summaries.drop("ISERROR", axis=1, inplace=True)
discharge_summaries = discharge_summaries.drop_duplicates()

In [None]:
discharge_summaries.head()

In [None]:
sample = discharge_summaries.sample(1)
sample

In [None]:
sample_ds = sample["TEXT"].values[0]

In [None]:
sample_hadm_id = sample["HADM_ID"].values[0]

In [None]:
import os

from dotenv import load_dotenv

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.message import Message, Role

In [None]:
load_dotenv()
AZURE_ENGINE = "gpt-4"
AZURE_API_VERSION = "2023-07-01-preview"

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
SYSTEM_MESSAGE = Message(
    role=Role.SYSTEM,
    content=f"""You are a consultant doctor tasked with writing a list of diagnoses from a discharge summary.
This list will then be used to assign ICD-9 codes to the patient.
General rules for accurate selection of diagnoses apply:
• Include the minimum number of diagnoses which accurately reflect the patient's condition
• Include every diagnoses or reason for encounter which affects the care, or influences
health status during the consultant episode, which is available in the classification and
supported by the medical record.
• Do not include background information or chronic problems which are no longer active
and which do not influence the health care being provided in the relevant consultant
episode. It is not always intended that symptoms or history be coded. Just because a
condition can be coded does not mean it should be coded each time the patient is
admitted. Any uncertainty around issues of relevance or inactive problems should be
discussed with the responsible consultant.

The user message will have the format:
Write a list of diagnoses for the following discharge summary:
<discharge summary>

Your response must be of the format:
<diagnosis 1>: <Short description of diagnosis 1>
<diagnosis 2>: <Short description of diagnosis 2>
...
<diagnosis n>: <Short description of diagnosis n>
""",
)

In [None]:
user_message_content = f"""Write a list of diagnoses for the following discharge summary:
{sample_ds}"""
prompt_messages = [
    SYSTEM_MESSAGE,
    Message(role=Role.USER, content=user_message_content),
]

prompt_messages

In [None]:
response = llm.query(prompt_messages)

In [None]:
print(response.content)

In [None]:
gpt_diagnoses = response.content.split("\n")

In [None]:
query_embedding = embedder.encode(gpt_diagnoses, convert_to_tensor=True)
cos_scores = util.cos_sim(query_embedding, corpus_embeddings)
top_results = torch.topk(cos_scores, k=1)

In [None]:
icd_definitions.iloc[top_results[1].cpu().flatten().numpy()]

In [None]:
for code in icd_gts[icd_gts["HADM_ID"] == sample_hadm_id]["ICD9_CODE"]:
    match = icd_definitions[icd_definitions["ICD9_CODE"] == code]["SHORT_TITLE"].values
    if len(match) > 0:
        print(code, match[0])
    else:
        print(code)

In [None]:
for code in top_results[1].cpu().flatten().numpy().tolist():
    match = icd_definitions[icd_definitions["ICD9_CODE"] == code]["SHORT_TITLE"].values
    if len(match) > 0:
        print(match[0])
    else:
        print(code)