In [None]:
import json
import os
import pickle
import re
from pathlib import Path
from typing import Dict, List

import torch
from dotenv import load_dotenv
from matplotlib import pyplot as plt
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer, util
from tqdm.notebook import tqdm

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.message import Message, Role
from discharge_summaries.openai_llm.token_count import (
    num_tokens_from_messages_azure_engine,
)
from discharge_summaries.schemas.mimic import Note, Record

In [None]:
load_dotenv()

In [None]:
DATA_DIR = Path.cwd().parent / "data"
OUTPUT_DIR = Path.cwd() / "output"


TRAINING_DATASET_PATH = DATA_DIR / "train_all_ds.pkl"
RANDOM_SEED = 23
AZURE_ENGINE = "gpt-4-32k"
AZURE_API_VERSION = "2023-07-01-preview"
# AZURE_ENGINE = "gpt-35-turbo"
# AZURE_API_VERSION = "2023-07-01-preview"

GUIDELINES_JSON_SCHEMA_PATH = (
    Path.cwd().parent / "guidelines" / "eDischarge-Summary-v2.1-1st-Feb-21_schema.json"
)
LLAMA_2_CONTEXT_LENGTH = 4096
MAX_PROMPT_LENGTH = LLAMA_2_CONTEXT_LENGTH - 1000

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
len(dataset)

## Semantic

In [None]:
class TextChunk(BaseModel):
    text: str
    timestamp: str
    token_length: int


def text_to_chunks(note: Note, max_chunk_length: int = 128) -> List[TextChunk]:
    chunks = []
    for section in note.text.split("\n\n"):
        prefix = f"Physician Note Extract\nTimestamp: {note.datetime}"
        prefix_length = num_tokens_from_messages_azure_engine(
            [Message(content=prefix, role=Role.USER)], AZURE_ENGINE, AZURE_API_VERSION
        )

        chunk_text = prefix
        chunk_length = prefix_length

        for sentence in re.split("\n(?=[^ a-z])|(?<=[?|!|.])\\s", section):
            new_chunk_text = f"{chunk_text}\n{sentence}"
            new_chunk_length = num_tokens_from_messages_azure_engine(
                [Message(content=new_chunk_text, role=Role.USER)],
                AZURE_ENGINE,
                AZURE_API_VERSION,
            )
            if new_chunk_length > max_chunk_length:
                chunks.append(
                    TextChunk(
                        text=chunk_text,
                        timestamp=note.datetime,
                        token_length=chunk_length,
                    )
                )
                chunk_text = prefix
                chunk_length = prefix_length
            else:
                chunk_text = new_chunk_text
                chunk_length = new_chunk_length

        if chunk_text != prefix:
            chunks.append(
                TextChunk(
                    text=chunk_text, timestamp=note.datetime, token_length=chunk_length
                )
            )

    return chunks

In [None]:
dataset_note_chunks = [
    [chunk for note in sample.physician_notes for chunk in text_to_chunks(note)]
    for sample in tqdm(dataset[:1000])
]

In [None]:
dataset_note_chunks[0][:5]

In [None]:
plt.hist(
    [chunk.token_length for note_chunks in dataset_note_chunks for chunk in note_chunks]
)
plt.show()

In [None]:
guidelines_json_schema_json = json.loads(GUIDELINES_JSON_SCHEMA_PATH.read_text())

In [None]:
element_schemas_json = [
    {
        "type": "object",
        "description": f"{section_title}. {section_dict['description']}",
        "properties": {element_title: element_dict},
    }
    for section_title, section_dict in guidelines_json_schema_json["properties"].items()
    for element_title, element_dict in section_dict["properties"].items()
]
len(element_schemas_json), element_schemas_json[:4]

In [None]:
def element_json_schema_to_str(element_json_schema: Dict) -> str:
    section_str = element_json_schema["description"]
    assert len(element_json_schema["properties"].items()) == 1

    element_title, element_dict = next(iter(element_json_schema["properties"].items()))

    if element_dict["type"] == "array":
        element_str = "\n".join(
            [
                f"{array_element_title}. {array_element_dict['description']}"
                for array_element_title, array_element_dict in element_dict[
                    "items"
                ].items()
            ]
        )
    else:
        element_str = f"{element_title}. {element_dict['description']}."

    return f"{section_str}\n{element_str}"


guideline_element_strs = [
    element_json_schema_to_str(element_json_schema)
    for element_json_schema in element_schemas_json
]

print(guideline_element_strs[0])
print(guideline_element_strs[3])

In [None]:
embedder = SentenceTransformer("msmarco-distilbert-base-tas-b")

In [None]:
guideline_element_embeddings = embedder.encode(
    guideline_element_strs, convert_to_tensor=True, show_progress_bar=True
)

In [None]:
n = 20
dataset_similarity_scores_cumsum = torch.ones(
    (len(dataset_note_chunks), len(element_schemas_json), n)
)

softmax = torch.nn.Softmax(dim=1)
for idx, note_chunks in tqdm(enumerate(dataset_note_chunks)):
    note_chunk_embeddings = embedder.encode(
        [chunk.text for chunk in note_chunks], convert_to_tensor=True
    )
    similarity_scores = util.dot_score(
        guideline_element_embeddings, note_chunk_embeddings
    ).cpu()
    similarity_scores_normalized, _ = softmax(similarity_scores).sort(
        dim=1, descending=True
    )
    similarity_scores_cumsum = similarity_scores_normalized.cumsum(dim=1)
    copy_n = min(similarity_scores_cumsum.shape[-1], n)
    dataset_similarity_scores_cumsum[idx, :, :copy_n] = similarity_scores_cumsum[
        :, :copy_n
    ]
dataset_similarity_scores_cumsum.shape
# for chunk in note_chunks:
#     chunk.guideline_element_embeddings = guideline_element_embeddings

In [None]:
dataset_similarity_scores_cumsum_merged = dataset_similarity_scores_cumsum.reshape(
    -1, n
)
dataset_similarity_scores_mean = dataset_similarity_scores_cumsum_merged.mean(dim=0)
dataset_similarity_scores_std = dataset_similarity_scores_cumsum_merged.std(dim=0)

dataset_similarity_scores_mean.shape, dataset_similarity_scores_std.shape

In [None]:
plt.plot(dataset_similarity_scores_cumsum_merged.mean(dim=0))
plt.fill_between(
    range(len(dataset_similarity_scores_std)),
    dataset_similarity_scores_mean + dataset_similarity_scores_std,
    dataset_similarity_scores_mean - dataset_similarity_scores_std,
    color="gray",
    alpha=0.5,
    label="Confidence Bounds",
)
plt.xlabel("n")
plt.ylabel("Mean Cumulative Similarity Score")
plt.show()

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
SYSTEM_MESSAGE = Message(
    role=Role.SYSTEM,
    content="""You are a consultant doctor tasked with writing a patients discharge summary.
Only the information in the physician notes provided by the user can be used for this task.
Each physician note has a title of the format Physician Note [number]: [timestamp].

The discharge summary must be written in accordance with the following json schema.
guidelines_json_schema_str
If the information is not present to fill in a field, answer it with an empty string.
""",
)