In [1]:
%load_ext dotenv
%dotenv

In [2]:
import openai
import os
import pandas as pd
import stanza
import time
from datasets import load_dataset
from tqdm import tqdm

In [3]:
# Only need to be run one time
# stanza.download('en')

In [31]:
split_name = 'train'
en_csqa = load_dataset('commonsense_qa', split=split_name)

Found cached dataset commonsense_qa (/mnt/nas2/kikiputri/cache/commonsense_qa/default/1.0.0/28d68f56649a7f0c23bc68eae850af914aa03f95f810011ae8cf58cc5ff5051b)


### Concept Relevancy Classifier

In [5]:
ner_pipeline = stanza.Pipeline('en', processors='tokenize,ner')

2023-07-03 12:20:28 INFO: Checking for updates to resources.json in case models have been updated.  Note: this behavior can be turned off with download_method=None or download_method=DownloadMethod.REUSE_RESOURCES


Downloading https://raw.githubusercontent.com/stanfordnlp/stanza-resources/main/resources_1.5.0.json:   0%|   …

2023-07-03 12:20:29 INFO: Loading these models for language: en (English):
| Processor | Package   |
-------------------------
| tokenize  | combined  |
| ner       | ontonotes |

2023-07-03 12:20:29 INFO: Using device: cuda
2023-07-03 12:20:29 INFO: Loading: tokenize
2023-07-03 12:20:33 INFO: Loading: ner
2023-07-03 12:20:34 INFO: Done loading processors!


In [6]:
def extract_names(sentence):
    ner_results = ner_pipeline(sentence)
    return [ent.text for ent in ner_results.ents if ent.type == 'PERSON']


def extract_locations(sentence):
    ner_results = ner_pipeline(sentence)
    return [ent.text for ent in ner_results.ents if ent.type in ['GPE', 'LOC']]

In [7]:
openai.api_key = os.environ['OPENAI_API_KEY']
openai.organization = os.environ['OPENAI_UILAB_KEY']
response_history = {}

In [14]:
def get_input_prompt(concept, concept_type, location_name):
    end_prompt = "Answer with only 'yes' or 'no'."
    
    if concept_type == 'location':
        return f"Is {concept} located in {location_name}? {end_prompt}"
    
    if concept_type == 'name':
        return f"Is the name \"{concept}\" common in {location_name}? {end_prompt}"

    return f"Does {concept} commonly found in {location_name}? {end_prompt}"


def get_openai_chat_completion(input_prompt, model_name, temp=0.2):
    return openai.ChatCompletion.create(
        model=model_name,
        messages=[
            {
                'role': 'user',
                'content': input_prompt
            }
        ],
        temperature=temp
    )


def get_openai_relevancy(input_prompt, model_name):
    if input_prompt in response_history:
        return response_history[input_prompt]
    
    try:
        completion = get_openai_chat_completion(input_prompt, model_name)
    except openai.error.RateLimitError:
        # print("OpenAI RateLimitError, wait for 1 minute...")
        time.sleep(60)
        completion = get_openai_chat_completion(input_prompt, model_name)
    except openai.error.ServiceUnavailableError:
        # print("OpenAI ServiceUnavailableError, wait for 1 minute...")
        time.sleep(60)
        completion = get_openai_chat_completion(input_prompt, model_name)

    response = completion.choices[0].message.content.strip().lower()
    
    if response in ["yes.", "no."]:
        response = response.replace(".", "")
    
    response_history[input_prompt] = response

    return response

In [32]:
en_csqa[0]

{'id': '075e483d21c29a511267ef62bedc0461',
 'question': 'The sanctions against the school were a punishing blow, and they seemed to what the efforts the school had made to change?',
 'question_concept': 'punishing',
 'choices': {'label': ['A', 'B', 'C', 'D', 'E'],
  'text': ['ignore', 'enforce', 'authoritarian', 'yell at', 'avoid']},
 'answerKey': 'A'}

In [38]:
model_name = "gpt-3.5-turbo"
relevancy_data = {
    'q_id': [], 'q_concept': [],
    'option_a': [], 'option_b': [], 'option_c': [], 'option_d': [], 'option_e': [],
    'names': [], 'answer': []
}
option_idxs = ['option_a', 'option_b', 'option_c', 'option_d', 'option_e']
for item in tqdm(en_csqa):
    relevancy_data['q_id'].append(item['id'])

    input_prompt = get_input_prompt(item['question_concept'], "other", "Indonesia")
    rel = get_openai_relevancy(input_prompt, model_name)
    relevancy_data['q_concept'].append(rel)

    for option_idx, choice in zip(option_idxs, item['choices']['text']):
        locations = extract_locations(choice)
        if len(locations) > 0:
            input_prompt = get_input_prompt(choice, "location", "Indonesia")
            rel = get_openai_relevancy(input_prompt, model_name)
            relevancy_data[option_idx].append(rel)
        else:
            relevancy_data[option_idx].append(None)
        
    names = extract_names(item['question'])
    names_rel = []
    for name in names:
        input_prompt = get_input_prompt(name, "name", "Indonesia")
        rel = get_openai_relevancy(input_prompt, model_name)
        names_rel.append((name, rel))
    relevancy_data['names'].append(names_rel)

    relevancy_data['answer'].append(item['answerKey'])

100%|██████████| 9741/9741 [09:07<00:00, 17.80it/s]  


In [39]:
relevancy_df = pd.DataFrame(relevancy_data)

In [40]:
relevancy_df

Unnamed: 0,q_id,q_concept,option_a,option_b,option_c,option_d,option_e,names,answer
0,075e483d21c29a511267ef62bedc0461,yes,,,,,,[],A
1,61fe6e879ff18686d7552425a36344c8,yes,,,,,,"[(Sammy, no)]",B
2,4c1cb0e95b99f72d55c068ba0255c54d,yes,,,,,,[],A
3,02e821a3e53cb320790950aab4489e85,yes,,no,,,,[],D
4,23505889b94e880c3e89cff4ba119860,no,,,,,,[],C
...,...,...,...,...,...,...,...,...,...
9736,f1b2a30a1facff543e055231c5f90dd0,yes,,,,,,[],E
9737,a63b4d0c0b34d6e5f5ce7b2c2c08b825,yes,,,,,,[],D
9738,22d0eea15e10be56024fd00bb0e4f72f,yes,,,,,,[],A
9739,7c55160a4630de9690eb328b57a18dc2,yes,,,,,,"[(John, no)]",A


In [41]:
out_parent_dir = "/mnt/nas2/kikiputri/id-csqa/dataset/relevancy/"
relevancy_df.to_csv(out_parent_dir + split_name + "_step1.csv", index=False)

In [45]:
resp_history_df = pd.DataFrame({'prompt': response_history.keys(), 'response': response_history.values()})

In [46]:
resp_history_df

Unnamed: 0,prompt,response
0,Does revolving door commonly found in Indonesi...,no
1,Does people commonly found in Indonesia? Answe...,yes
2,Does magazines commonly found in Indonesia? An...,yes
3,Does hamburger commonly found in Indonesia? An...,yes
4,Does farmland commonly found in Indonesia? Ans...,yes
...,...,...
2533,"Is the name ""Lud"" common in Indonesia? Answer ...",no
2534,Does dining room table commonly found in Indon...,yes
2535,Does restroom commonly found in Indonesia? Ans...,yes
2536,"Is the name ""Mike"" common in Indonesia? Answer...",no


In [47]:
resp_history_df.to_csv(out_parent_dir + "gpt-3.5-history-230703.csv", index=False)