In [1]:
import os, sys
import random
import json
import nltk 
import csv
import torch
import numpy as np
import nltk  # $ pip install nltk
from nltk.stem import PorterStemmer
from nltk.corpus import cmudict  # >>> nltk.download('cmudict')
from nltk.tokenize import word_tokenize
from spellchecker import SpellChecker

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification

from beliefbank_data.utils import generate_assertion, generate_question, find_constraints

# Load data

In [2]:
constraints_path = "beliefbank_data/constraints_v2.json"
facts_path = "beliefbank_data/silver_facts.json"

In [3]:
constraints = json.load(open(constraints_path))
facts = json.load(open(facts_path))

In [85]:
# entities = list(facts.keys())
# random.shuffle(entities)
# dev_size = 65
# dev_entities = sorted(entities[:dev_size])
# eval_entities = sorted(entities[dev_size:])
# with open("beliefbank_data/dev_entities.txt", "w") as f:
#     f.writelines([e + '\n' for e in dev_entities])
# with open("beliefbank_data/eval_entities.txt", "w") as f:
#     f.writelines([e + '\n' for e in eval_entities])

with open("beliefbank_data/dev_entities.txt", "r") as f:
    dev_entities = [e.strip() for e in f.readlines()]
print(dev_entities)

# with open("beliefbank_data/eval_entities.txt", "r") as f:
#     eval_entities = [e.strip() for e in f.readlines()]
# print(eval_entities)

['american bison', 'baboon', 'birch', 'buck', 'bull', 'calf', 'camel', 'carnivore', 'carp', 'cheetah', 'chick', 'chimpanzee', 'cock', 'crocodile', 'dog', 'dolphin', 'domestic ass', 'duck', 'earthworm', 'elephant', 'european wolf spider', 'foxhound', 'frog', 'gazelle', 'gecko', 'german shepherd', 'giant panda', 'giraffe', 'gladiolus', 'hen', 'horse', 'hound', 'howler monkey', 'hummingbird', 'jaguar', 'lamb', 'leopard', 'lion', 'livestock', 'llama', 'magpie', 'midge', 'mink', 'mullet', 'myna', 'new world blackbird', 'orchid', 'owl', 'ox', 'penguin', 'peony', 'pigeon', 'poodle', 'puppy', 'rabbit', 'rat', 'reptile', 'robin', 'rose', 'salamander', 'starling', 'tiger', 'turkey', 'whale', 'zebra']


In [87]:
statements = [(entity, relation, label == 'yes')
              for entity, relations in facts.items() if entity in dev_entities 
              for relation, label in relations.items()]
print(f"Number of facts: {len(statements)}")
statements[:5]

Number of facts: 9640


[('american bison', 'IsA,mammal', True),
 ('american bison', 'IsA,american bison', True),
 ('american bison', 'IsA,animal', True),
 ('american bison', 'IsA,vertebrate', True),
 ('american bison', 'IsA,warm blooded animal', True)]

# Load models

In [88]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [52]:
# Downloads a pretty large model
tokenizer = AutoTokenizer.from_pretrained("allenai/macaw-large")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/macaw-large")
model = model.to(device=device).eval()

In [47]:
# QA Model stuff
def format_question(question_list):
    question_list = ["$answer$ ; $mcoptions$ = (A) yes (B) no; $question$ = " + item \
         for item in question_list]
    return question_list

def predict(question_list):
    B = len(question_list)
    question_list = format_question(question_list)
    answer_list_all_yes = ["$answer$ = yes"] * B     # pass in list of "yes"
    
    # print(dir(tokenizer))
    inputs = tokenizer.batch_encode_plus(question_list, max_length = 256, padding=True, truncation=True, return_tensors="pt")
    labels = tokenizer.batch_encode_plus(answer_list_all_yes, max_length = 15, padding=True, truncation=True, return_tensors="pt") # max_length is set to len("$answer$ = yes")
    
    # output = model.generate(input_ids, max_length=200)
    # answers = tokenizer.batch_decode(output, skip_special_tokens=True)
    fwd = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"],
                # decoder_input_ids=labels["input_ids"], decoder_attention_mask=labels["attention_mask"])
                labels=labels["input_ids"])
    # output_ids = torch.argmax(fwd.logits, dim=-1)
    # print(tokenizer.batch_decode(output_ids, skip_special_tokens=True))

    # loss
    # loss = fwd.loss # - log(P(y|x))
    # confidence = torch.exp(-loss)
    logits = fwd.logits.reshape((B, 7, -1))
    logits = logits[:, 5, :] # Index of yes/no token in answer
    probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()
    # yes has input_id 4273, no has input_id 150
    confidence_yes = probs[..., 4273] 
    confidence_no = probs[..., 150]
    
    answers = (confidence_yes >= confidence_no) # np.array([(ans == "$answer$ = yes") for ans in answers])
    confidences = np.where(answers, confidence_yes, confidence_no)

    return answers, confidences


In [53]:
nli_tokenizer = AutoTokenizer.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")
nli_model = AutoModelForSequenceClassification.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")
nli_model = nli_model.to(device=device).eval()

Some weights of the model checkpoint at ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [14]:
def contradiction_matrix(sents, nli_tokenizer, nli_model):
    if sents.ndim == 1:
        sents = sents.reshape(1, -1)
    
    N, B = sents.shape
    prem = []
    hypo = []
    for i in range(N):
        for j in range(B):
            for k in range(B):
                prem.append(sents[i][j])
                hypo.append(sents[i][k])

    tokenized = nli_tokenizer(prem, hypo, 
                              max_length=256, 
                              return_token_type_ids=True, 
                              truncation=True,
                              padding=True)
    
    input_ids = torch.Tensor(tokenized['input_ids']).to(device).long()
    token_type_ids = torch.Tensor(tokenized['token_type_ids']).to(device).long()
    attention_mask = torch.Tensor(tokenized['attention_mask']).to(device).long()
    
    nli_outputs = nli_model(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            labels=None)
    predicted_probability = torch.softmax(nli_outputs.logits, dim=1)
    contra_matrix = predicted_probability[..., 2]
    contra_matrix = contra_matrix.reshape(N, B, B)
    return contra_matrix.detach().cpu().numpy()

In [15]:
# Correction methods
def do_nothing(predictions, confidences, contra_matrix):
    return predictions

def correction_1(predictions, confidences, contra_matrix):
    contra_matrix_sym = (contra_matrix + contra_matrix.T) / 2
    pass
    return predictions

In [68]:
def evaluate(predictions, answers):
    if predictions.ndim == 1:
        predictions = predictions.reshape(1, -1)
    answers = answers.reshape(predictions.shape)
    
    actual_answers = answers.copy()
    actual_answers[:, 1:] = np.logical_not(actual_answers[:, 1:])
    acc = np.sum(predictions == actual_answers)
    
    yes_no = (answers[:, 0] == answers[:, 1])
    pred_same = (predictions[:, 0] == predictions[:, 1])
    pred_diff = np.logical_not(pred_same)
    con = np.where(yes_no, pred_same, pred_diff)
    con = np.count_nonzero(con)
    
    total = predictions.size
    bsize = predictions.shape[0]
    return acc, con, total, bsize

In [71]:
# TODO: Batch these calculations (in particular when plugging into predict/QA model
acc_count = 0
con_count = 0
total_count = 0
num_pairs = 0

batch_size = 10
batch_counter = 0
num_batches = 0
batch = []
for idx, base in enumerate(statements):
    entity, relation, true = base
    
    filter_dict = {
        'source': relation,
        'direction': 'forward',
    }
    selected_constraints = find_constraints(constraints, filter_dict=filter_dict)
    if len(selected_constraints) == 0:
        continue
    c = random.choice(selected_constraints)
    contra = (entity, c['target'], not (c['weight'] == 'yes_yes'))
    # print(base, contra)
    
    # batch = [base, contra]
    if batch_counter == 0:
        batch = []
    batch_counter += 1
    batch.extend([base, contra])
    if batch_counter < batch_size: # Batch not full yet, keep accumulating examples
        continue
    # We have a full batch
    batch_counter = 0
    num_batches += 1
    
    questions, answers = zip(*[generate_question(*tup) for tup in batch])
    question_list = list(questions)
    answer_list = np.array([ans == "Yes" for ans in answers])
    # print("Questions:", question_list)
    # print("Labels (for contradiction):", answer_list)
    
    predictions, confidences = predict(question_list)
    predictions = predictions.flatten()
    confidences = confidences.flatten()
    # print("QA predictions:", predictions)
    # print("QA confidences:", confidences)
    
    pred_batch = [(ent, rel, predictions[i]) for i, (ent, rel, true) in enumerate(batch)]
    assertions = [generate_assertion(*tup) for tup in pred_batch]
    # print("Assertions:", assertions)
    
    assertions = np.array(assertions).reshape(batch_size, -1)
    contra_matrix = contradiction_matrix(assertions, nli_tokenizer, nli_model)
    # print("Contradiction probability matrix:\n", contra_matrix)
    
    predictions = predictions.reshape(batch_size, -1)
    confidences = confidences.reshape(batch_size, -1)
    corrected = do_nothing(predictions, confidences, contra_matrix)
    acc, con, total, bsize = evaluate(corrected, answer_list)
    acc_count += acc
    con_count += con
    total_count += total
    num_pairs += bsize
    # print(acc, con, total, bsize)
    
    if num_batches % 2 == 0:
        print(f"Iter {idx}: {num_batches} batches, {num_pairs} pairs")
        print(f"\tAccurate {acc_count} / {total_count} = {acc_count / total_count}")
        print(f"\tContradictions {con_count} / {num_pairs} = {con_count / num_pairs}")
    
print(f"Accurate {acc_count} / {total_count} = {acc_count / total_count}")
print(f"Contradictions {con_count} / {total_count // 2} = {con_count / num_pairs}")

Iter 36: 2 batches, 20 pairs
	Accurate 33 / 40 = 0.825
	Contradictions 7 / 20 = 0.35
Iter 61: 4 batches, 40 pairs
	Accurate 53 / 80 = 0.6625
	Contradictions 15 / 40 = 0.375
Iter 81: 6 batches, 60 pairs
	Accurate 82 / 120 = 0.6833333333333333
	Contradictions 22 / 60 = 0.36666666666666664
Iter 101: 8 batches, 80 pairs
	Accurate 110 / 160 = 0.6875
	Contradictions 28 / 80 = 0.35
Iter 122: 10 batches, 100 pairs
	Accurate 142 / 200 = 0.71
	Contradictions 32 / 100 = 0.32
Iter 149: 12 batches, 120 pairs
	Accurate 171 / 240 = 0.7125
	Contradictions 41 / 120 = 0.3416666666666667
Iter 183: 14 batches, 140 pairs
	Accurate 200 / 280 = 0.7142857142857143
	Contradictions 50 / 140 = 0.35714285714285715
Iter 204: 16 batches, 160 pairs
	Accurate 229 / 320 = 0.715625
	Contradictions 61 / 160 = 0.38125
Iter 224: 18 batches, 180 pairs
	Accurate 260 / 360 = 0.7222222222222222
	Contradictions 70 / 180 = 0.3888888888888889
Iter 244: 20 batches, 200 pairs
	Accurate 290 / 400 = 0.725
	Contradictions 76 / 200 = 

KeyboardInterrupt: 