In [None]:
import json
import os
import random
import re
from pathlib import Path
from typing import Dict, List, Union

import jsonref
import numpy as np
import pandas as pd
import tiktoken
from dotenv import load_dotenv
from rank_bm25 import BM25Okapi

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.message import Message, Role
from discharge_summaries.schemas.prsb_guidelines import DischargeSummary
from discharge_summaries.structured_data_extractors.mimic import (
    MIMICStructuredDataExtractor,
)

In [None]:
load_dotenv()

In [None]:
MIMIC_III_DIR = (
    Path.cwd().parent / "data" / "physionet.org" / "files" / "mimiciii" / "1.4"
)
AZURE_ENGINE = "gpt-3-turbo-16k"
AZURE_API_VERSION = "2023-07-01-preview"
TOKENIZER = tiktoken.get_encoding("cl100k_base")

In [None]:
physician_notes_df = pd.read_csv(MIMIC_III_DIR / "physician_notes.csv")
discharge_summary_df = pd.read_csv(MIMIC_III_DIR / "discharge_summaries.csv")

In [None]:
structured_data_extractor = MIMICStructuredDataExtractor(MIMIC_III_DIR)

In [None]:
hadm_ids = discharge_summary_df["HADM_ID"].unique()
random.Random(23).shuffle(hadm_ids)
hadm_id = hadm_ids[0]
hadm_id

In [None]:
physician_notes_hadm_id_df = physician_notes_df[
    physician_notes_df["HADM_ID"] == hadm_ids[0]
]
len(physician_notes_hadm_id_df)

In [None]:
structured_data_summary = structured_data_extractor.complete_prsb_discharge_summary(
    hadm_id
)
structured_data_summary_dict = structured_data_summary.dict()
medications_structured_data = structured_data_summary_dict.pop(
    "medications_and_medical_devices"
)
procedures_structured_data = structured_data_summary_dict.pop("procedures")
print()

In [None]:
def remove_key_from_schema(schema: Union[Dict, List, str], remove_key: str):
    if isinstance(schema, dict):
        if remove_key in schema.keys():
            del schema[remove_key]
        for key in schema.keys():
            remove_key_from_schema(schema[key], remove_key)
    elif isinstance(schema, list):
        for item in schema:
            remove_key_from_schema(item, remove_key)

In [None]:
def json_ref_dict_to_dict(json_ref_dict: Dict) -> Dict:
    for k, v in json_ref_dict.items():
        if type(v) == jsonref.JsonRef:
            json_ref_dict[k] = json_ref_dict_to_dict(dict(v))
        elif type(v) == dict:
            json_ref_dict[k] = json_ref_dict_to_dict(v)
    return json_ref_dict

In [None]:
json_schema = jsonref.loads(DischargeSummary.schema_json(), jsonschema=True)
json_schema = json_ref_dict_to_dict(json_schema)
json_schema.pop("definitions")
remove_key_from_schema(json_schema, "required")
# Keep top level title
remove_key_from_schema(json_schema["properties"], "title")

In [None]:
def remove_previously_filled_schema_fields(schema: Dict, filled_schema: Dict) -> Dict:
    for key, value in filled_schema.items():
        if isinstance(value, dict):
            schema[key]["properties"] = {
                property_key: property_value
                for property_key, property_value in remove_previously_filled_schema_fields(
                    schema[key]["properties"], value
                ).items()
                if property_value
            }
        else:
            if value:
                del schema[key]
    return schema


json_schema["properties"] = remove_previously_filled_schema_fields(
    json_schema["properties"], structured_data_summary_dict
)

## Unfillable

In [None]:
unfillable_sections = {
    "discharge_details",
    "distribution_list",
    "person_completing_record",
    "allergies_and_adverse_reactions",
    "diagnoses",
}

for section in unfillable_sections:
    json_schema["properties"].pop(section)

In [None]:
snomed_search_sections = {
    "procedures",
    "medications_and_medical_devices",
    "investigation_results",
    "assessment_scale",
}  # handle allergies separately
for section in snomed_search_sections:
    json_schema["properties"].pop(section)

In [None]:
json_schema["properties"].pop("clinical_summary")

In [None]:
json_schema

In [None]:
def extract_unique_note_sections(physician_notes_df: pd.DataFrame) -> pd.DataFrame:
    # Could be smarter here alot of text overlap
    physician_notes_df_filtered = (
        physician_notes_df[["CHARTTIME", "TEXT"]]
        .drop_duplicates()
        .reset_index(drop=True)
    )
    prev_added_text = set()
    note_sections = []
    for note_idx, note in physician_notes_df_filtered.sort_values(
        "CHARTTIME"
    ).iterrows():
        start_char_idx = 0
        for section_text in re.split(
            "\n(?=^[^\n].*?:)", note["TEXT"], flags=re.MULTILINE
        ):
            end_char_idx = start_char_idx + len(section_text)
            if section_text not in prev_added_text:
                note_sections.append(
                    (note_idx, start_char_idx, end_char_idx, section_text)
                )
                prev_added_text.add(section_text)
            start_char_idx = end_char_idx + 1
    return pd.DataFrame.from_records(
        note_sections, columns=["note_idx", "start_char_idx", "end_char_idx", "text"]
    )

In [None]:
sections_df = extract_unique_note_sections(physician_notes_hadm_id_df)
sections_df

In [None]:
import spacy

try:
    nlp = spacy.load("en_core_web_lg", disable=["tok2vec", "parser", "ner"])
except OSError:
    spacy.cli.download("en_core_web_lg")
    nlp = spacy.load("en_core_web_lg")

In [None]:
def tokenise_texts(
    texts: List[str], spacy_pipeline: spacy.language.Language
) -> List[str]:
    return [
        [
            token.lemma_
            for token in doc
            if not token.is_punct and not token.is_stop and token.is_alpha
        ]
        for doc in spacy_pipeline.pipe(texts)
    ]

In [None]:
tokenised_sections = tokenise_texts(sections_df["text"].tolist(), nlp)

In [None]:
bm25 = BM25Okapi(tokenised_sections)

In [None]:
query_dict = {}
for section_name, section_values in json_schema["properties"].items():
    query_string = f'{section_name.replace("_", " ")} {section_values["description"]}'
    if section_values["type"] == "object":
        for property_key, property_value in section_values["properties"].items():
            query_string += (
                f" {property_key.replace('_', ' ')} {property_value['description']}"
            )
    tokenised_query = tokenise_texts([query_string], nlp)[0]

    doc_scores = bm25.get_scores(tokenised_query)
    top_n_idxs = np.argsort(doc_scores)[-10:][::-1]
    retrieved_sections = "\n\n".join(
        f"Extract {idx + 1}\n{text}"
        for idx, text in enumerate(sections_df.iloc[top_n_idxs]["text"])
    )

    print(section_name)
    print(retrieved_sections)

In [None]:
query = (
    "Individual requirement that a person has. These may be a communication, cultural,"
    " cognitive or mobility need."
)
tokenised_query = tokenise_texts([query], nlp)[0]

In [None]:
doc_scores = np.array(bm25.get_scores(tokenised_query))
top_n_idxs = np.argsort(doc_scores)[-10:][::-1]
top_n_idxs

In [None]:
for idx, text in enumerate(sections_df.iloc[top_n_idxs]["text"]):
    print(f"Extract {idx + 1}\n{text}\n")

In [None]:
test = "Some text is interesting about paracetamol [[12312/12]] [Hospital 123]."
doc = nlp(test.lower())

In [None]:
cleaned = [
    token.lemma_
    for token in doc
    if not token.is_punct and not token.is_stop and token.is_alpha
]
cleaned

In [None]:
len(TOKENIZER.encode(json.dumps(json_schema)))

In [None]:
def generate_notes_string(physician_notes_df: pd.DataFrame):
    # Could be smarter here alot of text overlap
    physician_notes_df_filtered = (
        physician_notes_df[["CHARTTIME", "TEXT"]]
        .drop_duplicates()
        .reset_index(drop=True)
    )
    added_sections = set()
    physician_notes = []
    for idx, note in physician_notes_df_filtered.sort_values("CHARTTIME").iterrows():
        new_sections = ""
        for note_section in re.split(
            "\n(?=^[^\n].*?:)", note["TEXT"], flags=re.MULTILINE
        ):
            if note_section not in added_sections:
                new_sections += "\n" + note_section
                added_sections.add(note_section)
        physician_notes.append(
            f"Physician Note {idx+1}: {note['CHARTTIME']}{new_sections}"
        )
    return "\n\n".join(physician_notes)

In [None]:
def create_system_message(json_schema: Dict) -> Message:
    return Message(
        role=Role.SYSTEM,
        content=f"""You are a consultant doctor tasked with writing a patients discharge summary.
Only the information in the physician notes provided by the user can be used for this task.
Each physician note has a title of the format Physician Note [number]: [timestamp].

The discharge summary must be written in accordance with the following json schema.
{json.dumps(json_schema)}
If the information is not present to fill in a field, answer it with an empty string or list.
""",
    )

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
def generate_discharge_summary(
    physician_note_df: pd.DataFrame,
    json_schema: Dict,
    llm: AzureOpenAIChatModel,
    max_prompt_tokens=15000,
) -> List[Message]:
    system_message = create_system_message(json_schema)

    notes_string = generate_notes_string(physician_note_df)
    user_message_content = (
        "Generate the discharge summary json given the following physician"
        f" notes\n\n{notes_string}"
    )
    prompt_messages = [
        system_message,
        Message(role=Role.USER, content=user_message_content),
    ]

    num_prompt_tokens = sum(
        len(TOKENIZER.encode(message.content)) for message in prompt_messages
    )
    print(num_prompt_tokens)
    if num_prompt_tokens > max_prompt_tokens:
        raise ValueError(
            f"Prompt has {num_prompt_tokens} tokens, which is greater than the max of"
            f" {max_prompt_tokens}."
        )

    return prompt_messages + [llm.query(prompt_messages)]

In [None]:
messages = generate_discharge_summary(physician_notes_hadm_id_df, json_schema, llm)

In [None]:
output_path = Path.cwd() / "output" / f"mimic_hadm_id_{int(hadm_id)}"
if not output_path.exists():
    output_path.mkdir(parents=True)

In [None]:
(output_path / "json_schema.json").write_text(json.dumps(json_schema, indent=4))

In [None]:
combined_notes = "\n\n".join(
    f"Physician Note {idx+1}: {note['CHARTTIME']}\n{note['TEXT']}"
    for idx, note in physician_notes_hadm_id_df.sort_values("CHARTTIME").iterrows()
)
(output_path / "physician_notes.txt").write_text(combined_notes)

In [None]:
(output_path / "discharge_summary.json").write_text(
    json.dumps(json.loads(messages[-1].content), indent=4)
)

In [None]:
(output_path / "prompts.json").write_text(
    json.dumps([message.dict() for message in messages[:-1]], indent=4)
)

In [None]:
mimic_hadm_id_ds = discharge_summary_df[discharge_summary_df["HADM_ID"] == hadm_id][
    "TEXT"
].iloc[0]
(output_path / "mimic_discharge_summary.txt").write_text(mimic_hadm_id_ds)

In [None]:
medications_structured_data