In [None]:
PROJECT_HOME = "."

# # For Colab

# PROJECT_HOME = "/content/drive/My Drive/Projects/LLM-MCI-detection"

# # Google Drive storage setup
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

In [None]:
import os
import pandas as pd

In [None]:
output_dir = os.path.join(PROJECT_HOME, "data")
os.makedirs(output_dir, exist_ok=True)

In [None]:
llm_model_names = ["gemma-2-9B", "llama-3.1-8B", "gpt-35-turbo", "gpt-4"]

In [None]:
N_fold_observational_generation = 5
N_fold_cross_lingual_generation = 5

In [None]:
original_data = pd.read_csv(os.path.join(PROJECT_HOME, 'data', 'original.csv'))

In [None]:
for llm_model_name in llm_model_names:
    for generation_type in ["observational", "cross-lingual", "counterfactual"]:
        # some generation types may have multiple runs to get a balanced dataset
        if generation_type == "counterfactual":
            base_dirs = [os.path.join(PROJECT_HOME, "data", "%s-generation" % generation_type)]
        elif generation_type == "observational":
            base_dirs = [os.path.join(PROJECT_HOME, "data", "%s-generation" % generation_type, "%d" % i) for i in range(N_fold_observational_generation)]
        elif generation_type == "cross-lingual":
            base_dirs = [os.path.join(PROJECT_HOME, "data", "%s-generation" % generation_type, "%d" % i) for i in range(N_fold_observational_generation)]

        records = []
        for base_dir in base_dirs:
            if not os.path.exists(os.path.join(base_dir)):
                continue

            generated_file_names = [file_name for file_name in os.listdir(os.path.join(base_dir, llm_model_name)) if file_name.endswith(".txt")]

            for txt_file_name in generated_file_names:
                file_idx = int(txt_file_name.split(".")[0])
                original_row = original_data.iloc[file_idx]

                # Original label
                original_label = original_row['label']
                assert original_label in ["NC", "MCI"]

                # Controlled variables
                age = original_row['age']
                gender = original_row['gender']
                race = original_row['race']
                education = original_row['education']

                # Label based on how samples are geneated
                if generation_type == "counterfactual":
                    label = "MCI" if original_label == "NC" else "NC"
                elif generation_type in ["observational", "cross-lingual"]:
                    label = original_label

                line_header = "Text"
                if generation_type == "cross-lingual":
                    line_header = "Chinese"

                with open(os.path.join(base_dir, llm_model_name, txt_file_name)) as txt_file:
                    lines = txt_file.readlines()
                    non_empty_lines = []
                    for line in lines:
                        if line.strip() == "":
                            continue
                        non_empty_lines.append(line.strip())
                    text = None
                    for line_number, line in enumerate(non_empty_lines):

                        line = line.lstrip() # remove the spaces at the begining
                        line = line.replace("**", "") # "**Text:**" => "Text:"
                        if text is None and line.startswith(line_header):
                            """

                            Text: [new text data]

                            Text:
                            [new text data]

                            """
                            text = line.split(":")[1].strip()
                            if text == "": # Use the next line
                                text = non_empty_lines[line_number+1].strip()
                                if text == "":
                                    raise ValueError("Cannot find transcription: %s" % os.path.join(base_dir, llm_model_name, txt_file_name))
                            records.append({
                                "index": file_idx,
                                "label": label,
                                "age": age,
                                "gender": gender,
                                "race": race,
                                "education": education,
                                "text": text
                            })
                    if text is None:
                        print("No data for %s" % os.path.join(base_dir, llm_model_name, txt_file_name))

        new_df = pd.DataFrame(records)
        print("# of %s generation samples (%s): %d" % (generation_type, llm_model_name, len(new_df)))
        new_df.to_csv(os.path.join(output_dir, "%s_generation_%s.csv" % (generation_type, llm_model_name)), index=False)