In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import json
import os
from copy import deepcopy
from pathlib import Path
from random import Random
from typing import Dict, List, Set, Union

import jsonschema
import pandas as pd
import tiktoken
from dotenv import load_dotenv

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.prompts import (
    generate_rcp_system_message,
    generate_rcp_user_message,
)
from discharge_summaries.schemas.mimic import PhysicianNote
from discharge_summaries.schemas.rcp_guidelines import RCPGuidelines

In [None]:
load_dotenv()

In [None]:
AZURE_ENGINE = "gpt-4"
AZURE_API_VERSION = "2023-07-01-preview"
TOKENIZER_NAME = "cl100k_base"
EXAMPLE_DIR = Path.cwd() / "examples"
OUTPUT_DIR = Path.cwd() / "output"

SNOMED_DIR = Path.cwd().parent / "data" / "snomed"
PHRASE_MATCHER_FPATH = SNOMED_DIR / "snomed_phrase_matcher_full.pkl"

MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
PHYSICIAN_NOTE_FPATH = MIMIC_III_DIR / "physician_notes_mimic.csv"

SAMPLE_SIZE = 5
RANDOM_SEED = 23

In [None]:
rcp_schema = RCPGuidelines.schema()
example = json.loads((EXAMPLE_DIR / "example.json").read_text())
jsonschema.validate(example, rcp_schema)

In [None]:
def remove_keys_recursive(d: Union[List, Dict], keys: Set[str]):
    if isinstance(d, dict):
        for key in list(d.keys()):
            if key in keys:
                del d[key]
            else:
                remove_keys_recursive(d[key], keys)
    elif isinstance(d, list):
        for item in d:
            remove_keys_recursive(item, keys)
    return d


# Remove keys "title" and "required" recursively
simplified_rcp_schema = remove_keys_recursive(
    deepcopy(rcp_schema), {"title", "required"}
)
simplified_rcp_schema

## Load MIMIC

In [None]:
notes_df = pd.read_csv(PHYSICIAN_NOTE_FPATH)

In [None]:
hadm_ids = notes_df["HADM_ID"].unique().tolist()
sample_hadm_ids = Random(RANDOM_SEED).sample(hadm_ids, SAMPLE_SIZE)
sample_hadm_ids

In [None]:
sample_hadm_id = sample_hadm_ids[0]
physician_notes = [
    PhysicianNote(text=row["TEXT"], hadm_id=row["HADM_ID"], timestamp=row["CHARTTIME"])
    for _, row in notes_df[notes_df["HADM_ID"] == sample_hadm_id].iterrows()
]
physician_notes

In [None]:
def deduplicate_note_lines(notes: List[PhysicianNote]) -> List[PhysicianNote]:
    seen_lines = set()
    deduplicated_notes = []
    for note in notes:
        deduplicated_lines = []
        for line in note.text.split("\n"):
            if line == "" or line in seen_lines:
                pass
            else:
                seen_lines.add(line)
                deduplicated_lines.append(line)
        if deduplicated_lines:
            deduplicated_notes.append(
                note.copy(update={"text": "\n".join(deduplicated_lines)})
            )
    return deduplicated_notes


deduplicated_notes = deduplicate_note_lines(physician_notes)
deduplicated_notes = sorted(deduplicated_notes, key=lambda x: x.timestamp)
len(physician_notes), len(deduplicated_notes)

## Prompting

In [None]:
system_message = generate_rcp_system_message(simplified_rcp_schema, example)
print(system_message.content)

In [None]:
user_message = generate_rcp_user_message(physician_notes)
print(user_message.content)

In [None]:
tokenizer = tiktoken.get_encoding(TOKENIZER_NAME)
for message in [system_message, user_message]:
    print(len(tokenizer.encode(message.content)))

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
response = llm.query([system_message, user_message])

In [None]:
(OUTPUT_DIR / f"mimic_{int(sample_hadm_id)}.json").write_text(
    json.dumps(json.loads(response.content), indent=4)
)

In [None]:
f"mimic_{int(sample_hadm_id)}.json"