In [None]:
import json
import pickle
import re
from pathlib import Path
from typing import List

from dotenv import load_dotenv
from matplotlib import pyplot as plt
from sentence_transformers import SentenceTransformer

from discharge_summaries.openai_llm.message import Message, Role
from discharge_summaries.openai_llm.token_count import (
    num_tokens_from_messages_azure_engine,
)
from discharge_summaries.schemas.mimic import Record
from discharge_summaries.schemas.prsb_guidelines import Section

In [None]:
load_dotenv()

In [None]:
DATA_DIR = Path.cwd().parent / "data"
OUTPUT_DIR = Path.cwd() / "output"


TRAINING_DATASET_PATH = DATA_DIR / "train_all_ds.pkl"
RANDOM_SEED = 23
AZURE_ENGINE = "gpt-4-32k"
AZURE_API_VERSION = "2023-07-01-preview"
# AZURE_ENGINE = "gpt-35-turbo"
# AZURE_API_VERSION = "2023-07-01-preview"

GUIDELINES_JSON_PATH = (
    Path.cwd().parent
    / "guidelines"
    / "eDischarge-Summary-v2.1-1st-Feb-21_extract_text_elements.json"
)

In [None]:
with open(TRAINING_DATASET_PATH, "rb") as in_file:
    dataset = [Record(**record) for record in pickle.load(in_file)]
dataset = dataset
sample = dataset[0]

In [None]:
def text_to_chunks(text: str, prefix: str, max_chunk_length: int = 128) -> List[str]:
    chunks = []
    for section in text.split("\n\n"):
        chunk = prefix
        for sentence in re.split("\n(?=[^ a-z])|(?<=[?|!|.])\\s", section):
            if (
                num_tokens_from_messages_azure_engine(
                    [Message(content=chunk + sentence, role=Role.USER)],
                    AZURE_ENGINE,
                    AZURE_API_VERSION,
                )
                > max_chunk_length
            ):
                chunks.append(chunk)
                chunk = prefix
            else:
                chunk += f"\n{sentence}"
        chunks.append(chunk)

    return chunks

In [None]:
physician_note_chunks = [
    chunk
    for idx, note in enumerate(sample.physician_notes)
    for chunk in text_to_chunks(note.text, f"Physician Note {idx}, {note.datetime}\n")
]

In [None]:
plt.hist(
    [
        num_tokens_from_messages_azure_engine(
            [Message(content=chunk, role=Role.USER)], AZURE_ENGINE, AZURE_API_VERSION
        )
        for chunk in physician_note_chunks
    ]
)
plt.show()

In [None]:
embedder = SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
corpus_embeddings = embedder.encode(
    physician_note_chunks, convert_to_tensor=True, show_progress_bar=True
)

In [None]:
guidelines = [
    Section(**section_dict)
    for section_dict in json.loads(GUIDELINES_JSON_PATH.read_text())
]

In [None]:
guidelines[0].name, guidelines[0].description

In [None]:
def section_to_prompts(section: Section) -> List[str]:
    if section.is_record:
    
    return [
        f"""Section: {section.name}
Section Description: {section_description}
Element Name: {element.name}
Element Description: f"{element.description} {element.values}"
Element Cardinality: {element.cardinality}"""
        for element in section.elements[starting_idx:]
    ]

In [None]:
# prompts = section_to_prompts(guidelines[4])

# print(prompts[0])

In [None]:
# element_prompt = (guideline_element_to_prompt(guidelines[0].elements[1], guidelines[0]))
# query_embedding = embedder.encode(element_prompt, convert_to_tensor=True)
# scores = util.dot_score(query_embedding, corpus_embeddings)[0]
# top_results = torch.topk(scores, k=5)
# print("Query:", element_prompt)
# print("\nTop 5 most similar sentences in corpus:")

# for score, idx in zip(top_results[0], top_results[1]):
#     print(physician_note_chunks[idx], "(Score: {:.4f})".format(score))
#     print()