In [None]:
import json
import os
import pickle
from pathlib import Path
from typing import List

from dotenv import load_dotenv

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.message import Message, Role
from discharge_summaries.openai_llm.token_count import (
    num_tokens_from_messages_azure_engine,
)
from discharge_summaries.schemas.mimic import Note, Record

In [None]:
load_dotenv()

In [None]:
DATA_DIR = Path.cwd().parent / "data"
OUTPUT_DIR = Path.cwd() / "output"


TRAINING_DATASET_PATH = DATA_DIR / "train_all_ds.pkl"
RANDOM_SEED = 23
AZURE_ENGINE = "gpt-4-32k"
AZURE_API_VERSION = "2023-07-01-preview"
# AZURE_ENGINE = "gpt-35-turbo"
# AZURE_API_VERSION = "2023-07-01-preview"

GUIDELINES_JSON_SCHEMA_PATH = (
    Path.cwd().parent / "guidelines" / "eDischarge-Summary-v2.1-1st-Feb-21_schema.json"
)

PRSB_EXAMPLE_FPATH = DATA_DIR / "prsb_example.txt"
PRSB_EXAMPLE_GT_FPATH = DATA_DIR / "prsb_example_gt.txt"

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset[:5]
len(dataset)

In [None]:
sample = dataset[0]
physican_notes = sample.physician_notes
discharge_summary = [sample.discharge_summary]

Remove indents (save tokens)

In [None]:
guidelines_json_schema_json = json.loads(GUIDELINES_JSON_SCHEMA_PATH.read_text())
guidelines_json_schema_str = json.dumps(guidelines_json_schema_json)

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
SYSTEM_MESSAGE = Message(
    role=Role.SYSTEM,
    content=f"""You are a consultant doctor tasked with writing a patients discharge summary.
Only the information in the physician notes provided by the user can be used for this task.
Each physician note has a title of the format Physician Note [number]: [timestamp].

The discharge summary must be written in accordance with the following json schema.
{guidelines_json_schema_str}
If the information is not present to fill in a field, answer it with an empty string.
""",
)

In [None]:
num_tokens_from_messages_azure_engine([SYSTEM_MESSAGE], AZURE_ENGINE, AZURE_API_VERSION)

In [None]:
def generate_notes_string(notes: List[Note]):
    return "\n\n".join(
        f"Physician Note {idx+1}: {note.datetime}\n{note.text}"
        for idx, note in enumerate(notes)
    )


def generate_discharge_summary(
    notes: List[Note], max_prompt_tokens=31000
) -> List[Message]:
    notes_string = generate_notes_string(notes)
    user_message_content = (
        "Generate the discharge summary json given the following physician"
        f" notes\n\n{notes_string}"
    )
    prompt_messages = [
        SYSTEM_MESSAGE,
        Message(role=Role.USER, content=user_message_content),
    ]

    num_prompt_tokens = num_tokens_from_messages_azure_engine(
        prompt_messages, AZURE_ENGINE, AZURE_API_VERSION
    )
    if num_prompt_tokens > max_prompt_tokens:
        raise ValueError(
            f"Prompt has {num_prompt_tokens} tokens, which is greater than the max of"
            f" {max_prompt_tokens}."
        )

    return prompt_messages + [llm.query(prompt_messages)]


def save_output(output: List[Message], output_fpath: Path):
    completion = json.loads(output[-1].content)
    file_output = {
        "Completion": completion,
        "Prompts": [message.dict() for message in output[:-1]],
    }
    output_fpath.write_text(json.dumps(file_output, indent=4))

In [None]:
output = generate_discharge_summary(physican_notes)
save_output(output, OUTPUT_DIR / "0_mimic_example.json")

In [None]:
output = generate_discharge_summary(discharge_summary)
save_output(output, OUTPUT_DIR / "0_mimic_example_gt.json")