In [61]:
import numpy as np
import torch
# Macaw-large, PTLM 
# https://github.com/allenai/macaw
# This was used in the BeliefBank Paper
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
# Downloads a pretty large model
tokenizer = AutoTokenizer.from_pretrained("allenai/macaw-large")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/macaw-large")

Example output for a simple question

In [83]:
input_string = "$answer$ ; $mcoptions$ = (A) yes (B) no; $question$ = Is a robin a virus?"
input_ids = tokenizer.encode(input_string, return_tensors="pt")
output = model.generate(input_ids, max_length=200)

tokenizer.batch_decode(output, skip_special_tokens=True)


['$answer$ = no']

In [None]:
df = pd.read_csv("beliefbank_data/calibration_questions.csv", header=None)
df

Unnamed: 0,0
0,Is an albatross a bird?|Yes
1,Is an albatross a seabird?|Yes
2,Is an albatross an animal?|Yes
3,Is an albatross a eukaryotic_organism?|Yes
4,Is an albatross a pelagic_bird?|Yes
...,...
1067,Is a daffodil a palm tree?|No
1068,Is a daffodil a crustacean?|No
1069,Is a daffodil a jellyfish?|No
1070,Is a daffodil an invertebrate?|No


In [31]:
def load_file(file_name):
    with open(file_name, 'r') as file:
        return [line.strip().split(sep="|") for line in file]
        
print(load_file('beliefbank_data/calibration_questions.csv')[0:2])


[['Is an albatross a bird?', 'Yes'], ['Is an albatross a seabird?', 'Yes']]


In [98]:
def create_question_answer_list(file_name, n):
     # n = # of (q, a) pairs to use

    q_and_a = load_file(file_name)
    questions, answers = np.split(np.array(q_and_a), 2, axis=1)
    questions = ["$answer$ ; $mcoptions$ = (A) yes (B) no; $question$ = " + item + "?"\
         for sublist in questions for item in sublist]
    answers = [item for sublist in answers for item in sublist]

    question_list = list(questions)[:n]
    answer_list = list(answers)[:n]
    # print(question_list, answer_list)
    return question_list, answer_list


In [101]:
# run MACAW on file with questions and answers
def batch_eval(file_name, n):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    question_list, answer_list = create_question_answer_list(file_name, n)

    inputs_dict = tokenizer.batch_encode_plus(question_list, max_length = 200, padding=True, truncation=True, return_tensors="pt")
    input_ids = inputs_dict.input_ids.to(device)

    output = model.generate(input_ids, max_length=200)
    answers = tokenizer.batch_decode(output, skip_special_tokens=True)
    return answers


In [111]:
ans = batch_eval("beliefbank_data/calibration_questions.csv", 5)
print(ans)

['$answer$ = yes', '$answer$ = yes', '$answer$ = yes', '$answer$ = yes', '$answer$ = yes']



## Basic Evaluation

In [122]:
def macaw_evaluate(n):
    macaw_pred = batch_eval("beliefbank_data/calibration_questions.csv", n)
    macaw_pred = [item[len('$answer$ = '):] for item in macaw_pred] # remove '$answer$ = '
    _ , truth = create_question_answer_list("beliefbank_data/calibration_questions.csv", n)
    print(macaw_pred, truth)

    correct = 0
    for idx in range(n):
        if(macaw_pred[idx].lower() == truth[idx].lower()):
            correct += 1
    return correct / n # proportion of correct macaw preds
    

In [123]:
macaw_evaluate(100)