# MIMIC Generation

This notebooks ingests mimic physician notes and Royal College of Physician London guidelines. 

These are converted into a prompt and queried to GPT-4-turbo. 

The simplified json schema used in the prompt is saved to file.

The outputs are then saved to `outputs\llm_responses` dir. This output is the raw json and the message history with the LLM

In [None]:
import json
import math
from copy import deepcopy
from pathlib import Path
from random import Random

import pandas as pd
from dotenv import load_dotenv
from transformers import GenerationConfig

from llm_discharge_summaries.schemas.mimic import PhysicianNote
from llm_discharge_summaries.schemas.rcp_guidelines import RCPGuidelines

In [None]:
PHYSICIAN_NOTE_FPATH = (
    Path.cwd()
    / "inputs"
    / "physionet.org"
    / "files"
    / "mimiciii"
    / "1.4"
    / "physician_notes_mimic.csv"
)
ONE_SHOT_EXAMPLE_DIR = (
    Path.cwd().parent / "llm_discharge_summaries" / "schemas" / "rcp_one_shot_example"
)
OUTPUT_DIR = Path.cwd() / "outputs" / "llm_responses"

GPT_4_ENGINE = "gpt-4-turbo"
AZURE_API_VERSION = "2023-07-01-preview"
TOKENIZER_NAME = "cl100k_base"

NUMBER_CLINICAL_EVALUATORS = 15
NUM_EXAMPLES_PER_EVALUATOR = 5
RANDOM_SEED = 23

In [None]:
OUTPUT_DIR.mkdir(exist_ok=True)
load_dotenv()

## Pre-process schema

In [None]:
rcp_schema = RCPGuidelines.schema()

Title, required fields are removed as contain redundant information. 

Definition properties with a default value are also removed

In [None]:
def remove_keys_recursive(d: list | dict, keys: set[str]):
    if isinstance(d, dict):
        for key in list(d.keys()):
            if key in keys:
                del d[key]
            else:
                remove_keys_recursive(d[key], keys)
    elif isinstance(d, list):
        for item in d:
            remove_keys_recursive(item, keys)
    return d


def remove_default_definition_properties(schema: dict):
    for section_dict in schema["definitions"].values():
        section_dict["properties"] = {
            property: property_dict
            for property, property_dict in section_dict["properties"].items()
            if "default" not in property_dict.keys()
        }
    return schema


simplified_rcp_schema = remove_keys_recursive(
    deepcopy(rcp_schema), {"title", "required"}
)
simplified_rcp_schema = remove_default_definition_properties(simplified_rcp_schema)

In [None]:
(OUTPUT_DIR.parent / "simplified_rcp_schema.json").write_text(
    json.dumps(simplified_rcp_schema, indent=4)
)

## Load 1 shot example

In [None]:
example_notes = [
    PhysicianNote(**note)
    for note in json.loads((ONE_SHOT_EXAMPLE_DIR / "physician_notes.json").read_text())
]

In [None]:
notes_df = pd.read_csv(PHYSICIAN_NOTE_FPATH)

In [None]:
num_shared_examples = math.floor(NUMBER_CLINICAL_EVALUATORS / 2)
eval_sample_size = (
    NUM_EXAMPLES_PER_EVALUATOR * NUMBER_CLINICAL_EVALUATORS - num_shared_examples
)
eval_sample_size

In [None]:
hadm_ids = notes_df["HADM_ID"].unique().tolist()
# Used for 1 round of qualitative evaluation
sample_hadm_ids = Random(RANDOM_SEED).sample(
    hadm_ids, NUM_EXAMPLES_PER_EVALUATOR + eval_sample_size
)[NUM_EXAMPLES_PER_EVALUATOR:]
train_hadm_ids = sample_hadm_ids[:NUM_EXAMPLES_PER_EVALUATOR]
eval_hadm_ids = sample_hadm_ids[NUM_EXAMPLES_PER_EVALUATOR:]

## Query LLM

In [None]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

cache_dir = "/data2/simon/auto-medical-discharge-summaries/.model_cache"
device = "cuda"
model_name = "mistralai/Mistral-7B-Instruct-v0.2"
assert torch.cuda.is_available()

pip install flash-attn --no-build-isolation

In [None]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
    cache_dir=cache_dir,
    device_map=device,
)
tokenizer = AutoTokenizer.from_pretrained(
    model_name, padding_side="left", cache_dir=cache_dir
)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
train_physician_notes = [
    PhysicianNote(
        hadm_id=row["HADM_ID"],
        title=row["DESCRIPTION"],
        timestamp=row["CHARTTIME"],
        text=row["TEXT"],
    )
    for _, row in notes_df[notes_df["HADM_ID"] == 103411].iterrows()
]

In [None]:
# f"""The following document contains examples of the reason for admission extracted from a set of patient clinical notes
# This has been done by an expert clinician.

# Each example has the following 3 parts:
# Patient's Clinical Notes:
# [The patient's clinical notes ordered by ascending timestamp. Each note has a title of the format [Title]: [timestamp year-month-day hour:min].]

# Reason for admission:
# [The main reason why the patient was admitted to hospital, eg chest pain, breathlessness, collapse, etc. This should be symptoms and not the diagnosis]
# Examples are separated by
# ###
# Example 1:
# Patient's Clinical Notes:
# {_physician_notes_to_string(example_notes)}
# Reason for admission:
# Chest tightness pain, breathlessness, nausea and dizziness started at 6 am.
# ###
# """

In [None]:
example_response = json.loads(
    (ONE_SHOT_EXAMPLE_DIR / "discharge_summary.json").read_text()
)

In [None]:
from llm_discharge_summaries.openai_llm.prompts import _physician_notes_to_string
from llm_discharge_summaries.schemas.rcp_guidelines import AdmissionDetails

In [None]:
fields = [
    (
        field_name,
        field_schema["description"],
        example_response["admission_details"][field_name],
    )
    for field_name, field_schema in AdmissionDetails.schema()["properties"].items()
    if "default" not in field_schema.keys()
]
fields

In [None]:
def first_message_prompt(
    field_name: str, description: str, example_response: str | list[str]
):
    cleaned_field_name = field_name.replace("_", " ").lower()
    example_response_str = (
        ", ".join(example_response)
        if type(example_response) == list
        else example_response
    )

    return [
        {
            "role": "user",
            "content": f"""I am a expert clinician tasked with finding the {cleaned_field_name} for a patient.
I will give you a patient's clinical notes from their stay from oldest to most recent. Each note is separated by a blank line and starts with a title followed by a timestamp.
Then you must tell me what you the {cleaned_field_name} which is defined as {description.lower()}
Responses should be no longer than 40 words.
Expand any abbreviations used in the notes to their full medical terms.
An example response would be
After analyzing the clinical notes I have found that the {cleaned_field_name} is {example_response_str.lower()}""",
        },
        {
            "role": "assistant",
            "content": (
                "Of course, I can help with that. Please provide the patient's clinical"
                " notes."
            ),
        },
        {
            "role": "user",
            "content": f"""These are the patient's clinical notes 
{_physician_notes_to_string(train_physician_notes)[:4000]}""",
        },
        {
            "role": "assistant",
            "content": (
                f"After analyzing the patient notes the {cleaned_field_name} was"
            ),
        },
    ]


first_prompt_messages = [
    first_message_prompt(name, description, example)
    for name, description, example in fields
]

In [None]:
for first_prompt_message in first_prompt_messages:
    print(first_prompt_message[0]["content"])
    print("--")

In [None]:
message_strs = [
    tokenizer.apply_chat_template(messages, tokenize=False)[
        len(tokenizer.bos_token) : -len(tokenizer.eos_token)
    ]
    for messages in first_prompt_messages
]
tokens = tokenizer(message_strs, return_tensors="pt", padding=True).to(device)
tokens.input_ids.shape

In [None]:
generation_config = GenerationConfig(
    max_new_tokens=1000,
    do_sample=False,
    use_cache=True,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)
generated_ids = model.generate(**tokens, generation_config=generation_config)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [None]:
doctor_assistant_messages = [
    field_output.split("[/INST]")[-1].strip() for field_output in output
]
doctor_assistant_messages

In [None]:
def second_message_prompt(
    field_name: str,
    description: str,
    example_response: str | list[str],
    doctor_assistant_message: str,
):
    yaml_string_array_type = """array
    items:
    type: string"""

    yaml_schema = f"""type: object
properties:
{field_name}:
    description: {description}
    type: {"string" if type(example_response) == str else yaml_string_array_type}"""

    return [
        {
            "role": "user",
            "content": f"""I am an administrator converting doctors notes to a json according to the following yaml schema
```yaml
{yaml_schema}
```

Only respond with the schema compliant json object.
Expand any abbreviations used by the doctor to their full medical terms.
Do not include an explanation.
Do not include any additional properties even if additional information is available.
An example response would be
```json
{{'{field_name}': {example_response}}}
```
""",
        },
        {
            "role": "assistant",
            "content": (
                "Of course, I can help with that. Please provide a doctors note."
            ),
        },
        {
            "role": "user",
            "content": f"""This is the doctor's note
{doctor_assistant_message}""",
        },
        {
            "role": "assistant",
            "content": f"""```json
{{{field_name}: """,
        },
    ]


second_message_prompts = [
    second_message_prompt(name, description, example, doctor_assistant_message)
    for (name, description, example), doctor_assistant_message in zip(
        fields, doctor_assistant_messages
    )
]

In [None]:
for second_message_prompt in second_message_prompts:
    print(second_message_prompt[0]["content"])
    print("--")

In [None]:
message_strs = [
    tokenizer.apply_chat_template(messages, tokenize=False)[
        len(tokenizer.bos_token) : -len(tokenizer.eos_token)
    ]
    for messages in second_message_prompts
]
tokens = tokenizer(message_strs, return_tensors="pt", padding=True).to(device)
tokens.input_ids.shape

In [None]:
generation_config = GenerationConfig(
    max_new_tokens=1000,
    do_sample=False,
    use_cache=True,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id,
)
generated_ids = model.generate(**tokens, generation_config=generation_config)
output = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

In [None]:
schema_assistant_messages = [
    field_output.split("[/INST]")[-1].strip() for field_output in output
]
schema_assistant_messages

In [None]:
def extract_text_between_backticks(text, prefix, suffix):
    first_backtick_index = text.find(prefix)
    if first_backtick_index == -1:
        return None  # If no backticks found
    second_backtick_index = text.find(suffix, first_backtick_index + len(prefix))
    if second_backtick_index == -1:
        return None  # If only one backtick found
    return text[first_backtick_index : second_backtick_index + 1]

In [None]:
schema_assistant_messages

In [None]:
for message in schema_assistant_messages:
    extracted_text = extract_text_between_backticks(message, "{", "}")
    formatted_json_string = (
        extracted_text.replace("'", '"')
        .replace(":", '":')
        .replace(",\n", ',\n "')
        .replace("{", '{"')
    )

    print(formatted_json_string)

In [None]:
for first_prompt_message in first_prompt_messages:
    print(first_prompt_message[0]["content"])
    print("--")

In [None]:
response = output.split("[/INST]")[-1]
print(response)

In [None]:
{"title": "AdmissionDetails", "type": "string", "description": "thing"}

In [None]:
{"AdmissionDetails": "thing"}