In [2]:
import time

import itertools

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import torch

In [30]:
start = time.time()
print("Loading Model")

# AIMA -> explicar no contexto de busca
obs = ['In your inventory, you see: a metal pot (containing a substance called caesium) an orange a thermometer, currently reading a temperature of 10 degrees celsius ']

context = ['you have metal pot with caesium in your inventory']
#

a = list(itertools.product(obs, context))
a.sort(key=lambda x: x[1])
print(f"Combinations = {len(a)}")

max_length = 256

hg_model_hub_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"

config = AutoConfig.from_pretrained(hg_model_hub_name)
tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name)

print(config.label2id)

#device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name)
model.to(device)

end = time.time()
print(f"Model loaded {end - start} - model {model.device}")

Loading Model
Combinations = 1
{'entailment': 0, 'neutral': 1, 'contradiction': 2}


Some weights of the model checkpoint at ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli were not used when initializing RobertaForSequenceClassification: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Model loaded 36.25479698181152 - model cpu


In [31]:
start = time.time()
tokenized_input_seq_pair = tokenizer.batch_encode_plus(a,
                                                       max_length=max_length,
                                                       return_token_type_ids=True, truncation=True, padding=True)

input_ids = torch.tensor(tokenized_input_seq_pair['input_ids'], device=device).long()  #.unsqueeze(0)
# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
token_type_ids = torch.tensor(tokenized_input_seq_pair['token_type_ids'], device=device).long()  #.unsqueeze(0)
attention_mask = torch.tensor(tokenized_input_seq_pair['attention_mask'], device=device).long()  #.unsqueeze(0)

outputs = model(input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=None)
end = time.time()

print(f"Inference time: {end - start}")
logits = outputs[0]
probs = torch.softmax(logits, dim=1)
for i, pair in enumerate(a):
    predicted_probability = probs[i].tolist()  # batch_size only one
    print('----')
    print("Premise:", pair[0])
    print("Hypothesis:", pair[1])
    print("Entailment:", predicted_probability[int(config.label2id['entailment'])])
    print("Neutral:", predicted_probability[int(config.label2id['neutral'])])
    print("Contradiction:", predicted_probability[int(config.label2id['contradiction'])])

Inference time: 0.7378344535827637
----
Premise: In your inventory, you see: a metal pot (containing a substance called caesium) an orange a thermometer, currently reading a temperature of 10 degrees celsius 
Hypothesis: you have metal pot with caesium in your inventory
Entailment: 0.990742027759552
Neutral: 0.007689112797379494
Contradiction: 0.0015688496641814709


- $context = (c_1, c_2)$
- $beliefbase = (b_1, b_2)$
- $(b_1 \models c_1 \lor b_2 \models c_1) \land (b_1 \models c_2 \lor b_2 \models c_2)$

In [26]:
num_context = len(context)
num_beliefs = len(obs)

argmax_probs = probs.argmax(-1)
slice_idx = []
idx = 0
for i in range(num_context):
    slice_idx.append(argmax_probs[idx:(idx + num_beliefs)])
    idx = num_beliefs
slice_idx

[tensor([0, 0]), tensor([1, 0])]

In [28]:
#c1_or = torch.where(c1 == entailment_idx, True, False)
#c2_or = torch.where(c2 == entailment_idx, True, False)
entailment_idx = 0
context_or = [torch.where(c == entailment_idx, True, False).any().unsqueeze(0) for c in slice_idx]
all_ors = torch.concatenate(context_or)  # contiguous
entailment = all_ors.all()
if entailment:
    print(entailment.item())
else:
    print("not entailment")

True


In [29]:
all_ors

tensor([True, True])

In [27]:
context_or

[tensor([False]), tensor([True])]

In [71]:
entailment_probs = (probs[:, entailment_idx] * entailment_mask)
entailment_probs[entailment_probs != 0].mean()

tensor(0.9460, grad_fn=<MeanBackward0>)

In [72]:
type(entailment_probs)

torch.Tensor