In [None]:
import os
import sys

src_path = os.path.abspath('../..')
print(src_path)
sys.path.append(src_path)

In [None]:
from src.utils import create_directory, raw_data_path, processed_data_path, set_seed

In [None]:
set_seed(seed=42)

In [None]:
import pandas as pd

In [None]:
mimic_iv_path = os.path.join(raw_data_path, "physionet.org/files/mimiciv/2.2")
mimic_iv_note_path = os.path.join(raw_data_path, "physionet.org/files/mimic-iv-note/2.2")
output_path = os.path.join(processed_data_path, "mimic4")

In [None]:
cohort = pd.read_csv(os.path.join(output_path, "cohort.csv"))
print(cohort.shape)
cohort.head()

In [None]:
cohort["hadm_intime"] = pd.to_datetime(cohort["hadm_intime"])
cohort["hadm_outtime"] = pd.to_datetime(cohort["hadm_outtime"])
cohort["stay_intime"] = pd.to_datetime(cohort["stay_intime"])
cohort["stay_outtime"] = pd.to_datetime(cohort["stay_outtime"])

In [None]:
hadm_ids = set(cohort.hadm_id.unique().tolist())
len(hadm_ids)

In [None]:
from pandarallel import pandarallel

pandarallel.initialize()

## discharge

In [None]:
discharge = pd.read_csv(os.path.join(mimic_iv_note_path, "note/discharge.csv.gz"))
print(discharge.shape)
discharge.head()

In [None]:
discharge = discharge[discharge.hadm_id.isin(hadm_ids)]
print(discharge.shape)
discharge.head()

In [None]:
print(discharge.text.iloc[0])

In [None]:
# https://github.com/hanyin88/DRG-LLaMA/blob/main/data/MIMIC_Preprocessing.py
import re


def extract_HC(dc_summary_raw):

    # Set up the regular expression to extract hospital course from discharge summary
    # Of note these patterns would not caputre all hospital courses, and is indeed a convservative approach to ensure quality of data
    pattern1  = re.compile("Brief Hospital Course.*\n*((?:\n.*)+?)(Medications on Admission|___  on Admission|___ on Admission)")
    pattern2  = re.compile("Brief Hospital Course.*\n*((?:\n.*)+?)Discharge Medications")
    pattern3  = re.compile("(Brief Hospital Course|rief Hospital Course|HOSPITAL COURSE)\
                        .*\n*((?:\n.*)+?)\
                        (Medications on Admission|Discharge Medications|DISCHARGE MEDICATIONS|DISCHARGE DIAGNOSIS|Discharge Disposition|___ Disposition|CONDITION ON DISCHARGE|DISCHARGE INSTRUCTIONS)")
    pattern4  = re.compile("(Mg-[12].|LACTATE-[12].|Epi-|Gap-|COUNT-|TRF-)___(.*\n*((?:\n.*)+?))(Medications on Admission)")


    # Idea here is to try more convservaite pattern first, if not work, try less conservative pattern
    def split_note(note):
        if re.search(pattern1, note):
            return re.search(pattern1, note).group(1)
        elif re.search(pattern2, note):
            return re.search(pattern2, note).group(1)
        elif re.search(pattern3, note):
            return re.search(pattern3, note).group(2)
        elif re.search(pattern4, note):
            return re.search(pattern4, note).group(2)
        else:
            return None

    # Apply the function to dc_summary_raw to extract hospital course
    dc_summary_raw["hospital_course"] = dc_summary_raw["text"].apply(split_note)

    # Drop those records that do not have hospital course captured with above regular expression patterns
    dc_summary = dc_summary_raw[["hadm_id", "text", "hospital_course"]].dropna()

    # Get the number of words for each hospital course. Note that the current method is not accurate due to presense of special characters, but it's good enough for our purpose
    dc_summary["num_words"] = dc_summary["hospital_course"].apply(lambda x: len(x.split()))

    # Remove the notes with less than 40 words
    dc_summary = dc_summary[dc_summary["num_words"] > 40]

    # Remove duplicate hospital courses (but keep the first one), as most of these notes represent low quality data
    dc_summary = dc_summary.drop_duplicates(subset=["hospital_course"], keep="first")

    # Mean number of words in the hospital course is 378
    dc_summary["num_words"].mean()

    dc_summary = dc_summary[["hadm_id", "text", "hospital_course"]]

    return dc_summary

In [None]:
discharge = extract_HC(discharge)

In [None]:
discharge

In [None]:
# https://github.com/ji-youn-kim/EHRNoteQA/blob/master/src/preprocessing/preprocess.py
import re


def transform_string(s):
    s = re.sub(r'(\n\s*|\s*\n)', '\n', s)
    s = re.sub(r'\s{2,}', ' ', s)
    s = s.strip()
    return s

In [None]:
print(transform_string(discharge.hospital_course.iloc[1]))

In [None]:
discharge["cleaned_hospital_course"] = discharge.hospital_course.parallel_apply(transform_string)
discharge.head()

In [None]:
print(discharge.cleaned_hospital_course.iloc[4])

## GPT

In [None]:
system_content = """You are an AI assistant specialized in analyzing ICU patients' data.

You are provided with a discharge summary of an ICU patient, which summarizes important clinical records and serves as an essential reference for the doctor’s clinical decision-making.

Your task is to generate a question-answer pair inquiring about the patient. 

Objective:
1. Formulate one question that a doctor will ask based on the provided discharge summary.
2. The answer should be found within the provided discharge summary. 
3. Refrain from formulating questions that can be answered without referring to the provided discharge summary.
4. Avoid questions that include sensitive personal information or "___".
5. Do not create questions that are too easy to answer. To answer your question, someone should have the clinical expertise equivalent to a doctor and must fully understand all provided discharge summaries.
6. Arrange your output in the following format:
- Question: [Your Question]
- Answer: [Your Answer]
7. Keep both the question and answer concise (within 256 tokens)."""

In [None]:
import tiktoken


def num_tokens_from_message(message):
    encoding = tiktoken.encoding_for_model("gpt-35-turbo-0125")
    return len(encoding.encode(message[0]["content"])) + len(encoding.encode(message[1]["content"])) + 11

In [None]:
prompts = {}
for _, row in discharge.iterrows():
    messages = [{"role": "system", "content": system_content},
                {"role": "user", "content": row.cleaned_hospital_course}]
    prompts[row.hadm_id] = messages
len(prompts)

In [None]:
prompts[29079034]

In [None]:
prompts_num_tokens = {}
for k, v in prompts.items():
    prompts_num_tokens[k] = num_tokens_from_message(v)

In [None]:
import numpy as np


print("mean: ", np.mean(list(prompts_num_tokens.values())))
print("std: ", np.std(list(prompts_num_tokens.values())))
print("min: ", np.min(list(prompts_num_tokens.values())))
print("max: ", np.max(list(prompts_num_tokens.values())))
print("25th Quantile: ", np.percentile(list(prompts_num_tokens.values()), 25))
print("50th Quantile: ", np.percentile(list(prompts_num_tokens.values()), 50))
print("75th Quantile: ", np.percentile(list(prompts_num_tokens.values()), 75))

In [None]:
max_response_tokens = 256
token_limit = 16384

In [None]:
prompts_filtered = {}
for k, v in prompts.items():
    if prompts_num_tokens[k] + max_response_tokens < token_limit:
        prompts_filtered[k] = v
    else:
        print(f"hadm id {k} is filtered due to length {prompts_num_tokens[k]}")

In [None]:
print(len(prompts))
print(len(prompts_filtered))

In [None]:
prompts_filtered[29079034]

In [None]:
import asyncio
from openai import AsyncAzureOpenAI


# TODO: Enter your credentials
async_client = AsyncAzureOpenAI(
    azure_endpoint="",
    api_key="",
    api_version=""
)

In [None]:
async def generate_chat_response(async_client, prompt):
    chat_params = {
        "model": "gpt-3.5-turbo",
        "messages": prompt,
        "max_tokens": max_response_tokens,
        "temperature": 0.0,
    }
    try:
        response = await async_client.chat.completions.create(**chat_params)
    except Exception as e:
        print(f"Error in call_async: {e}")
        time.sleep(10)
        print(f"Sleep for 10s...")
        return -1
    return response.choices[0].message.content

In [None]:
import time


async def process_prompts(prompts):
    # Gather all the futures together and wait for them to complete
    responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts))        
    return responses

In [None]:
len(prompts_filtered)

In [None]:
prompts_filtered_subset = {k: prompts_filtered[k] for k in list(prompts_filtered.keys())[:10]}
len(prompts_filtered_subset)

In [None]:
def chunk_list(lst, chunk_size):
    """Yield successive chunk_size chunks from lst."""
    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]

In [None]:
from tqdm.asyncio import tqdm


async def process_prompts_in_batches(prompts, batch_size, repeat=3):
    all_responses = {}
    
    for i in range(repeat):
        
        print(f"round {i}")
        prev_n_responses = len(all_responses)
        
        prompts_k = [k for k in prompts.keys() if k not in all_responses]

        # Chunk the prompts into batches
        prompt_k_batches = list(chunk_list(prompts_k, batch_size))

        for batch_k in tqdm(prompt_k_batches, desc="Processing Batches"):
            batch_v = [prompts[k] for k in batch_k]
            responses = await process_prompts(batch_v)
            all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}
        print(f"get {len(all_responses) - prev_n_responses} new responses")
    
    return all_responses

In [None]:
# Choose an appropriate batch size
batch_size = 10  # Adjust based on your system and API limits

# Assuming we are in an async environment
responses = await process_prompts_in_batches(prompts_filtered_subset, batch_size)
print(f"Processed {len(responses)} responses")

In [None]:
responses

In [None]:
len(prompts_filtered)

In [None]:
# Choose an appropriate batch size
batch_size = 10  # Adjust based on your system and API limits

# Assuming we are in an async environment
responses = await process_prompts_in_batches(prompts_filtered, batch_size)
print(f"Processed {len(responses)} responses")

In [None]:
responses[28369884]

In [None]:
c = 0
for qa in responses.values():
    if "___" in qa:
        c += 1
c

In [None]:
import re


def split_qa(qa, verbose=False):
    if verbose:
        print(qa)
    pattern1 = r"-?\s*Question: (.*)\s*-?\s*Answer: (.*)"
    pattern2 = r"\*\*Question:\*\* (.*)\s*-?\s*\*\*Answer:\*\* (.*)"
    match = re.search(pattern1, qa)
    if match is None:
        match = re.search(pattern2, qa)
    question = match.group(1)
    answer = match.group(2)
    if verbose:
        print("Question:", question)
        print("Answer:", answer)
    return question, answer

In [None]:
responses_split = {}
for hadm_id, qa in responses.items():
    try:
        responses_split[hadm_id] = split_qa(qa)
    except AttributeError:
        print(qa)
        print("=====================")

In [None]:
responses_split[28369884]

In [None]:
len(responses_split)

In [None]:
import json


with open(os.path.join(output_path, "qa_note_orig.jsonl"), "w") as file:
    for hadm_id, qa in responses.items():
        # Convert the dictionary to a JSON string and write it to the file
        json_string = json.dumps({"hadm_id": hadm_id, "qa": qa})
        file.write(json_string + '\n')

In [None]:
import json


with open(os.path.join(output_path, "qa_note.jsonl"), "w") as file:
    for hadm_id, (q, a) in responses_split.items():
        # Convert the dictionary to a JSON string and write it to the file
        json_string = json.dumps({"hadm_id": hadm_id, "q": q, "a": a})
        file.write(json_string + '\n')