In [1]:
%load_ext dotenv
%dotenv

In [2]:
import json
import os
import pygsheets
import pandas as pd
import openai
from collections import Counter
from pathlib import Path
from tqdm import tqdm

### Initialize sheet credentials

In [3]:
client = pygsheets.authorize()

In [4]:
workers_ids = ['W1', 'W2', 'W3', 'W4', 'W5', 'W6']
id_to_key_task1 = {
    'ind': {
        'W1': '14dAiIsBLxBUHzXgFDznMjSXCJmFY6qnIWOADqITWx9E',
        'W2': '1jHWRjwkahxpA5T_BWBRaMQ4q9kyg1l0gGz42pC6-dxI',
        'W3': '1C-Yyc17WFWScC5eu5WttQc8P8vCFuLvh8MqlBKiwF8c',
        'W4': '1J0xqeC05H1RVPSERi8fCXWUTGJ4iH3SI0w59BoFh_KU',
        'W5': '1T9jbfP1iapNLS94QZMiYw5K6LrZoP01Qf1jvTWpvdkU',
        'W6': '1G0chmXUmWC-lrZLg9ToQTGCdQmnbP0DA204bIKJn9yc',
    },
    'sun': {
        'W1': '1R-bT3RMu41cx-fnac3mNdb3OT_kb6vof1eMdCbU8_jM',
        'W2': '1ugtKZhO4jLtxW_yEwrrPXZtry2pLaP-6QgiIHKldl5E',
        'W3': '1OzP1XwzU3c-rNXyqxU3JKrrXHeXzSu3D6FIcfym4fHE',
        'W4': '1dNwW9dL4YPBEypUHQ6_-Dse8XL851W0WUiCBb28hRBQ',   
        'W5': '16wXGP08nESdLg4am4wfngeMNbK81IJAo3p0udg_iVlU',
        'W6': '1ZQEBOPGAlc6f2IUwavEKpPHY7dQxV4lN5iSQm1WHnX4',
    }
}
id_to_key_task2 = {
    'ind': {
        'W1': '1Wa33qUjeB0pq1QI87jUqHTXlQ88ADCxAFp77vFAhSVU',
        'W2': '1FbYeTu3ZK4vBLoPVpJqCFB8Rj6yzazYRpxRtmPXXf-k',
        'W3': '1UysOeI1QnU8sNXqwQ7FGnULvKCjt8Tm6eoFEdYeQ0xY',
        'W4': '1lKIAWvs9D8JoyNrZJzB6pU4KVtfWstH3tOEfM9dZMDc',
        'W5': '1nQsQPAoxeRr9QzybjrkpTn5-IuBBCByMNN8txkUVV4w',
        'W6': '1SCV9OBxvHwxQ31t6GFirO6et-raMbMpJdfLjn1_HsgU',
    },
    'sun': {
        'W1': '1-FLTsgge53Wgb3HmIlM-oCOWow4_kLEJ0bey8MaMxVI',
        'W2': '1-WLiRHFXlD5BawHdBkI2kLSlLttBXU4ufnTwLB8pufM',
        'W3': '1EVv6ktg-6ZC5e9UBFvPIVrZI0T2WAbx3OHau_T-OgSE',
        'W4': '1XyXhn_R3VuNcsmHCmCEwojFM1hUPGhQe6FDrfo3ALds',
        'W5': '1YsreB2g0AeDbiFOIu2JOUmbAkwugBRaaf8SMUA2mrg0',
        'W6': '1oU4K52UaKJT3EEvDdOs8o4tD97Wxl4SKSFQK193w2bQ',
    }
}

#### Processing data (with commonsense QA format)

In [5]:
answers_letters = ['A', 'B', 'C', 'D', 'E']
cat_id_to_en = {
    'Kuliner': 'culinary',
    'Tempat': 'place',
    'Budaya': 'culture',
    'Sejarah': 'history',
    'Aktivitas': 'activity',
}
data_by_lang = {}
for lang in ['ind', 'sun']:
    data_by_id = {}
    for ref_id in tqdm(workers_ids, desc=f'Processing {lang}'):
        sh_task1 = client.open_by_key(id_to_key_task1[lang][ref_id])
        wks_task1 = sh_task1.worksheet('title', 'Data')

        categories = [c for c in wks_task1.get_col(2)[1:] if c != '']
        question_concepts = [c for c in wks_task1.get_col(3)[1:] if c != '']
        question_concepts_trans = [c for c in wks_task1.get_col(4)[1:] if c != ''] if lang == "sun" else []
        
        question_col_num = 4 if lang == 'ind' else 5
        questions = [q for q in wks_task1.get_col(question_col_num)[1:] if q != '']
        
        options_col_num = 5 if lang == 'ind' else 6
        options = wks_task1.get_col(options_col_num)[1:]
        options_group, options_buffer = [], []
        for option in options:
            options_buffer.append(option)
            if len(options_buffer) == 5:
                options_group.append(options_buffer)
                options_buffer = []
        
        gold_col_num  = 6 if lang == 'ind' else 7
        gold = wks_task1.get_col(gold_col_num)[1:]
        gold_group, gold_buffer = [], []
        for ans in gold:
            gold_buffer.append(ans)
            if len(gold_buffer) == 5:
                gold_group.append(gold_buffer)
                gold_buffer = []
        
        data = []
        for i in range(len(questions)):
            answer_texts = [o[3:] for o in options_group[i]]
            answer = answers_letters[gold_group[i].index('TRUE')]
            data.append({
                'category': cat_id_to_en[categories[i]],
                'question_concepts': question_concepts[i] if lang == 'ind' else question_concepts_trans[i],
                'question': questions[i],
                'choices': {
                    'label': answers_letters,
                    'text': answer_texts
                },
                'answer_creator': answer,
                'answer_majority': '',
                'answers': {},
                'answers_uncertainty': {},
                'question_ambiguity': {},
                'option_ambiguity': {},
                'reason': {}
            })

        data_by_id[ref_id] = data
    
    data_by_lang[lang] = data_by_id

Processing ind: 100%|██████████| 6/6 [00:18<00:00,  3.07s/it]
Processing sun: 100%|██████████| 6/6 [02:08<00:00, 21.38s/it]


In [6]:
for lang in ['ind', 'sun']:
    print(f'Processing {lang}')
    for ref_id in workers_ids:
        sh_task2 = client.open_by_key(id_to_key_task2[lang][ref_id])
        for pred_id in tqdm(workers_ids, desc=f'Processing {ref_id}'):
            if pred_id == ref_id:
                continue
            wks_task2 = sh_task2.worksheet('title', pred_id)

            q_ambiguity_col_num = 8 if lang == 'ind' else 9
            q_ambiguity = [q for q in wks_task2.get_col(q_ambiguity_col_num)[1:] if q != '']

            uncertainty_col_num = 10 if lang == 'ind' else 11
            uncertainty = [q for q in wks_task2.get_col(uncertainty_col_num)[1:] if q != '']

            reason_col_num = 11 if lang == 'ind' else 12
            reasons = wks_task2.get_col(reason_col_num)[1:]
            reason_single, reason_buffer = [], []
            for reason in reasons:
                reason_buffer.append(reason)
                if len(reason_buffer) == 5:
                    reason_single.append(reason_buffer[0])
                    reason_buffer = []

            answer_col_num = 6 if lang == 'ind' else 7
            answers = wks_task2.get_col(answer_col_num)[1:]
            answers_single, answers_buffer = [], []
            for answer in answers:
                answers_buffer.append(answer)
                if len(answers_buffer) == 5:
                    answers_single.append(answers_letters[answers_buffer.index('TRUE')])
                    answers_buffer = []

            options_col_num = 9 if lang == 'ind' else 10
            options = wks_task2.get_col(options_col_num)[1:]
            options_group, options_buffer = [], []
            for option in options:
                options_buffer.append(option)
                if len(options_buffer) == 5:
                    options_group.append(options_buffer)
                    options_buffer = []
            
            for i in range(len(q_ambiguity)):
                data_by_lang[lang][pred_id][i]['answers'][ref_id] = answers_single[i]
                data_by_lang[lang][pred_id][i]['answers_uncertainty'][ref_id] = uncertainty[i]
                data_by_lang[lang][pred_id][i]['question_ambiguity'][ref_id] = q_ambiguity[i]
                data_by_lang[lang][pred_id][i]['option_ambiguity'][ref_id] = options_group[i]
                data_by_lang[lang][pred_id][i]['reason'][ref_id] = reason_single[i]

Processing ind


Processing W1: 100%|██████████| 6/6 [00:09<00:00,  1.62s/it]
Processing W2: 100%|██████████| 6/6 [00:11<00:00,  1.84s/it]
Processing W3: 100%|██████████| 6/6 [01:52<00:00, 18.77s/it]
Processing W4: 100%|██████████| 6/6 [00:09<00:00,  1.61s/it]
Processing W5: 100%|██████████| 6/6 [00:12<00:00,  2.08s/it]
Processing W6: 100%|██████████| 6/6 [00:09<00:00,  1.66s/it]


Processing sun


Processing W1: 100%|██████████| 6/6 [00:09<00:00,  1.55s/it]
Processing W2: 100%|██████████| 6/6 [02:14<00:00, 22.39s/it]
Processing W3: 100%|██████████| 6/6 [00:09<00:00,  1.54s/it]
Processing W4: 100%|██████████| 6/6 [00:10<00:00,  1.68s/it]
Processing W5: 100%|██████████| 6/6 [00:09<00:00,  1.57s/it]
Processing W6: 100%|██████████| 6/6 [00:12<00:00,  2.04s/it]


In [7]:
def get_majority_element(lst):
    counts = Counter(lst)
    majority = max(counts, key=counts.get)
    return majority

In [10]:
data_by_lang_majority = {'ind': {}, 'sun': {}}
for lang in ['ind', 'sun']:
    for ref_id in tqdm(workers_ids, desc='Processing majority question'):
        data_by_id = []
        for item in data_by_lang[lang][ref_id]:
            item['answer_majority'] = get_majority_element(list(item['answers'].values()))
            data_by_id.append(item)
        data_by_lang_majority[lang][ref_id] = data_by_id

Processing majority question: 100%|██████████| 6/6 [00:00<00:00, 498.01it/s]
Processing majority question: 100%|██████████| 6/6 [00:00<00:00, 705.56it/s]


In [11]:
data_by_lang['ind']['W1'][2]

{'category': 'culinary',
 'question_concepts': 'aduk',
 'question': 'Apa yang biasanya digunakan untuk mengaduk kopi?',
 'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
  'text': ['Sedotan', 'Bungkus kopi', 'Garpu', 'Sumpit', 'Sendok']},
 'answer_creator': 'C',
 'answer_majority': 'E',
 'answers': {'W2': 'E', 'W3': 'E', 'W4': 'E', 'W5': 'A', 'W6': 'E'},
 'answers_uncertainty': {'W2': 'certain',
  'W3': 'certain',
  'W4': 'certain',
  'W5': 'certain',
  'W6': 'certain'},
 'question_ambiguity': {'W2': 'clear',
  'W3': 'clear',
  'W4': 'clear',
  'W5': 'clear',
  'W6': 'clear'},
 'option_ambiguity': {'W2': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W3': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W4': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W5': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W6': ['clear', 'clear', 'clear', 'clear', 'clear']},
 'reason': {'W2': '', 'W3': '', 'W4': '', 'W5': '', 'W6': ''}}

In [14]:
data_by_lang['sun']['W1'][0]

{'category': 'culinary',
 'question_concepts': 'adab dahar',
 'question': 'Dihandap ieu anu termasuk kana adab dahar anu umum di daerah Sunda nyaeta?',
 'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
  'text': ['Nganggo panangan kenca',
   'Nyandak kaemaman anu tebih',
   'Dahar bari nyarios',
   'Nganggo panangan katuhu',
   'Dahar bari diuk']},
 'answer_creator': 'D',
 'answer_majority': 'D',
 'answers': {'W2': 'D', 'W3': 'D', 'W4': 'D', 'W5': 'D', 'W6': 'D'},
 'answers_uncertainty': {'W2': 'certain',
  'W3': 'certain',
  'W4': 'certain',
  'W5': 'certain',
  'W6': 'certain'},
 'question_ambiguity': {'W2': 'clear',
  'W3': 'clear',
  'W4': 'clear',
  'W5': 'clear',
  'W6': 'clear'},
 'option_ambiguity': {'W2': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W3': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W4': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W5': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W6': ['clear', 'clear', 'clear', 'clear', 'clear']},
 'reason'

In [16]:
# Dump the dataset
for lang in ['ind', 'sun']:
    with open(f'../../dataset/v2_human/raw_{lang}.json', 'w', encoding='utf-8') as fp:
        json.dump(data_by_lang[lang], fp)

#### Processing filtered data

In [17]:
import re

In [18]:
def get_certainty_based_on_reason(reason, certainty):
    reason = reason.lower()
    pattern = r"\[uncertai(?:n|nty)[^\]]*\].*?(?=\[|$)"
    matches = re.findall(pattern, reason)
    if matches:
        reason = ''.join(matches).strip()
    
    if 'googl' in reason:
        if any(word in reason for word in ['tidak ada', 'susah', 'sulit', 'belum menemukan']):
            return 'uncertain'
        else:
            return 'certain'
    else:
        return certainty

In [19]:
data_by_lang, num_q_ambiguous_by_lang, num_op_ambiguous_by_lang, num_uncertain_by_lang, filtered_data_by_lang = {}, {}, {}, {}, {}
workers_ids = ['W1', 'W2', 'W3', 'W4', 'W5', 'W6']
for lang in ['ind', 'sun']:
    with open(f'../../dataset/v2_human/raw_{lang}.json', encoding='utf-8') as fp:
        data_by_lang[lang] = json.load(fp)
    
    filtered_data = []
    num_q_ambiguous_by_wid, num_op_ambiguous_by_wid, num_uncertain_by_wid = {}, {}, {}
    for ref_id in tqdm(workers_ids, desc='Filtering question'):
        q_ambiguous_num, op_ambiguous_num, uncertain_num = 0, 0, 0
        for item in data_by_lang[lang][ref_id]:
            is_q_ambiguous = list(item['question_ambiguity'].values()).count('ambiguous') >= 3
            
            op_ambiguity = list(item['option_ambiguity'].values())
            op_ambiguity_grouped = []
            for row in zip(*op_ambiguity):
                op_ambiguity_grouped.append(list(row).count('ambiguous') >= 3)
            is_op_ambiguous = op_ambiguity_grouped.count(True) >= 3

            uncertainty_grouped = []
            for ans_un, reason in zip(list(item['answers_uncertainty'].values()), list(item['reason'].values())):
                if ans_un == 'certain':
                    uncertainty_grouped.append('certain')
                else:
                    uncertainty_grouped.append(get_certainty_based_on_reason(reason, ans_un))
            is_uncertain = uncertainty_grouped.count('uncertain') >= 4

            if not is_q_ambiguous and not is_op_ambiguous and not is_uncertain:
                filtered_data.append(item)
            else:
                if is_q_ambiguous:
                    q_ambiguous_num += 1
                    print('Ambiguous question:')
                    print(item)
                elif is_op_ambiguous:
                    op_ambiguous_num += 1
                elif is_uncertain:
                    uncertain_num += 1
                    print('Uncertain question:')
                    print(item)

        num_q_ambiguous_by_wid[ref_id] = q_ambiguous_num
        num_op_ambiguous_by_wid[ref_id] = op_ambiguous_num
        num_uncertain_by_wid[ref_id] = uncertain_num

    filtered_data_by_lang[lang] = filtered_data
    num_q_ambiguous_by_lang[lang] = num_q_ambiguous_by_wid
    num_op_ambiguous_by_lang[lang] = num_op_ambiguous_by_wid
    num_uncertain_by_lang[lang] = num_uncertain_by_wid

Filtering question: 100%|██████████| 6/6 [00:00<00:00, 275.72it/s]


Ambiguous question:
{'category': 'place', 'question_concepts': 'pulau', 'question': 'Pulau yang terkenal akan kotanya yang besar adalah?', 'choices': {'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['Pulau Bali', 'Pulau Jawa', 'Pulau Lombok', 'Pulau Sumatera', 'Nusa Tenggara Barat']}, 'answer_creator': 'B', 'answer_majority': 'B', 'answers': {'W1': 'B', 'W2': 'B', 'W3': 'B', 'W4': 'B', 'W5': 'B'}, 'answers_uncertainty': {'W1': 'certain', 'W2': 'certain', 'W3': 'uncertain', 'W4': 'certain', 'W5': 'uncertain'}, 'question_ambiguity': {'W1': 'ambiguous', 'W2': 'clear', 'W3': 'ambiguous', 'W4': 'clear', 'W5': 'ambiguous'}, 'option_ambiguity': {'W1': ['clear', 'clear', 'clear', 'clear', 'clear'], 'W2': ['clear', 'clear', 'clear', 'clear', 'clear'], 'W3': ['clear', 'clear', 'clear', 'clear', 'clear'], 'W4': ['clear', 'clear', 'clear', 'clear', 'clear'], 'W5': ['clear', 'ambiguous', 'clear', 'ambiguous', 'clear']}, 'reason': {'W1': '[ambiguous] besar secara ukuran lokasi antarkota di pulau-pulau 

Filtering question: 100%|██████████| 6/6 [00:00<00:00, 299.50it/s]

Uncertain question:
{'category': 'culture', 'question_concepts': 'Dayak', 'question': 'Naon bagian ti kabudayaan Dayak nu loba dijelaskeun di buku sakola?', 'choices': {'label': ['A', 'B', 'C', 'D', 'E'], 'text': ['Imah tradisionalna', 'Cara hirupna', 'Pakean tradisionalna', 'Cara urang Dayak ngaranan budak', 'Cara nyawahna']}, 'answer_creator': 'C', 'answer_majority': 'C', 'answers': {'W1': 'C', 'W3': 'A', 'W4': 'C', 'W5': 'B', 'W6': 'C'}, 'answers_uncertainty': {'W1': 'uncertain', 'W3': 'uncertain', 'W4': 'certain', 'W5': 'uncertain', 'W6': 'uncertain'}, 'question_ambiguity': {'W1': 'clear', 'W3': 'clear', 'W4': 'clear', 'W5': 'clear', 'W6': 'clear'}, 'option_ambiguity': {'W1': ['clear', 'clear', 'clear', 'clear', 'clear'], 'W3': ['clear', 'clear', 'clear', 'clear', 'clear'], 'W4': ['clear', 'clear', 'clear', 'clear', 'clear'], 'W5': ['clear', 'clear', 'clear', 'clear', 'clear'], 'W6': ['clear', 'clear', 'clear', 'clear', 'clear']}, 'reason': {'W1': '[Uncertain] Tidak yakin saat meng




In [20]:
len(filtered_data_by_lang['ind']), len(filtered_data_by_lang['sun'])

(1498, 1499)

In [21]:
num_q_ambiguous_by_lang

{'ind': {'W1': 0, 'W2': 0, 'W3': 0, 'W4': 0, 'W5': 0, 'W6': 1},
 'sun': {'W1': 0, 'W2': 0, 'W3': 0, 'W4': 0, 'W5': 0, 'W6': 0}}

In [22]:
num_op_ambiguous_by_lang

{'ind': {'W1': 0, 'W2': 0, 'W3': 0, 'W4': 0, 'W5': 0, 'W6': 0},
 'sun': {'W1': 0, 'W2': 0, 'W3': 0, 'W4': 0, 'W5': 0, 'W6': 0}}

In [23]:
num_uncertain_by_lang

{'ind': {'W1': 0, 'W2': 0, 'W3': 0, 'W4': 0, 'W5': 0, 'W6': 1},
 'sun': {'W1': 0, 'W2': 1, 'W3': 0, 'W4': 0, 'W5': 0, 'W6': 0}}

In [24]:
filtered_data_by_lang['ind'][2]

{'category': 'culinary',
 'question_concepts': 'aduk',
 'question': 'Apa yang biasanya digunakan untuk mengaduk kopi?',
 'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
  'text': ['Sedotan', 'Bungkus kopi', 'Garpu', 'Sumpit', 'Sendok']},
 'answer_creator': 'C',
 'answer_majority': 'E',
 'answers': {'W2': 'E', 'W3': 'E', 'W4': 'E', 'W5': 'A', 'W6': 'E'},
 'answers_uncertainty': {'W2': 'certain',
  'W3': 'certain',
  'W4': 'certain',
  'W5': 'certain',
  'W6': 'certain'},
 'question_ambiguity': {'W2': 'clear',
  'W3': 'clear',
  'W4': 'clear',
  'W5': 'clear',
  'W6': 'clear'},
 'option_ambiguity': {'W2': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W3': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W4': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W5': ['clear', 'clear', 'clear', 'clear', 'clear'],
  'W6': ['clear', 'clear', 'clear', 'clear', 'clear']},
 'reason': {'W2': '', 'W3': '', 'W4': '', 'W5': '', 'W6': ''}}

In [25]:
# Dump the dataset
for lang in ['ind', 'sun']:
    with open(f'../../dataset/v2_human/filtered_{lang}.json', 'w', encoding='utf-8') as fp:
        json.dump(filtered_data_by_lang[lang], fp)

#### Processing question prompt

In [94]:
answers_letters = ['A', 'B', 'C', 'D', 'E']
prompts_by_lang, data_by_lang = {}, {}
mode = 'v2_filtered'
if mode in ['v1', 'v2_filtered']:
    workers_ids = ['W1']
else:
    workers_ids = ['W1', 'W2', 'W3', 'W4', 'W5', 'W6']
for lang in ['ind', 'sun']:
    if mode == 'v1':
        input_filename = f'../../dataset/v1_synthetic/filtered_both_test_{lang}.json'
    elif mode == 'v2_raw':
        input_filename = f'../../dataset/v2_human/raw_{lang}.json'
    elif mode == 'v2_filtered':
        input_filename = f'../../dataset/v2_human/filtered_{lang}.json'
    
    with open(input_filename, encoding='utf-8') as fp:
        data_by_lang[lang] = json.load(fp)
    
    if mode == 'v1':
        prompts = []
        for item in data_by_lang[lang]:
            question_prompt = item['question']
            for label, text in zip(item['choices']['label'], item['choices']['text']):
                question_prompt += f"\n{label}. {text}"
            prompts.append({
                'question_prompt': question_prompt,
                'answer': item['answerKey']
            })
        prompts_by_lang[lang] = {'W1': prompts}
    elif mode == 'v2_filtered':
        prompts = []
        for item in data_by_lang[lang]:
            question_prompt = item['question']
            for label, text in zip(item['choices']['label'], item['choices']['text']):
                question_prompt += f"\n{label}. {text}"
            prompts.append({
                'question_prompt': question_prompt,
                'answer': item['answer_majority']
            })
        prompts_by_lang[lang] = {'W1': prompts}
    elif mode == 'v2_raw':
        prompts_by_id = {}
        for ref_id in tqdm(workers_ids, desc='Processing question prompt'):
            prompts = []
            for item in data_by_lang[lang][ref_id]:
                question_prompt = item['question']
                for label, text in zip(item['choices']['label'], item['choices']['text']):
                    question_prompt += f"\n{label}. {text}"
                prompts.append({
                    'question_prompt': question_prompt,
                    'answer': item['answer_creator']
                })
            prompts_by_id[ref_id] = prompts

        prompts_by_lang[lang] = prompts_by_id

In [95]:
prompts_by_lang['sun']['W1'][2]

{'question_prompt': 'Naon nami kaemaman anu didamel tina tipung ketan sareng gula kawung anu diaduk dugika kentel?\nA. Dodol\nB. Awug\nC. Comro\nD. Misro\nE. Gegetuk',
 'answer': 'A'}

#### Get ChatGPT Answers

In [114]:
openai.api_key = os.environ['OPENAI_API_KEY']
openai.organization = os.environ['OPENAI_UILAB_KEY']
model_name = "gpt-4-1106-preview"
resp_history_filename = f"{model_name}_history_231121.csv"
resp_history_path = Path(resp_history_filename)
if resp_history_path.is_file():
    print("Response history found!")
    resp_history_df = pd.read_csv(resp_history_filename)
    response_history = dict(zip(resp_history_df.prompt, resp_history_df.response))
else:
    print("Response history not found. Initializing new one...")
    response_history = {}

Response history found!


In [115]:
def get_openai_chat_completion(input_prompt, model_name, resp_num=1, temp=0.1):
    return openai.ChatCompletion.create(
        model=model_name,
        messages=[
            {
                'role': 'user',
                'content': input_prompt
            }
        ],
        temperature=temp,
        n=resp_num
    )

In [116]:
def get_input_prompt(question_prompt):
    end_prompt = "Give only one answer that most likely to be the correct answer with a prefix that says \"Answer:\" follows by the option letter. For example:\nAnswer: Z"
    return f"{question_prompt}\n\n{end_prompt}"

In [117]:
def get_openai_answer(input_prompt, model_name, resp_num=1):
    if input_prompt in response_history:
        return response_history[input_prompt]
    
    resp = get_openai_chat_completion(input_prompt, model_name, resp_num)
    answer = resp.choices[0].message.content.strip().upper()
    answer_cleaned = answer.replace('ANSWER: ', '')
    if '. ' in answer_cleaned:
        answer_cleaned = answer_cleaned.split('. ')[0]

    if len(answer_cleaned) > 1:
        print('Answer len > 1, retry...')
        print('Answer before:', answer_cleaned)
        resp = get_openai_chat_completion(input_prompt, model_name, resp_num)
        answer = resp.choices[0].message.content.strip().upper()
        answer_cleaned = answer.replace('ANSWER: ', '')
        if '. ' in answer_cleaned:
            answer_cleaned = answer_cleaned.split('. ')[0]
        print('Answer after:', answer_cleaned)

    response_history[input_prompt] = answer_cleaned
    resp_history_df = pd.DataFrame({'prompt': response_history.keys(), 'response': response_history.values()})
    resp_history_df.to_csv(resp_history_filename, index=False)

    return answer_cleaned

In [118]:
answers_by_lang = {}
for lang in ['ind', 'sun']:
    answers_by_id = {}
    prompts_by_id = prompts_by_lang[lang]
    print('Language:', lang)
    for ref_id in workers_ids:
        answers = []
        for prompt_item in tqdm(prompts_by_id[ref_id], desc=f"Processing {ref_id}"):
            input_prompt = get_input_prompt(prompt_item['question_prompt'])
            model_pred = get_openai_answer(input_prompt, model_name, resp_num=1)
            answers.append(model_pred)
        answers_by_id[ref_id] = answers
    answers_by_lang[lang] = answers_by_id

Language: ind


Processing W1: 100%|██████████| 1498/1498 [00:00<00:00, 661863.20it/s]


Language: sun


Processing W1: 100%|██████████| 1499/1499 [00:00<00:00, 404533.63it/s]


In [119]:
conflict_num_by_lang, acc_by_lang = {}, {}
for lang in ['ind', 'sun']:
    conflict_num_by_id, acc_by_id = {}, {}
    for ref_id in workers_ids:
        conflict_num = 0
        for pred, gold in zip(answers_by_lang[lang][ref_id], prompts_by_lang[lang][ref_id]):
            if pred != gold['answer']:
                conflict_num += 1
        conflict_num_by_id[ref_id] = conflict_num
        acc_by_id[ref_id] = (len(answers_by_lang[lang][ref_id]) - conflict_num) / len(answers_by_lang[lang][ref_id]) * 100
    conflict_num_by_lang[lang] = conflict_num_by_id
    acc_by_lang[lang] = acc_by_id

In [120]:
model_name

'gpt-4-1106-preview'

In [121]:
conflict_num_by_lang

{'ind': {'W1': 104}, 'sun': {'W1': 233}}

In [122]:
acc_by_lang

{'ind': {'W1': 93.05740987983978}, 'sun': {'W1': 84.45630420280187}}

#### Analyze anno conflict num

In [None]:
for lang in ['ind', 'sun']:
    # Get gold answer
    gold_ans = {}
    for ref_id in tqdm(workers_ids):
        sh_task1 = client.open_by_key(id_to_key_task1[lang][ref_id])
        wks_task1 = sh_task1.worksheet('title', 'Data')
        gold_col_num  = 6 if lang == 'ind' else 7
        gold_ans[ref_id] = wks_task1.get_col(gold_col_num)[1:]

    conflict_data = {}
    for pred_id in tqdm(workers_ids):
        all_conflict_counts = []
        sh_pred = client.open_by_key(id_to_key_task2[lang][pred_id])
        for ref_id in workers_ids:
            if pred_id == ref_id:
                all_conflict_counts.append(0)
            else:
                wks_pred = sh_pred.worksheet('title', ref_id)

                # Get answers
                ans_col_num  = 6 if lang == 'ind' else 7
                pred_ans = wks_pred.get_col(ans_col_num)[1:]
                
                # Get conflict status
                stat, stat_buffer = [], []
                for pred, gold in zip(pred_ans, gold_ans[ref_id]):

                    stat_buffer.append('OK' if pred == gold else 'CONFLICT')
                    if len(stat_buffer) == 5:
                        stat.append('CONFLICT' if 'CONFLICT' in stat_buffer else 'OK')
                        stat_buffer = []

                ok_count, conflict_count = stat.count('OK'), stat.count('CONFLICT')
                if ok_count + conflict_count != 250:
                    print(pred_id, ref_id, ok_count, conflict_count)
                assert ok_count + conflict_count == 250
                all_conflict_counts.append(conflict_count)
        conflict_data[pred_id] = all_conflict_counts

    conflict_df = pd.DataFrame.from_dict(conflict_data, orient='index', columns=workers_ids)
    print('Conflict data for', lang)
    print(conflict_df)