In [1]:
%load_ext dotenv
%dotenv

In [2]:
import os
import pygsheets
import pandas as pd
import openai
from pathlib import Path
from tqdm import tqdm

In [3]:
client = pygsheets.authorize()

In [4]:
workers_ids = ['W1', 'W2', 'W3', 'W4', 'W5', 'W6']
id_to_key_task1 = {
    'ind': {
        'W1': '14dAiIsBLxBUHzXgFDznMjSXCJmFY6qnIWOADqITWx9E',
        'W2': '1jHWRjwkahxpA5T_BWBRaMQ4q9kyg1l0gGz42pC6-dxI',
        'W3': '1C-Yyc17WFWScC5eu5WttQc8P8vCFuLvh8MqlBKiwF8c',
        'W4': '1J0xqeC05H1RVPSERi8fCXWUTGJ4iH3SI0w59BoFh_KU',
        'W5': '1T9jbfP1iapNLS94QZMiYw5K6LrZoP01Qf1jvTWpvdkU',
        'W6': '1G0chmXUmWC-lrZLg9ToQTGCdQmnbP0DA204bIKJn9yc',
    },
    'sun': {
        'W1': '1R-bT3RMu41cx-fnac3mNdb3OT_kb6vof1eMdCbU8_jM',
        'W2': '1ugtKZhO4jLtxW_yEwrrPXZtry2pLaP-6QgiIHKldl5E',
        'W3': '1OzP1XwzU3c-rNXyqxU3JKrrXHeXzSu3D6FIcfym4fHE',
        'W4': '1dNwW9dL4YPBEypUHQ6_-Dse8XL851W0WUiCBb28hRBQ',   
        'W5': '16wXGP08nESdLg4am4wfngeMNbK81IJAo3p0udg_iVlU',
        'W6': '1ZQEBOPGAlc6f2IUwavEKpPHY7dQxV4lN5iSQm1WHnX4',
    }
}
id_to_key_task2 = {
    'ind': {
        'W1': '1Wa33qUjeB0pq1QI87jUqHTXlQ88ADCxAFp77vFAhSVU',
        'W2': '1FbYeTu3ZK4vBLoPVpJqCFB8Rj6yzazYRpxRtmPXXf-k',
        'W3': '1UysOeI1QnU8sNXqwQ7FGnULvKCjt8Tm6eoFEdYeQ0xY',
        'W4': '1lKIAWvs9D8JoyNrZJzB6pU4KVtfWstH3tOEfM9dZMDc',
        'W5': '1nQsQPAoxeRr9QzybjrkpTn5-IuBBCByMNN8txkUVV4w',
        'W6': '1SCV9OBxvHwxQ31t6GFirO6et-raMbMpJdfLjn1_HsgU',
    },
    'sun': {
        'W1': '1-FLTsgge53Wgb3HmIlM-oCOWow4_kLEJ0bey8MaMxVI',
        'W2': '1-WLiRHFXlD5BawHdBkI2kLSlLttBXU4ufnTwLB8pufM',
        'W3': '1EVv6ktg-6ZC5e9UBFvPIVrZI0T2WAbx3OHau_T-OgSE',
        'W4': '1XyXhn_R3VuNcsmHCmCEwojFM1hUPGhQe6FDrfo3ALds',
        'W5': '1YsreB2g0AeDbiFOIu2JOUmbAkwugBRaaf8SMUA2mrg0',
        'W6': '1oU4K52UaKJT3EEvDdOs8o4tD97Wxl4SKSFQK193w2bQ',
    }
}

In [5]:
openai.api_key = os.environ['OPENAI_API_KEY']
openai.organization = os.environ['OPENAI_UILAB_KEY']
model_name = "gpt-4"
resp_history_filename = f"{model_name}_history_231106.csv"
resp_history_path = Path(resp_history_filename)
if resp_history_path.is_file():
    print("Response history found!")
    resp_history_df = pd.read_csv(resp_history_filename)
    response_history = dict(zip(resp_history_df.prompt, resp_history_df.response))
else:
    print("Response history not found. Initializing new one...")
    response_history = {}

Response history not found. Initializing new one...


In [6]:
def get_openai_chat_completion(input_prompt, model_name, resp_num=1):
    return openai.ChatCompletion.create(
        model=model_name,
        messages=[
            {
                'role': 'user',
                'content': input_prompt
            }
        ],
        temperature=0.2,
        n=resp_num
    )

In [7]:
def get_input_prompt(question_prompt):
    end_prompt = "Give only one answer that most likely to be the correct answer with a prefix that says \"Answer:\" follows by the option letter. For example:\nAnswer: Z"
    return f"{question_prompt}\n\n{end_prompt}"

In [8]:
def get_openai_answer(input_prompt, model_name, resp_num=1):
    if input_prompt in response_history:
        return response_history[input_prompt]
    
    resp = get_openai_chat_completion(input_prompt, model_name, resp_num)
    answer = resp.choices[0].message.content.strip().upper()
    answer_cleaned = answer.replace('ANSWER: ', '')
    if '. ' in answer_cleaned:
        answer_cleaned = answer_cleaned.split('. ')[0]

    if len(answer_cleaned) > 1:
        print('Answer len > 1, retry...')
        print('Answer before:', answer_cleaned)
        resp = get_openai_chat_completion(input_prompt, model_name, resp_num)
        answer = resp.choices[0].message.content.strip().upper()
        answer_cleaned = answer.replace('ANSWER: ', '')
        if '. ' in answer_cleaned:
            answer_cleaned = answer_cleaned.split('. ')[0]
        print('Answer after:', answer_cleaned)

    response_history[input_prompt] = answer_cleaned
    resp_history_df = pd.DataFrame({'prompt': response_history.keys(), 'response': response_history.values()})
    resp_history_df.to_csv(resp_history_filename, index=False)

    return answer_cleaned

#### Processing question prompt

In [9]:
answers_letters = ['A', 'B', 'C', 'D', 'E']
prompts_by_lang = {}
for lang in ['ind', 'sun']:
    prompts_by_id = {}
    for ref_id in tqdm(workers_ids, desc='Processing question prompt'):
        sh_task1 = client.open_by_key(id_to_key_task1[lang][ref_id])
        wks_task1 = sh_task1.worksheet('title', 'Data')
        
        question_col_num = 4 if lang == 'ind' else 5
        questions = [q for q in wks_task1.get_col(question_col_num)[1:] if q != '']
        
        options_col_num = 5 if lang == 'ind' else 6
        options = wks_task1.get_col(options_col_num)[1:]
        options_group, options_buffer = [], []
        for option in options:
            options_buffer.append(option)
            if len(options_buffer) == 5:
                options_group.append(options_buffer)
                options_buffer = []
        
        gold_col_num  = 6 if lang == 'ind' else 7
        gold = wks_task1.get_col(gold_col_num)[1:]
        gold_group, gold_buffer = [], []
        for ans in gold:
            gold_buffer.append(ans)
            if len(gold_buffer) == 5:
                gold_group.append(gold_buffer)
                gold_buffer = []
        
        prompts = []
        for q, o, g in zip(questions, options_group, gold_group):
            question_prompt = q
            for o_item in o:
                question_prompt += f"\n{o_item}"
            answer = answers_letters[g.index('TRUE')]
            prompts.append({
                'question_prompt': question_prompt,
                'answer': answer
            })

        prompts_by_id[ref_id] = prompts
    
    prompts_by_lang[lang] = prompts_by_id

Processing question prompt: 100%|██████████| 6/6 [00:15<00:00,  2.66s/it]
Processing question prompt: 100%|██████████| 6/6 [00:14<00:00,  2.48s/it]


In [13]:
answers_by_lang = {}
for lang in ['ind', 'sun']:
    answers_by_id = {}
    prompts_by_id = prompts_by_lang[lang]
    print('Language:', lang)
    for ref_id in workers_ids:
        answers = []
        for prompt_item in tqdm(prompts_by_id[ref_id], desc=f"Processing {ref_id}"):
            input_prompt = get_input_prompt(prompt_item['question_prompt'])
            model_pred = get_openai_answer(input_prompt, model_name, resp_num=1)
            answers.append(model_pred)
        answers_by_id[ref_id] = answers
    answers_by_lang[lang] = answers_by_id

Language: ind


Processing W1: 100%|██████████| 250/250 [00:00<00:00, 293472.15it/s]
Processing W2: 100%|██████████| 250/250 [00:00<00:00, 296626.87it/s]
Processing W3: 100%|██████████| 250/250 [00:00<00:00, 316217.13it/s]
Processing W4: 100%|██████████| 250/250 [00:00<00:00, 244594.36it/s]
Processing W5: 100%|██████████| 250/250 [00:00<00:00, 255438.73it/s]
Processing W6: 100%|██████████| 250/250 [00:00<00:00, 230811.36it/s]


Language: sun


Processing W1: 100%|██████████| 250/250 [00:00<00:00, 230709.79it/s]
Processing W2: 100%|██████████| 250/250 [00:00<00:00, 231575.97it/s]
Processing W3: 100%|██████████| 250/250 [00:00<00:00, 235264.98it/s]
Processing W4: 100%|██████████| 250/250 [00:00<00:00, 226425.39it/s]
Processing W5: 100%|██████████| 250/250 [00:00<00:00, 227852.24it/s]
Processing W6: 100%|██████████| 250/250 [01:23<00:00,  2.99it/s] 


In [14]:
conflict_num_by_lang = {}
for lang in ['ind', 'sun']:
    conflict_num_by_id = {}
    for ref_id in workers_ids:
        conflict_num = 0
        for pred, gold in zip(answers_by_lang[lang][ref_id], prompts_by_lang[lang][ref_id]):
            if pred != gold['answer']:
                conflict_num += 1
        conflict_num_by_id[ref_id] = conflict_num
    conflict_num_by_lang[lang] = conflict_num_by_id

In [15]:
conflict_num_by_lang

{'ind': {'W1': 35, 'W2': 28, 'W3': 22, 'W4': 23, 'W5': 13, 'W6': 13},
 'sun': {'W1': 68, 'W2': 57, 'W3': 60, 'W4': 61, 'W5': 76, 'W6': 71}}

#### Analyze anno conflict num

In [None]:
for lang in ['ind', 'sun']:
    # Get gold answer
    gold_ans = {}
    for ref_id in tqdm(workers_ids):
        sh_task1 = client.open_by_key(id_to_key_task1[lang][ref_id])
        wks_task1 = sh_task1.worksheet('title', 'Data')
        gold_col_num  = 6 if lang == 'ind' else 7
        gold_ans[ref_id] = wks_task1.get_col(gold_col_num)[1:]

    conflict_data = {}
    for pred_id in tqdm(workers_ids):
        all_conflict_counts = []
        sh_pred = client.open_by_key(id_to_key_task2[lang][pred_id])
        for ref_id in workers_ids:
            if pred_id == ref_id:
                all_conflict_counts.append(0)
            else:
                wks_pred = sh_pred.worksheet('title', ref_id)

                # Get answers
                ans_col_num  = 6 if lang == 'ind' else 7
                pred_ans = wks_pred.get_col(ans_col_num)[1:]
                
                # Get conflict status
                stat, stat_buffer = [], []
                for pred, gold in zip(pred_ans, gold_ans[ref_id]):

                    stat_buffer.append('OK' if pred == gold else 'CONFLICT')
                    if len(stat_buffer) == 5:
                        stat.append('CONFLICT' if 'CONFLICT' in stat_buffer else 'OK')
                        stat_buffer = []

                ok_count, conflict_count = stat.count('OK'), stat.count('CONFLICT')
                if ok_count + conflict_count != 250:
                    print(pred_id, ref_id, ok_count, conflict_count)
                assert ok_count + conflict_count == 250
                all_conflict_counts.append(conflict_count)
        conflict_data[pred_id] = all_conflict_counts

    conflict_df = pd.DataFrame.from_dict(conflict_data, orient='index', columns=workers_ids)
    print('Conflict data for', lang)
    print(conflict_df)