In [None]:
import os
import sys

src_path = os.path.abspath('../..')
print(src_path)
sys.path.append(src_path)

In [None]:
from src.utils import create_directory, raw_data_path, processed_data_path, set_seed

In [None]:
set_seed(seed=42)

In [None]:
import pandas as pd

In [None]:
output_path = os.path.join(processed_data_path, "mimic4")

In [None]:
from pandarallel import pandarallel

pandarallel.initialize()

## qa_event

In [None]:
qa_event_selected = pd.read_json(os.path.join(output_path, "qa_event_selected.jsonl"), lines = True)
qa_event_selected

In [None]:
qa_event = qa_event_selected

In [None]:
qa_event.event_type.value_counts()

In [None]:
qa_event.hadm_id.nunique()

In [None]:
qa_event['count'] = qa_event.groupby('hadm_id').cumcount()
qa_event['hadm_id'] = qa_event['hadm_id'].astype(str) + '_' + qa_event['count'].astype(str)
qa_event = qa_event.drop(columns=['count'])
qa_event

In [None]:
qa_event.hadm_id.nunique()

### GPT

In [None]:
system_content = """You are an AI assistant with expertise in medical knowledge. 

Your input consists of a question-answer pair created using predefined rules. 

Your primary task is to rephrase both the question and the answer to introduce variety in the wording while preserving their original meanings.

Objective:
1. Rewrite the question and answer using language that mimics how a physician might phrase them in a real-world setting.
2. Ensure the paraphrased text is grammatically correct.
3. Adjust capitalization as needed.
4. Maintain the original intent and meaning of the question-answer pair.
5. Format your response as follows:
- Question: [Your paraphrased question]
- Answer: [Your paraphrased answer]
6. Aim for brevity in both the question and answer."""

In [None]:
def wrap_user_content(row):
    return f"Question: {row.q}\nAnswer: {row.a}"

In [None]:
print(wrap_user_content(qa_event.iloc[0]))

In [None]:
prompts = {}
for _, row in qa_event.iterrows():
    messages = [{"role": "system", "content": system_content},
                {"role": "user", "content": wrap_user_content(row)}]
    prompts[row.hadm_id] = messages
len(prompts)

In [None]:
prompts["24903681_5"]

In [None]:
max_response_tokens = 256

In [None]:
import asyncio
from openai import AsyncAzureOpenAI


# TODO: Enter your credentials
async_client = AsyncAzureOpenAI(
    azure_endpoint="",
    api_key="",
    api_version=""
)

In [None]:
async def generate_chat_response(async_client, prompt):
    chat_params = {
        "model": "gpt-3.5-turbo",
        "messages": prompt,
        "max_tokens": max_response_tokens,
        "temperature": 0.0,
    }
    try:
        response = await async_client.chat.completions.create(**chat_params)
    except Exception as e:
        print(f"Error in call_async: {e}")
        time.sleep(10)
        print(f"Sleep for 10s...")
        return -1
    return response.choices[0].message.content

In [None]:
import time


async def process_prompts(prompts):
    # Gather all the futures together and wait for them to complete
    responses = await asyncio.gather(*(generate_chat_response(async_client, prompt) for prompt in prompts))        
    return responses

In [None]:
prompts_subset = {k: prompts[k] for k in list(prompts.keys())[:10]}
len(prompts_subset)

In [None]:
def chunk_list(lst, chunk_size):
    """Yield successive chunk_size chunks from lst."""
    for i in range(0, len(lst), chunk_size):
        yield lst[i:i + chunk_size]

In [None]:
from tqdm.asyncio import tqdm


async def process_prompts_in_batches(prompts, batch_size, repeat=3):
    all_responses = {}
    
    for i in range(repeat):
        
        print(f"round {i}")
        prev_n_responses = len(all_responses)
        
        prompts_k = [k for k in prompts.keys() if k not in all_responses]

        # Chunk the prompts into batches
        prompt_k_batches = list(chunk_list(prompts_k, batch_size))

        for batch_k in tqdm(prompt_k_batches, desc="Processing Batches"):
            batch_v = [prompts[k] for k in batch_k]
            responses = await process_prompts(batch_v)
            all_responses |= {k: v for k, v in zip(batch_k, responses) if type(v) is str}
        print(f"get {len(all_responses) - prev_n_responses} new responses")
    
    return all_responses

In [None]:
# Choose an appropriate batch size
batch_size = 10  # Adjust based on your system and API limits

# Assuming we are in an async environment
responses = await process_prompts_in_batches(prompts_subset, batch_size)
print(f"Processed {len(responses)} responses")

In [None]:
responses

In [None]:
def split_dict_equally(input_dict, chunks=5):
    # Calculate the size of each chunk
    chunk_size = len(input_dict) // chunks
    # Calculate how many items will be in the last chunk
    last_chunk_size = chunk_size + (len(input_dict) % chunks)
    
    # An iterator over the items of the original dictionary
    it = iter(input_dict)
    
    # This will store the list of smaller dictionaries
    result = []

    for i in range(chunks):
        # Use a dictionary comprehension to create a smaller dictionary
        # The last chunk will take the remaining items
        if i < chunks - 1:
            part_dict = {k: input_dict[k] for k in (next(it) for _ in range(chunk_size))}
        else:
            part_dict = {k: input_dict[k] for k in (next(it) for _ in range(last_chunk_size))}
        result.append(part_dict)
    
    return result

In [None]:
split_prompts = split_dict_equally(prompts, 10)
len(split_prompts)

In [None]:
sum([len(i) for i in split_prompts])

In [None]:
import re


def split_qa(qa, verbose=False):
    if verbose:
        print(qa)
    pattern = r"-?\s*Question: (.*)\s*-?\s*Answer: (.*)"
    match = re.search(pattern, qa)
    question = match.group(1)
    answer = match.group(2)
    if verbose:
        print("Question:", question)
        print("Answer:", answer)
    return question, answer

In [None]:
hadm_id_to_event_type = qa_event.set_index("hadm_id").to_dict()["event_type"]
len(hadm_id_to_event_type)

In [None]:
import pickle

with open(os.path.join(output_path, "split_prompts_cache.pkl"), 'wb') as file:
    pickle.dump(split_prompts, file)

In [None]:
import json

for chunk_i, chunk_prompts in enumerate(split_prompts):
    print(f"Processing chunk {chunk_i} with {len(chunk_prompts)} prompts")
    
    # Choose an appropriate batch size
    batch_size = 10  # Adjust based on your system and API limits

    # Assuming we are in an async environment
    responses = await process_prompts_in_batches(chunk_prompts, batch_size)
    print(f"Processed {len(responses)} responses")
    
    responses_split = {}
    for hadm_id, qa in responses.items():
        responses_split[hadm_id] = split_qa(qa)
        
    print(f"After split: {len(responses_split)}")
    
    filename = f"qa_event_{chunk_i}.jsonl"
    print(f"filename: {filename}")
    
    with open(os.path.join(output_path, filename), "w") as file:
        for hadm_id, (q, a) in responses_split.items():
            # Convert the dictionary to a JSON string and write it to the file
            actual_hadm_id = int(hadm_id.split("_")[0])
            json_string = json.dumps({"hadm_id": actual_hadm_id, "q": q, "a": a, "event_type": hadm_id_to_event_type[hadm_id]})
            file.write(json_string + '\n')

In [None]:
print("done")

In [None]:
input_files = [f"qa_event_{i}.jsonl" for i in range(10)]
output_file = f"qa_event.jsonl"

# Open the output file in write mode
with open(os.path.join(output_path, output_file), 'w') as outfile:
    # Iterate over each input file
    for input_file in input_files:
        # Open the input file in read mode
        with open(os.path.join(output_path, input_file), 'r') as infile:
            # Read each line (each JSON object) in the input file
            for line in infile:
                # Write the JSON object to the output file
                outfile.write(line)