In [89]:
import os, sys
import random
import json
import nltk 
import csv
import torch
import numpy as np
import nltk  # $ pip install nltk
from nltk.stem import PorterStemmer
from nltk.corpus import cmudict  # >>> nltk.download('cmudict')
from nltk.tokenize import word_tokenize
from spellchecker import SpellChecker

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification

from beliefbank_data.utils import generate_assertion, generate_question, find_constraints

In [2]:
constraints_path = "beliefbank_data/constraints_v2.json"
facts_path = "beliefbank_data/silver_facts.json"

In [3]:
constraints = json.load(open(constraints_path))
facts = json.load(open(facts_path))

In [4]:
statements = [(entity, relation, label == 'yes')
              for entity, relations in facts.items() 
              for relation, label in relations.items()]
statements[:5]

[('american bison', 'IsA,mammal', True),
 ('american bison', 'IsA,american bison', True),
 ('american bison', 'IsA,animal', True),
 ('american bison', 'IsA,vertebrate', True),
 ('american bison', 'IsA,warm blooded animal', True)]

In [5]:
# Downloads a pretty large model
tokenizer = AutoTokenizer.from_pretrained("allenai/macaw-large")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/macaw-large")

Downloading:   0%|          | 0.00/2.75G [00:00<?, ?B/s]

In [152]:
# QA Model stuff
def format_question(question_list):
    question_list = ["$answer$ ; $mcoptions$ = (A) yes (B) no; $question$ = " + item \
         for item in question_list]
    return question_list

def predict(question_list):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    B = len(question_list)
    question_list = format_question(question_list)
    answer_list_all_yes = ["$answer$ = yes"] * B     # pass in list of "yes"

    input_ids = tokenizer.encode(question_list, max_length = 256, padding=True, truncation=True, return_tensors="pt")
    labels = tokenizer.encode(answer_list_all_yes, max_length = 15, padding=True, truncation=True, return_tensors="pt") # max_length is set to len("$answer$ = yes")

    # output = model.generate(input_ids, max_length=200)
    # answers = tokenizer.batch_decode(output, skip_special_tokens=True)
    fwd = model(input_ids=input_ids, labels=labels)

    # loss
    # loss = fwd.loss # - log(P(y|x))
    # confidence = torch.exp(-loss)
    logits = fwd.logits.reshape((B, 7, -1))
    logits = logits[:, 5, :] # Index of yes/no token in answer
    probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()
    # yes has input_id 4273, no has input_id 150
    confidence_yes = probs[..., 4273] 
    confidence_no = probs[..., 150]
    
    answers = np.array([(ans == "$answer$ = yes") for ans in answers])
    confidences = np.where(answers, confidence_yes, confidence_no)

    return answers, confidences


In [69]:
nli_tokenizer = AutoTokenizer.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")
nli_model = AutoModelForSequenceClassification.from_pretrained("ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli")

Some weights of the model checkpoint at ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [80]:
def contradiction_matrix(sents, nli_tokenizer, nli_model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    B = len(sents)
    prem = []
    hypo = []
    for i1 in range(B):
        for i2 in range(B):
            prem.append(sents[i1])
            hypo.append(sents[i2])

    tokenized = nli_tokenizer(prem, hypo, 
                              max_length=256, 
                              return_token_type_ids=True, 
                              truncation=True,
                              padding=True)
    
    input_ids = torch.Tensor(tokenized['input_ids']).to(device).long()
    token_type_ids = torch.Tensor(tokenized['token_type_ids']).to(device).long()
    attention_mask = torch.Tensor(tokenized['attention_mask']).to(device).long()
    
    nli_outputs = nli_model(input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids,
                            labels=None)
    predicted_probability = torch.softmax(nli_outputs.logits, dim=1)
    contra_matrix = predicted_probability[..., 2]
    contra_matrix = contra_matrix.reshape(B, B)
    return contra_matrix.detach().cpu().numpy()

In [85]:
# Correction methods
def do_nothing(predictions, confidences, contra_matrix):
    return predictions

def correction_1(predictions, confidences, contra_matrix):
    contra_matrix_sym = (contra_matrix + contra_matrix.T) / 2
    pass
    return predictions

In [104]:
def evaluate(predictions, answers):
    actual_answers = answers.copy()
    actual_answers[1] = not actual_answers[1]
    acc = np.count_nonzero(predictions == actual_answers)
    if answers[0] == answers[1]:
        con = 1 * (predictions[0] == predictions[1])
    else:
        con = 1 * (predictions[0] != predictions[1])
    total = len(predictions)
    return acc, con, total

In [153]:
# TODO: Batch these calculations (in particular when plugging into predict/QA model

acc_count = 0
con_count = 0
total_count = 0
for idx, base in enumerate(statements):
    entity, relation, true = base
    
    filter_dict = {
        'source': relation,
        'direction': 'forward',
    }
    selected_constraints = find_constraints(constraints, filter_dict=filter_dict)
    if len(selected_constraints) == 0:
        continue
    c = random.choice(selected_constraints)
    contra = (entity, c['target'], not (c['weight'] == 'yes_yes'))
    # print(base, contra)
    
    batch = [base, contra]
    
    questions, answers = zip(*[generate_question(*tup) for tup in batch])
    question_list = list(questions)
    answer_list = [ans == "Yes" for ans in answers]
    # print("Questions:", question_list)
    # print("Labels (for contradiction):", answer_list)
    
    predictions, confidences = predict(question_list)
    predictions = predictions.flatten()
    confidences = confidences.flatten()
    print("QA predictions:", predictions)
    print("QA confidences:", confidences)
    
    pred_batch = [(ent, rel, predictions[i]) for i, (ent, rel, true) in enumerate(batch)]
    assertions = [generate_assertion(*tup) for tup in pred_batch]
    # print("Assertions:", assertions)
    
    contra_matrix = contradiction_matrix(assertions, nli_tokenizer, nli_model)
    # print("Contradiction probability matrix:\n", contra_matrix)
    
    corrected = do_nothing(predictions, confidences, contra_matrix)
    acc, con, total = evaluate(corrected, answer_list)
    acc_count += acc
    con_count += con
    total_count += count
    # print(acc, con, count)
    
    if idx % 10 == 0:
        print(f"Iter {idx}: {acc_count}, {con_count} / {total_count}")
    
print(f"Accurate {acc_count} / {total_count} = {acc_count / total_count}")
print(f"Contradictions {con_count} / {total_count // 2} = {con_count * 2 / total_count}")

tensor([[   0, 1514, 3247, 3321, 3229, 3274, 4273,    1]])
QA predictions: [ True]
QA confidences: [0.9963271  0.67922425]


IndexError: index 1 is out of bounds for axis 0 with size 1