In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
from copy import deepcopy
from pathlib import Path
from random import Random
from typing import Dict, List, Set, Union

import jsonschema
import pandas as pd
import tiktoken
from dotenv import load_dotenv
from tqdm.notebook import tqdm

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.prompts import (
    generate_rcp_system_message,
    generate_rcp_user_message,
)
from discharge_summaries.schemas.mimic import PhysicianNote
from discharge_summaries.schemas.rcp_guidelines import RCPGuidelines
from discharge_summaries.utils.deduplicate import deduplicate_physician_notes

In [None]:
load_dotenv()

In [None]:
GPT_4_ENGINE = "gpt-4"
GPT_4_32K_ENGINE = "gpt-4-32k"

AZURE_API_VERSION = "2023-07-01-preview"
TOKENIZER_NAME = "cl100k_base"
EXAMPLE_DIR = Path.cwd() / "examples"
OUTPUT_DIR = Path.cwd() / "output"

SNOMED_DIR = Path.cwd().parent / "data" / "snomed"
PHRASE_MATCHER_FPATH = SNOMED_DIR / "snomed_phrase_matcher_full.pkl"

MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
PHYSICIAN_NOTE_FPATH = MIMIC_III_DIR / "physician_notes_mimic.csv"

SAMPLE_SIZE = 5
RANDOM_SEED = 23

In [None]:
rcp_schema = RCPGuidelines.schema()
example = json.loads((EXAMPLE_DIR / "example.json").read_text())
jsonschema.validate(example, rcp_schema)

In [None]:
def remove_keys_recursive(d: Union[List, Dict], keys: Set[str]):
    if isinstance(d, dict):
        for key in list(d.keys()):
            if key in keys:
                del d[key]
            else:
                remove_keys_recursive(d[key], keys)
    elif isinstance(d, list):
        for item in d:
            remove_keys_recursive(item, keys)
    return d


# Remove keys "title" and "required" recursively
simplified_rcp_schema = remove_keys_recursive(
    deepcopy(rcp_schema), {"title", "required"}
)
simplified_rcp_schema

In [None]:
tokenizer = tiktoken.get_encoding(TOKENIZER_NAME)

gpt_4 = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=GPT_4_ENGINE,
    temperature=0,
    timeout=20,
)

gpt_4_32k = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=GPT_4_32K_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
notes_df = pd.read_csv(PHYSICIAN_NOTE_FPATH)

In [None]:
hadm_ids = notes_df["HADM_ID"].unique().tolist()
sample_hadm_ids = Random(RANDOM_SEED).sample(hadm_ids, SAMPLE_SIZE)
sample_hadm_ids

In [None]:
for hadm_id in tqdm(sample_hadm_ids):
    hadm_id_output_dir = OUTPUT_DIR / str(int(hadm_id))
    hadm_id_output_dir.mkdir(parents=True, exist_ok=True)

    physician_notes = [
        PhysicianNote(
            hadm_id=row["HADM_ID"],
            title=row["DESCRIPTION"],
            timestamp=row["CHARTTIME"],
            text=row["TEXT"],
        )
        for _, row in notes_df[notes_df["HADM_ID"] == hadm_id].iterrows()
    ]

    deduplicated_physician_notes = deduplicate_physician_notes(physician_notes)
    deduplicated_physician_notes = sorted(
        deduplicated_physician_notes, key=lambda x: x.timestamp
    )

    system_message = generate_rcp_system_message(simplified_rcp_schema, example)
    user_message = generate_rcp_user_message(deduplicated_physician_notes)

    prompt_length = sum(
        len(tokenizer.encode(message.content))
        for message in [system_message, user_message]
    )
    tokenizer = tiktoken.get_encoding(TOKENIZER_NAME)

    llm = gpt_4 if prompt_length < 7000 else gpt_4_32k

    response = llm.query([system_message, user_message])

    (hadm_id_output_dir / "prompt.txt").write_text(
        system_message.content + "\n\n" + user_message.content
    )

    (hadm_id_output_dir / "discharge_summary.json").write_text(
        json.dumps(json.loads(response.content), indent=4)
    )