In [None]:
import json
import os
import pickle
import re
from pathlib import Path
from typing import List

from dotenv import load_dotenv
from matplotlib import pyplot as plt
from sentence_transformers import SentenceTransformer

from discharge_summaries.openai_llm.chat_models import AzureOpenAIChatModel
from discharge_summaries.openai_llm.message import Message, Role
from discharge_summaries.openai_llm.token_count import (
    num_tokens_from_messages_azure_engine,
)
from discharge_summaries.schemas.mimic import Note, Record

In [None]:
load_dotenv()

In [None]:
DATA_DIR = Path.cwd().parent / "data"
OUTPUT_DIR = Path.cwd() / "output"


TRAINING_DATASET_PATH = DATA_DIR / "train_all_ds.pkl"
RANDOM_SEED = 23
AZURE_ENGINE = "gpt-4-32k"
AZURE_API_VERSION = "2023-07-01-preview"
# AZURE_ENGINE = "gpt-35-turbo"
# AZURE_API_VERSION = "2023-07-01-preview"

GUIDELINES_JSON_SCHEMA_PATH = (
    Path.cwd().parent / "guidelines" / "eDischarge-Summary-v2.1-1st-Feb-21_schema.json"
)

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
len(dataset)
sample = dataset[0]

Remove indents (save tokens)

In [None]:
guidelines_json_schema_json = json.loads(GUIDELINES_JSON_SCHEMA_PATH.read_text())
guidelines_json_schema_str = json.dumps(guidelines_json_schema_json)

In [None]:
llm = AzureOpenAIChatModel(
    api_base=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_KEY"),
    api_version=AZURE_API_VERSION,
    engine=AZURE_ENGINE,
    temperature=0,
    timeout=20,
)

In [None]:
SYSTEM_MESSAGE = Message(
    role=Role.SYSTEM,
    content=f"""You are a consultant doctor tasked with writing a patients discharge summary.
Only the information in the physician notes provided by the user can be used for this task.
Each physician note has a title of the format Physician Note [number]: [timestamp].

The discharge summary must be written in accordance with the following json schema.
{guidelines_json_schema_str}
If the information is not present to fill in a field, answer it with an empty string.
""",
)

In [None]:
def generate_notes_string(notes: List[Note]):
    return "\n\n".join(
        f"Physician Note {idx+1}: {note.datetime}\n{note.text}"
        for idx, note in enumerate(notes)
    )

## Semantic

In [None]:
def text_to_chunks(text: str, prefix: str, max_chunk_length: int = 128) -> List[str]:
    chunks = []
    for section in text.split("\n\n"):
        chunk = prefix
        for sentence in re.split("\n(?=[^ a-z])|(?<=[?|!|.])\\s", section):
            if (
                num_tokens_from_messages_azure_engine(
                    [Message(content=chunk + sentence, role=Role.USER)],
                    AZURE_ENGINE,
                    AZURE_API_VERSION,
                )
                > max_chunk_length
            ):
                chunks.append(chunk)
                chunk = prefix
            else:
                chunk += f"\n{sentence}"
        chunks.append(chunk)

    return chunks

In [None]:
physician_note_chunks = [
    chunk
    for idx, note in enumerate(sample.physician_notes)
    for chunk in text_to_chunks(note.text, f"Physician Note {idx}, {note.datetime}\n")
]

In [None]:
plt.hist(
    [
        num_tokens_from_messages_azure_engine(
            [Message(content=chunk, role=Role.USER)], AZURE_ENGINE, AZURE_API_VERSION
        )
        for chunk in physician_note_chunks
    ]
)
plt.show()

In [None]:
embedder = SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
corpus_embeddings = embedder.encode(
    physician_note_chunks, convert_to_tensor=True, show_progress_bar=True
)