In [1]:
# Given clustering definitions, generate clusters N=4 times and aggregate/compare the results
import os
import glob
import openai
from dotenv import load_dotenv
from llm import call_openai_api
import pandas as pd
import concurrent.futures
from utils import ALL_SUBJECT_IDS

load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")
model_name = "gpt-4"
# model_name = "gpt-4-32k"
NUM_REPEATS = 4
INPUT_DIR = "outputs/20231209/clusters/"
OUTPUT_DIR = "outputs/20231209/recluster/"
EXPECTED_NUM_SUBJECTS = len(ALL_SUBJECT_IDS)

qna_files = glob.glob(INPUT_DIR + "/QnA_*.txt")
qna_files.sort()
# qna_files = qna_files[:2]
qna_files

['outputs/20231209/clusters/QnA_0.txt',
 'outputs/20231209/clusters/QnA_1.txt',
 'outputs/20231209/clusters/QnA_10.txt',
 'outputs/20231209/clusters/QnA_11.txt',
 'outputs/20231209/clusters/QnA_12.txt',
 'outputs/20231209/clusters/QnA_13.txt',
 'outputs/20231209/clusters/QnA_14.txt',
 'outputs/20231209/clusters/QnA_15.txt',
 'outputs/20231209/clusters/QnA_16.txt',
 'outputs/20231209/clusters/QnA_17.txt',
 'outputs/20231209/clusters/QnA_18.txt',
 'outputs/20231209/clusters/QnA_19.txt',
 'outputs/20231209/clusters/QnA_2.txt',
 'outputs/20231209/clusters/QnA_20.txt',
 'outputs/20231209/clusters/QnA_21.txt',
 'outputs/20231209/clusters/QnA_22.txt',
 'outputs/20231209/clusters/QnA_23.txt',
 'outputs/20231209/clusters/QnA_24.txt',
 'outputs/20231209/clusters/QnA_25.txt',
 'outputs/20231209/clusters/QnA_26.txt',
 'outputs/20231209/clusters/QnA_27.txt',
 'outputs/20231209/clusters/QnA_28.txt',
 'outputs/20231209/clusters/QnA_29.txt',
 'outputs/20231209/clusters/QnA_3.txt',
 'outputs/20231209/c

In [2]:
def create_prompt_from_qna_file(qna_file):
    with open(qna_file, "r") as file:
        content = file.read()

    # Prompt is everything before the last occurrence of "subject_id"
    content = content.replace("Subject_id", "subject_id")  # Sheesh!
    last_occurrence = content.rfind("subject_id")
    if last_occurrence != -1:
        content = content[:last_occurrence]

    content = content.replace("PROMPT:\n", "")
    content = content.replace(
        """Start your response by defining each top and secondary cluster in tab-separated-values format, with columns: 
cluster_id  cluster_name    cluster_description""",
        "",
    )
    content = content.replace(
        """RESPONSE:\n""",
        "Use the following cluster definitions (Do not repeat this in output):\n",
    )
    return content


# print(create_prompt_from_qna_file(qna_files[-5]))  # Test

In [7]:
# Verify cluster assignments
def verify_cluster_assignments(output_fname):
    errors = []
    try:
        # File read error
        df = pd.read_csv(output_fname, sep="\t")

        # Check if current number of Subject IDs
        unique_subject_ids = df["subject_id"].unique()
        if (
            len(unique_subject_ids) != EXPECTED_NUM_SUBJECTS
            or len(df) != EXPECTED_NUM_SUBJECTS
        ):
            errors.append(
                f"{len(unique_subject_ids)} elements not {EXPECTED_NUM_SUBJECTS}"
            )
        missing_subject_ids = set(ALL_SUBJECT_IDS) - set(unique_subject_ids)
        if len(missing_subject_ids) > 0:
            errors.append(f"missing subject_ids: {missing_subject_ids}")
        extra_subject_ids = set(unique_subject_ids) - set(ALL_SUBJECT_IDS)
        if len(extra_subject_ids) > 0:
            errors.append(f"has extra subject_ids: {extra_subject_ids}")

        # Check primary assignments
        # print(df.iloc[:, 1].unique())
        # print(df.iloc[:, 1].isna().sum())

        result = (output_fname, errors)
    except Exception as e:
        result = (output_fname, [str(e)])
    return result


# Define a function to process each repetition for a given file
def process_repetition(qna_file, repeat_num):
    q_num = int(qna_file.replace(".txt", "").split("_")[-1])
    output_fname = f"{OUTPUT_DIR}/Q{q_num:02d}_recluster.{repeat_num:02d}.tsv"
    if os.path.exists(output_fname):
        _, errors = verify_cluster_assignments(output_fname)
        if len(errors) > 0:
            print("Exists but errors found:", output_fname, " ".join(errors))
        else:
            return (output_fname, "Already exists and Verified OK")

    prompt = create_prompt_from_qna_file(qna_file)
    response = call_openai_api(prompt, model_name, [])
    response_message = response["choices"][0]["message"]["content"]

    # Remove everything before the first subject_id
    subject_id_index = response_message.find("subject_id")
    if subject_id_index != -1:
        response_message = response_message[subject_id_index:]

    with open(output_fname, "w") as output_file:
        output_file.write(response_message)

    return verify_cluster_assignments(output_fname)


# Function to process each file
def process_file(qna_file):
    repetition_results = []
    with concurrent.futures.ThreadPoolExecutor() as executor:
        # Submit tasks for each repetition
        future_to_repetition = {
            executor.submit(process_repetition, qna_file, i): i
            for i in range(1, NUM_REPEATS + 1)
        }
        for future in concurrent.futures.as_completed(future_to_repetition):
            repetition_num = future_to_repetition[future]
            try:
                repetition_result = future.result()
                repetition_results.append(repetition_result)
            except Exception as e:
                repetition_results.append(
                    f"Error in repetition {repetition_num} of file '{qna_file}': {e}"
                )
    return repetition_results


# Main code
with concurrent.futures.ThreadPoolExecutor(max_workers=12) as executor:
    # Submit tasks for each file
    future_to_file = {
        executor.submit(process_file, qna_file): qna_file for qna_file in qna_files
    }
    for future in concurrent.futures.as_completed(future_to_file):
        qna_file = future_to_file[future]
        try:
            file_results = future.result()
            for result in file_results:
                print(result)
        except Exception as e:
            print(f"Error processing file {qna_file}: {e}")

['C1' 'C3' 'C4' 'C2' 'C5']
['C1' 'C3' 'C4' 'C2' 'C5']
['C1' 'C3' 'C4' 'C2' 'C5']
['C1' 'C3' 'C4' 'C2' 'C5']
['C1' 'C2' 'C3' 'C4' 'C2,C14' 'C5' 'C6' 'C8' 'C7' 'C9,C3' 'C16' 'C7,C14'
 'C10' 'C11' 'C13' 'C14,C3' 'C14' 'C12' 'C7,C3' 'C15' 'C8,C14' 'C2,C6']
['C1' 'C2' 'C3' 'C4' 'C2,C14' 'C5' 'C6' 'C8' 'C7' 'C3,C9' 'C16' 'C7,C14'
 'C10' 'C11' 'C13' 'C14,C3' 'C14' 'C12' 'C7,C3' 'C15' 'C8,C14' 'C6,C2']
['C1' 'C2' 'C3' 'C4' 'C2, C14' 'C5' 'C6' 'C8' 'C7' 'C3, C9' 'C16'
 'C7, C14' 'C13' 'C10' 'C11' 'C3, C14' 'C14' 'C12' 'C7, C3' 'C15'
 'C8, C14' 'C6, C2']
['C1' 'C2' 'C3' 'C4' 'C2,C14' 'C5' 'C6' 'C8' 'C7' 'C9,C3' 'C16' 'C7,C14'
 'C13' 'C3,C9' 'C10' 'C11' 'C3,C14' 'C14' 'C12' 'C7,C3' 'C15' 'C8,C14'
 'C6,C2']
['C1' 'C2']
['C1' 'C2']
['C1' 'C2']
['C1' 'C2']
['C1' 'C2' 'C4' 'C3' 'C5']
['C1' 'C2' 'C5' 'C4' 'C3']
['C1' 'C2' 'C4' 'C3' 'C5']
['C1,C3,C5' 'C2' 'C1,C2,C3' 'C7' 'C2,C4' 'C3' 'C2,C3' 'C1,C8' 'C5' 'C4,C8'
 'C1,C3,C5,C8' 'C1,C2' 'C3,C4' 'C8' 'C3,C5,C8' 'C1,C3' 'C1' 'C4' 'C5,C8'
 'C6' 'C6,C8' 'C1,

In [4]:
# # Process each QnA file
# qna_files = glob.glob(INPUT_DIR + "/QnA_*.txt")
# qna_files.sort()
# for qna_file in qna_files[25]:
#     print(f"Processing {os.path.basename(qna_file)}...")
#     q_num = qna_file.replace(".txt", "").split("_")[-1]
#     with open(qna_file, "r") as file:
#         content = file.read()

#     prompt = create_prompt_from_qna_file(content)

#     for i in range(1, NUM_REPEATS + 1):
#         print(f"Reclustering repeat {i} of {NUM_REPEATS}:")
#         response = call_openai_api(prompt, model_name, [])
#         response_message = response["choices"][0]["message"]["content"]

#         # Remove content before "subject_id"
#         subject_id_index = response_message.find("subject_id")
#         if subject_id_index != -1:
#             response_message = response_message[subject_id_index:]

#         fname = f"{OUTPUT_DIR}/Q{q_num}_recluster.{i:02d}.tsv"
#         try:
#             with open(fname, "w") as output_file:
#                 output_file.write(response_message)
#                 print(f"Saved: {fname}")

#             # Load the saved file and check for unique elements
#             df = pd.read_csv(fname, sep="\t")
#             unique_subject_ids = df["subject_id"].nunique()
#             if unique_subject_ids != 93:
#                 print(
#                     f"Warning: File '{fname}' does not have 93 unique elements in 'subject_id' column. Found {unique_subject_ids} unique elements."
#                 )
#             else:
#                 print(
#                     f"Verified '{fname}' has 93 unique elements in 'subject_id' column."
#                 )

#         except Exception as e:
#             print(f"Error encountered: {e}. Moving to the next iteration.")

In [5]:
print("Finished...")

Finished...
