In [2]:
import time

import itertools

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import torch

In [13]:
start = time.time()
print("Loading Model")

# AIMA -> explicar no contexto de busca
obs = ['You see the cupboard(which is closed)']

context = ['You see the cupboard(which is open)', 'You see the door to kitchen(which is open)', 'You see a table. On the table is: a bowl (containing nothing).']
#

a = list(itertools.product(obs, context))
a.sort(key = lambda x: x[1])
print(f"Combinations = {len(a)}")

max_length = 256

hg_model_hub_name = "alisawuffles/roberta-large-wanli"

config = AutoConfig.from_pretrained(hg_model_hub_name)
tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name)

print(config.label2id)

#device = 'cuda' if torch.cuda.is_available() else 'cpu'
device='cpu'

model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name)
model.to(device)

end = time.time()
print(f"Model loaded {end - start} - model {model.device}")

Loading Model
Combinations = 3
{'contradiction': 0, 'entailment': 1, 'neutral': 2}
Model loaded 4.482424020767212 - model cpu


In [14]:
start = time.time()
tokenized_input_seq_pair = tokenizer.batch_encode_plus(a,
                                                 max_length=max_length,
                                                 return_token_type_ids=True, truncation=True, padding=True)

input_ids = torch.tensor(tokenized_input_seq_pair['input_ids'], device=device).long()#.unsqueeze(0)
# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
token_type_ids = torch.tensor(tokenized_input_seq_pair['token_type_ids'], device=device).long()#.unsqueeze(0)
attention_mask = torch.tensor(tokenized_input_seq_pair['attention_mask'], device=device).long()#.unsqueeze(0)

outputs = model(input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=None)
end = time.time()

print(f"Inference time: {end - start}")
logits = outputs[0]
probs = torch.softmax(logits, dim=1)
for i, pair in enumerate(a):
    predicted_probability = probs[i].tolist()  # batch_size only one
    print('----')
    print("Premise:", pair[0])
    print("Hypothesis:", pair[1])
    print("Entailment:", predicted_probability[int(config.label2id['entailment'])])
    print("Neutral:", predicted_probability[int(config.label2id['neutral'])])
    print("Contradiction:", predicted_probability[int(config.label2id['contradiction'])])

Inference time: 0.2976675033569336
----
Premise: You see the cupboard(which is closed)
Hypothesis: You see a table. On the table is: a bowl (containing nothing).
Entailment: 0.0023102350533008575
Neutral: 0.9412702322006226
Contradiction: 0.056419484317302704
----
Premise: You see the cupboard(which is closed)
Hypothesis: You see the cupboard(which is open)
Entailment: 0.005573701113462448
Neutral: 0.006544825155287981
Contradiction: 0.9878815412521362
----
Premise: You see the cupboard(which is closed)
Hypothesis: You see the door to kitchen(which is open)
Entailment: 0.007574471645057201
Neutral: 0.5222721099853516
Contradiction: 0.47015348076820374


- $context = (c_1, c_2)$
- $beliefbase = (b_1, b_2)$
- $(b_1 \models c_1 \lor b_2 \models c_1) \land (b_1 \models c_2 \lor b_2 \models c_2)$

In [9]:
num_context = len(context)
num_beliefs = len(obs)

argmax_probs = probs.argmax(-1)
slice_idx = []
idx = 0
for i in range(num_context):
    slice_idx.append(argmax_probs[idx:(idx + num_beliefs)])
    idx = num_beliefs
slice_idx

[tensor([1, 2, 2]), tensor([2, 0, 0]), tensor([2, 0, 0])]

In [11]:
#c1_or = torch.where(c1 == entailment_idx, True, False)
#c2_or = torch.where(c2 == entailment_idx, True, False)
entailment_idx = 1
context_or = [torch.where(c == entailment_idx, True, False).any().unsqueeze(0) for c in slice_idx]
all_ors = torch.concatenate(context_or) # contiguous
entailment = all_ors.all()
if entailment:
    print(entailment.item())
else:
    print("not entailment")

not entailment


In [91]:
all_ors

tensor([True, True, True])

In [61]:
entailment_mask = torch.where(argmax_probs == entailment_idx, True, False)


tensor([[2.5057e-04, 8.9700e-01, 1.0275e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.5948e-04, 9.9499e-01, 4.4468e-03]], grad_fn=<MulBackward0>)

In [71]:
entailment_probs = (probs[:, entailment_idx] * entailment_mask)
entailment_probs[entailment_probs != 0].mean()

tensor(0.9460, grad_fn=<MeanBackward0>)

In [72]:
type(entailment_probs)

torch.Tensor