In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
from copy import deepcopy
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Set, Union

import jsonschema
import pandas as pd
import tiktoken
from dotenv import load_dotenv

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.prompts import (
    generate_rcp_system_message,
    generate_rcp_user_message,
)
from discharge_summaries.schemas.mimic import PhysicianNote
from discharge_summaries.schemas.rcp_guidelines import RCPGuidelines

In [None]:
load_dotenv()

In [None]:
AZURE_ENGINE = "gpt-4"
AZURE_API_VERSION = "2023-07-01-preview"
TOKENIZER_NAME = "cl100k_base"
EXAMPLE_DIR = Path.cwd() / "examples"
OUTPUT_DIR = Path.cwd() / "output"

In [None]:
rcp_schema = RCPGuidelines.schema()
example = json.loads((EXAMPLE_DIR / "example.json").read_text())
jsonschema.validate(example, rcp_schema)

In [None]:
def remove_keys_recursive(d: Union[List, Dict], keys: Set[str]):
    if isinstance(d, dict):
        for key in list(d.keys()):
            if key in keys:
                del d[key]
            else:
                remove_keys_recursive(d[key], keys)
    elif isinstance(d, list):
        for item in d:
            remove_keys_recursive(item, keys)
    return d


# Remove keys "title" and "required" recursively
simplified_rcp_schema = remove_keys_recursive(
    deepcopy(rcp_schema), {"title", "required"}
)
simplified_rcp_schema

## RCP Example

In [None]:
notes_df = pd.read_excel(
    Path.cwd().parent
    / "data"
    / "rcp"
    / "5. Activity-practice discharge summary writing task_0.xlsx",
    sheet_name="Notes",
    header=4,
)
notes_df.rename({"Unnamed: 0": "timestamp", "Unnamed: 1": "text"}, axis=1, inplace=True)
notes_df.head()

In [None]:
blank_rows = notes_df.isnull().all(axis=1)
consecutive_blank_rows = blank_rows & blank_rows.shift(-1)

split_dfs = []
start_index = 0
for end_index in consecutive_blank_rows[consecutive_blank_rows].index:
    split_dfs.append(notes_df.iloc[start_index:end_index])
    start_index = end_index + 2
split_dfs.append(notes_df.iloc[start_index:])

In [None]:
notes = []

for split_df in split_dfs:
    date_string = split_df["timestamp"].tolist()[0]
    date_string_excl_day = date_string.split(" ", 1)[1]
    timestamp = datetime.strptime(date_string_excl_day, "%d %b %Y %H:%M")
    notes.append(
        PhysicianNote(
            timestamp=timestamp.strftime("%Y-%m-%d %H:%M"),
            text="\n".join(split_df["text"].dropna().tolist()),
            hadm_id="0",
        )
    )

In [None]:
def deduplicate_note_lines(notes: List[PhysicianNote]) -> List[PhysicianNote]:
    seen_lines = set()
    deduplicated_notes = []
    for note in notes:
        deduplicated_lines = []
        for line in note.text.split("\n"):
            if line == "" or line in seen_lines:
                pass
            else:
                seen_lines.add(line)
                deduplicated_lines.append(line)
        if deduplicated_lines:
            deduplicated_notes.append(
                note.copy(update={"text": "\n".join(deduplicated_lines)})
            )
    return deduplicated_notes


deduplicated_notes = deduplicate_note_lines(notes)
deduplicated_notes = sorted(deduplicated_notes, key=lambda x: x.timestamp)
len(notes), len(deduplicated_notes)

## Prompting

In [None]:
system_message = generate_rcp_system_message(simplified_rcp_schema, example)
print(system_message.content)

In [None]:
user_message = generate_rcp_user_message(notes)
print(user_message.content)

In [None]:
tokenizer = tiktoken.get_encoding(TOKENIZER_NAME)
for message in [system_message, user_message]:
    print(len(tokenizer.encode(message.content)))

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
response = llm.query([system_message, user_message])

In [None]:
(OUTPUT_DIR / "rcp_example_1.json").write_text(
    json.dumps(json.loads(response.content), indent=4)
)