# MIMIC Generation

This notebooks ingests mimic physician notes and Royal College of Physician London guidelines. 

These are converted into a prompt and queried to GPT-4-turbo. 

The simplified json schema used in the prompt is saved to file.

The outputs are then saved to `outputs\llm_responses` dir. This output is the raw json and the message history with the LLM

In [None]:
import json
import math
import os
import time
from copy import deepcopy
from pathlib import Path
from random import Random

import jsonschema
import pandas as pd
from dotenv import load_dotenv
from openai.error import RateLimitError
from tqdm.notebook import tqdm

from llm_discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from llm_discharge_summaries.openai_llm.message import Message, Role
from llm_discharge_summaries.openai_llm.prompts import (
    generate_rcp_system_message,
    generate_rcp_user_message,
)
from llm_discharge_summaries.schemas.mimic import PhysicianNote
from llm_discharge_summaries.schemas.rcp_guidelines import RCPGuidelines

In [None]:
MIMIC_III_DIR = Path.cwd() / "inputs" / "physionet.org" / "files" / "mimiciii" / "1.4"
PHYSICIAN_NOTE_FPATH = MIMIC_III_DIR / "physician_notes_mimic.csv"
ONE_SHOT_EXAMPLE_DIR = (
    Path.cwd().parent / "llm_discharge_summaries" / "schemas" / "rcp_one_shot_example"
)
OUTPUT_DIR = Path.cwd() / "outputs" / "llm_responses"

GPT_4_ENGINE = "gpt-4-turbo"
AZURE_API_VERSION = "2023-07-01-preview"
TOKENIZER_NAME = "cl100k_base"

NUMBER_CLINICAL_EVALUATORS = 15
NUM_EXAMPLES_PER_EVALUATOR = 5
RANDOM_SEED = 23

In [None]:
OUTPUT_DIR.mkdir(exist_ok=True)
load_dotenv()

## Pre-process schema

In [None]:
rcp_schema = RCPGuidelines.schema()

Title, required fields are removed as contain redundant information. 

Definition properties with a default value are also removed

In [None]:
def remove_keys_recursive(d: list | dict, keys: set[str]):
    if isinstance(d, dict):
        for key in list(d.keys()):
            if key in keys:
                del d[key]
            else:
                remove_keys_recursive(d[key], keys)
    elif isinstance(d, list):
        for item in d:
            remove_keys_recursive(item, keys)
    return d


def remove_default_definition_properties(schema: dict):
    for section_dict in schema["definitions"].values():
        section_dict["properties"] = {
            property: property_dict
            for property, property_dict in section_dict["properties"].items()
            if "default" not in property_dict.keys()
        }
    return schema


simplified_rcp_schema = remove_keys_recursive(
    deepcopy(rcp_schema), {"title", "required"}
)
simplified_rcp_schema = remove_default_definition_properties(simplified_rcp_schema)
(OUTPUT_DIR / "simplified_rcp_schema.json").write_text(
    json.dumps(simplified_rcp_schema, indent=4)
)

## Load 1 shot example

In [None]:
example_notes = [
    PhysicianNote(**note)
    for note in json.loads((ONE_SHOT_EXAMPLE_DIR / "physician_notes.json").read_text())
]
example_response = json.loads(
    (ONE_SHOT_EXAMPLE_DIR / "discharge_summary.json").read_text()
)
jsonschema.validate(example_response, rcp_schema)

## Generate Prompt

In [None]:
system_message = generate_rcp_system_message(simplified_rcp_schema)
one_shot_user_message = generate_rcp_user_message(example_notes)
one_shot_response_message = Message(
    role=Role.ASSISTANT, content=json.dumps(example_response)
)

## Randomly sample admission examples

In [None]:
notes_df = pd.read_csv(PHYSICIAN_NOTE_FPATH)

In [None]:
num_shared_examples = math.floor(NUMBER_CLINICAL_EVALUATORS / 2)
eval_sample_size = (
    NUM_EXAMPLES_PER_EVALUATOR * NUMBER_CLINICAL_EVALUATORS - num_shared_examples
)
eval_sample_size

In [None]:
hadm_ids = notes_df["HADM_ID"].unique().tolist()
# Used for 1 round of qualitative evaluation
eval_hadm_ids = Random(RANDOM_SEED).sample(
    hadm_ids, NUM_EXAMPLES_PER_EVALUATOR + eval_sample_size
)[NUM_EXAMPLES_PER_EVALUATOR:]

## Query LLM

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=GPT_4_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
for hadm_id in tqdm(eval_hadm_ids):
    hadm_id_output_dir = OUTPUT_DIR / str(int(hadm_id))
    hadm_id_output_dir.mkdir(parents=True, exist_ok=True)

    physician_notes = [
        PhysicianNote(
            hadm_id=row["HADM_ID"],
            title=row["DESCRIPTION"],
            timestamp=row["CHARTTIME"],
            text=row["TEXT"],
        )
        for _, row in notes_df[notes_df["HADM_ID"] == hadm_id].iterrows()
    ]

    user_message = generate_rcp_user_message(physician_notes)
    prompt = [
        system_message,
        one_shot_user_message,
        one_shot_response_message,
        user_message,
    ]

    # Very basic handling of rate limit errors
    generation_complete = False
    while not generation_complete:
        try:
            t0 = time.time()
            response = llm.query(prompt)
            time_taken = time.time() - t0
            generation_complete = True
        except RateLimitError:
            print("Rate limit exceeded")
            time.sleep(20)

    raw_messages = ("\n" + "*" * 80 + "\n").join(
        [message.content for message in prompt]
        + [
            response.content,
            f"Time taken: {time_taken}",
        ]
    )
    (hadm_id_output_dir / "raw_messages.txt").write_text(raw_messages)

    # Handle prefix and suffixes e.g. '''json...'''
    json_start = response.content.find("{")
    json_end = response.content.rfind("}")
    (hadm_id_output_dir / "discharge_summary.json").write_text(
        json.dumps(json.loads(response.content[json_start : json_end + 1]), indent=4)
    )