In [None]:
import json
import os
import time
from copy import deepcopy
from pathlib import Path
from random import Random
from typing import Dict, List, Set, Union

import jsonschema
import pandas as pd
import tiktoken
from dotenv import load_dotenv
from openai.error import RateLimitError
from tqdm.notebook import tqdm

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.message import Message, Role
from discharge_summaries.openai_llm.prompts import (
    generate_rcp_system_message,
    generate_rcp_user_message,
)
from discharge_summaries.schemas.mimic import PhysicianNote
from discharge_summaries.schemas.rcp_guidelines import RCPGuidelines

In [None]:
load_dotenv()

In [None]:
GPT_4_ENGINE = "gpt-4-turbo"

AZURE_API_VERSION = "2023-07-01-preview"
TOKENIZER_NAME = "cl100k_base"
EXAMPLE_DIR = Path.cwd() / "examples"
OUTPUT_DIR = Path.cwd() / "output_eval_v3"

SNOMED_DIR = Path.cwd().parent / "data" / "snomed"
PHRASE_MATCHER_FPATH = SNOMED_DIR / "snomed_phrase_matcher_full.pkl"

MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
PHYSICIAN_NOTE_FPATH = MIMIC_III_DIR / "physician_notes_mimic.csv"

TRAIN_SAMPLE_SIZE = 5
EVAL_SAMPLE_SIZE = 68
RANDOM_SEED = 23

In [None]:
rcp_schema = RCPGuidelines.schema()

In [None]:
example_notes = [
    PhysicianNote(**note)
    for note in json.loads((EXAMPLE_DIR / "example_1_notes.json").read_text())
]
example_response = json.loads((EXAMPLE_DIR / "example_1.json").read_text())
jsonschema.validate(example_response, rcp_schema)

In [None]:
def remove_keys_recursive(d: Union[List, Dict], keys: Set[str]):
    if isinstance(d, dict):
        for key in list(d.keys()):
            if key in keys:
                del d[key]
            else:
                remove_keys_recursive(d[key], keys)
    elif isinstance(d, list):
        for item in d:
            remove_keys_recursive(item, keys)
    return d


def remove_default_definition_properties(schema: Dict):
    for section_dict in schema["definitions"].values():
        section_dict["properties"] = {
            property: property_dict
            for property, property_dict in section_dict["properties"].items()
            if "default" not in property_dict.keys()
        }
    return schema


simplified_rcp_schema = remove_keys_recursive(
    deepcopy(rcp_schema), {"title", "required"}
)
simplified_rcp_schema = remove_default_definition_properties(simplified_rcp_schema)
simplified_rcp_schema

In [None]:
system_message = generate_rcp_system_message(simplified_rcp_schema)
one_shot_user_message = generate_rcp_user_message(example_notes)
one_shot_response_message = Message(
    role=Role.ASSISTANT, content=json.dumps(example_response)
)

In [None]:
tokenizer = tiktoken.get_encoding(TOKENIZER_NAME)

llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=GPT_4_ENGINE,
    temperature=0,
    timeout=20,
    max_retries=0,
)

In [None]:
notes_df = pd.read_csv(PHYSICIAN_NOTE_FPATH)

In [None]:
hadm_ids = notes_df["HADM_ID"].unique().tolist()
train_hadm_ids = Random(RANDOM_SEED).sample(hadm_ids, TRAIN_SAMPLE_SIZE)
eval_hadm_ids = Random(RANDOM_SEED).sample(hadm_ids, TRAIN_SAMPLE_SIZE + 68)[
    TRAIN_SAMPLE_SIZE:
]
set(train_hadm_ids).intersection(set(eval_hadm_ids))

In [None]:
for hadm_id in tqdm(eval_hadm_ids):
    hadm_id_output_dir = OUTPUT_DIR / str(int(hadm_id))
    hadm_id_output_dir.mkdir(parents=True, exist_ok=True)

    physician_notes = [
        PhysicianNote(
            hadm_id=row["HADM_ID"],
            title=row["DESCRIPTION"],
            timestamp=row["CHARTTIME"],
            text=row["TEXT"],
        )
        for _, row in notes_df[notes_df["HADM_ID"] == hadm_id].iterrows()
    ]

    user_message = generate_rcp_user_message(physician_notes)
    prompt = [
        system_message,
        one_shot_user_message,
        one_shot_response_message,
        user_message,
    ]

    prompt_length = sum(len(tokenizer.encode(message.content)) for message in prompt)

    generation_complete = False
    while not generation_complete:
        try:
            t0 = time.time()
            response = llm.query(prompt)
            time_taken = time.time() - t0
            generation_complete = True
        except RateLimitError:
            print("Rate limit exceeded")
            time.sleep(20)

    raw_messages = ("\n" + "*" * 80 + "\n").join(
        [message.content for message in prompt]
        + [
            response.content,
            f"Time taken: {time_taken}",
            f"Prompt length: {prompt_length}",
        ]
    )
    (hadm_id_output_dir / "raw_messages.txt").write_text(raw_messages)

    # Handle prefix and suffixes e.g. '''json...'''
    json_start = response.content.find("{")
    json_end = response.content.rfind("}")
    (hadm_id_output_dir / "discharge_summary.json").write_text(
        json.dumps(json.loads(response.content[json_start : json_end + 1]), indent=4)
    )

In [None]:
prompt

In [None]:
response.content

In [None]:
time_taken