In [None]:
import os
import pickle
from pathlib import Path
from typing import List

import numpy as np
import tiktoken
from dotenv import load_dotenv
from tqdm.notebook import tqdm

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.message import Message, Role
from discharge_summaries.schemas.mimic import Note, Record

In [None]:
load_dotenv()

In [None]:
DATA_DIR = Path.cwd().parent / "data"
OUTPUT_DIR = Path.cwd() / "output"


TRAINING_DATASET_PATH = DATA_DIR / "train_all_ds.pkl"
RANDOM_SEED = 23
AZURE_ENGINE = "gpt-4-32k"
AZURE_API_VERSION = "2023-07-01-preview"
# AZURE_ENGINE = "gpt-35-turbo"
# AZURE_API_VERSION = "2023-07-01-preview"

In [None]:
def num_tokens_from_messages_azure_engine(
    messages: List[Message], azure_engine: str, azure_api_version: str
) -> int:
    azure_engine_and_version_to_openai_model = {
        "gpt-35-turbo-2023-07-01-preview": "gpt-3.5-turbo-0613",
        "gpt-4-32k-2023-07-01-preview": "gpt-4-32k-0613",
    }
    try:
        model = azure_engine_and_version_to_openai_model[
            f"{azure_engine}-{azure_api_version}"
        ]
    except KeyError:
        raise NotImplementedError(
            "num_tokens_from_messages() is not implemented for model"
            f" {azure_engine}-{azure_api_version}."
        )
    return num_tokens_from_messages(messages, model=model)


def num_tokens_from_messages(messages: List[Message], model: str) -> int:
    """Return the number of tokens used by a list of messages."""
    try:
        encoding = tiktoken.encoding_for_model(model)
    except KeyError:
        print("Warning: model not found. Using cl100k_base encoding.")
        encoding = tiktoken.get_encoding("cl100k_base")
    if model in {
        "gpt-3.5-turbo-0613",
        "gpt-3.5-turbo-16k-0613",
        "gpt-4-0314",
        "gpt-4-32k-0314",
        "gpt-4-0613",
        "gpt-4-32k-0613",
    }:
        tokens_per_message = 3
        tokens_per_name = 1
    elif model == "gpt-3.5-turbo-0301":
        tokens_per_message = (
            4  # every message follows <|start|>{role/name}\n{content}<|end|>\n
        )
        tokens_per_name = -1  # if there's a name, the role is omitted
    elif "gpt-3.5-turbo" in model:
        print(
            "Warning: gpt-3.5-turbo may update over time. Returning num tokens assuming"
            " gpt-3.5-turbo-0613."
        )
        return num_tokens_from_messages(messages, model="gpt-3.5-turbo-0613")
    elif "gpt-4" in model:
        print(
            "Warning: gpt-4 may update over time. Returning num tokens assuming"
            " gpt-4-0613."
        )
        return num_tokens_from_messages(messages, model="gpt-4-0613")
    else:
        raise NotImplementedError(
            f"""num_tokens_from_messages() is not implemented for model {model}.
            See https://github.com/openai/openai-python/blob/main/chatml.md for information
            on how messages are converted to tokens."""
        )
    num_tokens = 0
    for message in messages:
        num_tokens += tokens_per_message
        for key, value in message.dict().items():
            num_tokens += len(encoding.encode(value))
            if key == "name":
                num_tokens += tokens_per_name
    num_tokens += 3  # every reply is primed with <|start|>assistant<|message|>
    return num_tokens

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset
len(dataset)

## Token lengths

In [None]:
bhc_token_lengths = np.array(
    [
        num_tokens_from_messages_azure_engine(
            [Message(role=Role.ASSISTANT, content=record.discharge_summary.bhc)],
            AZURE_ENGINE,
            AZURE_API_VERSION,
        )
        for record in dataset
    ]
)
np.mean(bhc_token_lengths), np.median(bhc_token_lengths), np.std(
    bhc_token_lengths
), np.min(bhc_token_lengths), np.max(bhc_token_lengths)

In [None]:
bhc_token_lengths = np.array(
    [
        num_tokens_from_messages_azure_engine(
            [Message(role=Role.ASSISTANT, content=record.discharge_summary.bhc)],
            AZURE_ENGINE,
            AZURE_API_VERSION,
        )
        for record in tqdm(dataset)
    ]
)
np.median(bhc_token_lengths), np.min(bhc_token_lengths), np.max(
    bhc_token_lengths
), np.percentile(bhc_token_lengths, 95)

In [None]:
note_token_lengths = [
    [
        num_tokens_from_messages_azure_engine(
            [Message(role=Role.USER, content=note.text)],
            AZURE_ENGINE,
            AZURE_API_VERSION,
        )
        for note in record.physician_notes
    ]
    for record in tqdm(dataset)
]

In [None]:
single_note_token_lengths = np.array(
    [
        note_token_length
        for record_note_lengths in note_token_lengths
        for note_token_length in record_note_lengths
    ]
)
np.median(single_note_token_lengths), np.min(single_note_token_lengths), np.max(
    single_note_token_lengths
), np.percentile(single_note_token_lengths, 95)

In [None]:
combined_note_token_lengths = np.array(
    [sum(record_note_lengths) for record_note_lengths in note_token_lengths]
)
np.median(combined_note_token_lengths), np.min(combined_note_token_lengths), np.max(
    combined_note_token_lengths
), np.percentile(combined_note_token_lengths, 95)

In [None]:
num_notes = np.array([len(record.physician_notes) for record in tqdm(dataset)])
np.median(num_notes), np.min(num_notes), np.max(num_notes), np.percentile(num_notes, 95)

In [None]:
dataset_filtered_95 = dataset

In [None]:
sum(1 for length in combined_note_token_lengths if length < 31000) / len(
    combined_note_token_lengths
)

In [None]:
rough_cost_per_generation = 31 * 0.047 + 1 * 0.094
rough_cost_per_generation

## LLM Generation 32K

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
SYSTEM_MESSAGE = Message(
    role=Role.SYSTEM,
    content=f"""You are a consultant doctor.
You are tasked with writing the brief hospital course summary section of a patients discharge summary.
To aid you in this task the user provides you with physician notes from the patient's visit.
Each note has a title of the format Physician Note [number]: [timestamp].
The summary should be roughly 500 words long.
You can only use the information in the notes to write the summary.

This is an example of a brief hospital course summary:
{dataset[0].discharge_summary.bhc}
""",
)

In [None]:
def generate_notes_string(notes: List[Note]):
    return "\n\n".join(
        f"Physician Note {idx+1}: {note.datetime}\n{note.text}"
        for idx, note in enumerate(notes)
    )

In [None]:
outputs = []
for idx, record in enumerate(dataset[1:]):
    note_string = generate_notes_string(record.physician_notes)
    user_message_content = """Write a brief hopsital course summary from the following physician notes.

{note_string}"""
    user_message = Message(role=Role.USER, content=user_message_content)
    num_user_tokens = num_tokens_from_messages_azure_engine(
        [user_message], AZURE_ENGINE, AZURE_API_VERSION
    )
    if num_user_tokens > 31000:
        print(f"Skipping record {idx+1} because it has {num_user_tokens} tokens.")
        continue
    response = llm.query([SYSTEM_MESSAGE, user_message])
    outputs.append((idx + 1, [SYSTEM_MESSAGE, user_message, response]))
    if len(outputs) > 3:
        break

In [None]:
for idx, messages in outputs:
    prompts = "\n\n*****\n\n".join(
        [f"{message.role}:\n{message.content}" for message in messages[:2]]
    )
    file_output = (
        f"GPT BHC:\n{messages[-1].content}\n\n*****\n\nHuman"
        f" BHC:\n{dataset[idx].discharge_summary.bhc}\n\n*****\n\nPrompt:{prompts}"
    )
    (OUTPUT_DIR / f"{idx}.txt").write_text(file_output)