In [None]:
PROJECT_HOME = "."

# # For Colab

# PROJECT_HOME = "/content/drive/My Drive/Projects/LLM-MCI-detection"

# # Google Drive storage setup
# from google.colab import drive
# drive.mount('/content/drive', force_remount=True)

In [None]:
%%capture
%pip install python-dotenv openai

In [None]:
import os
import openai
import pandas as pd
from tqdm import tqdm

In [None]:
import dotenv
_ = dotenv.load_dotenv(os.path.join(PROJECT_HOME, './secret.env'))

In [None]:
skip_if_exist = True

In [None]:
use_azure = True

In [None]:
use_gpt4 = False

In [None]:
if use_azure:
    if use_gpt4:
        gpt_model_name = "gpt-4"
    else:
        gpt_model_name = "gpt-35-turbo"
else:
    if use_gpt4:
        gpt_model_name = "gpt-4-turbo-preview"
    else:
        gpt_model_name = "gpt-3.5-turbo"

In [None]:
if use_azure:
    client = openai.AzureOpenAI(
        api_key=os.getenv("AZURE_OPENAI_API_KEY"),
        api_version="2023-12-01-preview",
        azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT")
    )
else:
    client = openai.OpenAI(api_key=os.getenv('OPENAI_API_KEY'))

In [None]:
def request_api(system_message, user_message):
    response = client.chat.completions.create(
    model=gpt_model_name,
    messages=[
        {"role": "system", "content": system_message},
        {"role": "user", "content": user_message}]
    )
    text_content = response.choices[0].message.content
    return text_content

In [None]:
def extract_row_data(row):

    label = row['label']
    age = row['age']
    gender = row['gender']
    race = row['race']
    education = row['education']
    MMSE = row['MMSE']
    text = row['text']

    if pd.isnull(age):
        age = "MISSING"
    else:
        age = int(age)

    if gender == 1:
        gender == "male"
    else:
        gender = "female"

    if race == 1:
        race = "White"
    else:
        race = "Non-White"

    if pd.isnull(education):
        education = "MISSING"
    else:
        education = int(education)

    if pd.isnull(MMSE):
        MMSE = "MISSING"
    else:
        MMSE = int(MMSE)

    return label, age, gender, race, education, MMSE, text

In [None]:
def get_data_generation_prompt(row, generation_type):

    label, age, gender, race, education, MMSE, text = extract_row_data(row)

    if generation_type == "observational":
        assert label == "MCI"
    elif generation_type == "cross-lingual":
        assert label == "MCI"
    elif generation_type == "counterfactual":
        assert label in ["NC", "MCI"]
        if label == "NC":
            counterfactual_label = "MCI"
        else:
            counterfactual_label = "NC"
    else:
        raise ValueError("Unknown generation type: %s" % generation_type)

    system_message = "Use the following step-by-step instructions to respond to user inputs. The user inputs are related to the transcription of one test subject labelled %s " % label
    system_message += "describing the Cookie Theft picture from the Boston Diagnostic Aphasia Exam."
    system_message += "Other information of the test subject is provided, including, age, gender, race, education level (number of years), and Mini Mental State Examination (MMSE) score."
    system_message += "Before the step-by-step instructions, some background information is listed as follows."
    system_message += "This Cookie Theft picture description task is used to determine whether one is probable Alzheimer's disease (AD), mild cognitive impairment (MCI), or normal control (NC)."
    system_message += "The MMSE score measures one's cognitive function but needs adjustment for the education level."
    system_message += "The step-by-step instructions are listed as follows."
    system_message += "Step 1 - Explain the characteristics of this text and the reasons behind why this test subject is labelled %s." % label

    if generation_type in ["observational", "cross-lingual"]:
        system_message += "Step 2 - Given the explanations from Step 1, rephrase the original transcription to a similar but new transcription in two lines:"
    elif generation_type == "counterfactual":
        system_message += "Step 2 - Given the explanations from Step 1, imagine what characteristics a subject labelled with %s would have, " % counterfactual_label
        system_message += "while keeping the subject's age, gender, race, and education information unchanged."
        system_message += "Step 3 - Given the reasons from Step 2, write a new counterfactual transcription labelled with %s in two lines:" % counterfactual_label

    system_message += "the first line only outputs the new transcription in no more than 150 words, with a prefix 'Text:'; "
    system_message += "the second line outputs the explanations, with a prefix 'Explanations:'."

    if generation_type == "cross-lingual":
        system_message += "Step 3 - Given Step 2, only translate the text but not explanations into Chinese, with a prefix 'Chinese:'."

    user_message = "The original transcription of the test subject is given as follows: %s." % text
    user_message += "The label of this transcription is: %s." % label
    user_message += "The test subject's age is %s, gender is %s, race is %s, education level (number of years) is %s, and MMSE score is %s." % (age, gender, race, education, MMSE)

    return system_message, user_message

# New MCI samples generated by existing MCI samples

In [None]:
data = pd.read_csv(os.path.join(PROJECT_HOME, 'data', 'original.csv'))
observed_data = data[data['label']=='MCI']
N_fold_observational_generation = 5

for run_number in range(N_fold_observational_generation):
    output_dir_name = os.path.join(PROJECT_HOME, 'data', 'observational-generation', '%d' % run_number, gpt_model_name)
    os.makedirs(output_dir_name, exist_ok=True)

    for idx, original_row in tqdm(observed_data.iterrows(), total=len(observed_data), desc="Observational generation run %d" % run_number):

        if skip_if_exist:
            if os.path.exists(os.path.join(output_dir_name, f'{idx}.txt')):
                continue

        system_message, user_message = get_data_generation_prompt(original_row, "observational")

        try:
            text_content = request_api(system_message, user_message)
            with open(os.path.join(output_dir_name, f'{idx}.txt'), 'w') as f:
                f.write(text_content)
        except:
            continue

# New Chinese MCI samples generated by existing English MCI samples

In [None]:
data = pd.read_csv(os.path.join(PROJECT_HOME, 'data', 'original.csv'))
observed_data = data[data['label']=='MCI']
N_fold_cross_lingual_generation = 5

for run_number in range(N_fold_cross_lingual_generation):
    output_dir_name = os.path.join(PROJECT_HOME, 'data', 'cross-lingual-generation', '%d' % run_number, gpt_model_name)
    os.makedirs(output_dir_name, exist_ok=True)

    for idx, original_row in tqdm(observed_data.iterrows(), total=len(observed_data), desc="Cross-lingual generation run %d" % run_number):

        if skip_if_exist:
            if os.path.exists(os.path.join(output_dir_name, f'{idx}.txt')):
                continue

        system_message, user_message = get_data_generation_prompt(original_row, "cross-lingual")

        try:
            text_content = request_api(system_message, user_message)
            with open(os.path.join(output_dir_name, f'{idx}.txt'), 'w') as f:
                f.write(text_content)
        except:
            continue

# New counterfactual samples generated by existing observational samples

In [None]:
data = pd.read_csv(os.path.join(PROJECT_HOME, 'data', 'original.csv'))
observed_data = data[data['label']=='NC']
output_dir_name = os.path.join(PROJECT_HOME, 'data', 'counterfactual-generation', gpt_model_name)
os.makedirs(output_dir_name, exist_ok=True)

for idx, original_row in tqdm(observed_data.iterrows(), total=len(observed_data), desc="Counterfactual generation"):

    if skip_if_exist:
        if os.path.exists(os.path.join(output_dir_name, f'{idx}.txt')):
            continue

    system_message, user_message = get_data_generation_prompt(original_row, "counterfactual")

    try:
        text_content = request_api(system_message, user_message)
        with open(os.path.join(output_dir_name, f'{idx}.txt'), 'w') as f:
            f.write(text_content)
    except:
        continue