In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
import pickle
from pathlib import Path

import pandas as pd
from dotenv import load_dotenv

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.schemas.mimic import BHC, PhysicianNote
from discharge_summaries.snomed.lookup import SnomedLookup
from discharge_summaries.snomed.retriever import SnomedRetriever
from discharge_summaries.writers.bhc import BHCWriter

In [None]:
load_dotenv()

In [None]:
MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
BHC_FPATH = MIMIC_III_DIR / "BHCS.json"
PHYSICIAN_NOTE_FPATH = MIMIC_III_DIR / "physician_notes_mimic.csv"

SNOMED_DIR = Path.cwd().parent / "data" / "snomed"
TUNED_PHRASE_MATCHER_FPATH = SNOMED_DIR / "tuned_snomed_phrase_matcher.pkl"
PROMPT_MESSAGE_FPATH = Path.cwd() / "prompt_message.txt"
AZURE_ENGINE = "gpt-35-turbo-32k"
AZURE_API_VERSION = "2023-07-01-preview"
EXAMPLE_DIR = Path.cwd() / "example"

In [None]:
bhcs = [BHC(**bhc_dict) for bhc_dict in json.loads(BHC_FPATH.read_text())]

In [None]:
physician_notes = pd.read_csv(PHYSICIAN_NOTE_FPATH)

In [None]:
physician_notes.columns

In [None]:
sample_bhc = bhcs[10]
sample_physician_notes = [
    PhysicianNote(text=row["TEXT"], timestamp=row["CHARTTIME"], hadm_id=row["HADM_ID"])
    for _, row in physician_notes[
        physician_notes["HADM_ID"] == sample_bhc.hadm_id
    ].iterrows()
]

In [None]:
print(sample_bhc)

In [None]:
print(sample_physician_notes[0])

### Load Models

In [None]:
snomed_phrase_matcher = pickle.load(TUNED_PHRASE_MATCHER_FPATH.open("rb"))

In [None]:
snomed_lookup = SnomedLookup.load(SNOMED_DIR)

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
snomed_retriever = SnomedRetriever(
    snomed_phrase_matcher, snomed_lookup, token_window_size=25
)

In [None]:
bhc_writer = BHCWriter(llm, snomed_retriever, bhcs[:3], logging_dir=EXAMPLE_DIR)

In [None]:
filtered_findings = {
    "Respiratory failure",
    "Sepsis",
    "Fever",
    "Bandemia",
    "ARF (Acute Respiratory Failure)",
    "C. difficile colitis",
    "Unresponsiveness",
    "Diarrhea",
    "Elevated CK (Creatine Kinase)",
    "Renal failure",
    "Hypertension",
    "Psych history",
    "Chronic pain",
    "Cancer (squamous cell cancer of the head and neck)",
}

In [None]:
finding_to_extract_spans = bhc_writer._snomed_retriever(
    [note.text for note in sample_physician_notes], filtered_findings
)

In [None]:
bhc = bhc_writer(sample_physician_notes)

In [None]:
print(sample_bhc.full_text)

In [None]:
(EXAMPLE_DIR / "gpt_bhc.json").write_text(json.dumps(bhc.dict(), indent=4))