In [1]:
import time

import itertools

from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
import torch

In [2]:
start = time.time()
print("Loading Model")

obs = ['This room is called the art studio.',
       'you see a large cupboard. The large cupboard door is closed.',
       'you see a table. On the table is: a glass cup (containing nothing).']

context = ['You are in the art studio', 'you see a table with a container on the top containing nothing']

a = list(itertools.product(obs, context))
a.sort(key = lambda x: x[1])
print(f"Combinations = {len(a)}")

max_length = 256

hg_model_hub_name = "alisawuffles/roberta-large-wanli"

config = AutoConfig.from_pretrained(hg_model_hub_name)
tokenizer = AutoTokenizer.from_pretrained(hg_model_hub_name)

print(config.label2id)

#device = 'cuda' if torch.cuda.is_available() else 'cpu'
device='cpu'

model = AutoModelForSequenceClassification.from_pretrained(hg_model_hub_name)
model.to(device)

end = time.time()
print(f"Model loaded {end - start} - model {model.device}")

Loading Model
Combinations = 6
{'contradiction': 0, 'entailment': 1, 'neutral': 2}
Model loaded 6.416053771972656 - model cpu


In [3]:
start = time.time()
tokenized_input_seq_pair = tokenizer.batch_encode_plus(a,
                                                 max_length=max_length,
                                                 return_token_type_ids=True, truncation=True, padding=True)

input_ids = torch.tensor(tokenized_input_seq_pair['input_ids'], device=device).long()#.unsqueeze(0)
# remember bart doesn't have 'token_type_ids', remove the line below if you are using bart.
token_type_ids = torch.tensor(tokenized_input_seq_pair['token_type_ids'], device=device).long()#.unsqueeze(0)
attention_mask = torch.tensor(tokenized_input_seq_pair['attention_mask'], device=device).long()#.unsqueeze(0)

outputs = model(input_ids,
                attention_mask=attention_mask,
                token_type_ids=token_type_ids,
                labels=None)
# Note:
# "id2label": {
#     "0": "entailment",
#     "1": "neutral",
#     "2": "contradiction"
# },
end = time.time()

print(f"Inference time: {end - start}")
logits = outputs[0]
probs = torch.softmax(logits, dim=1)
for i, pair in enumerate(a):
    predicted_probability = probs[i].tolist()  # batch_size only one
    print('----')
    print("Premise:", pair[0])
    print("Hypothesis:", pair[1])
    print("Entailment:", predicted_probability[int(config.label2id['entailment'])])
    print("Neutral:", predicted_probability[int(config.label2id['neutral'])])
    print("Contradiction:", predicted_probability[int(config.label2id['contradiction'])])

Inference time: 0.6672201156616211
----
Premise: This room is called the art studio.
Hypothesis: You are in the art studio
Entailment: 0.8970025181770325
Neutral: 0.10274694859981537
Contradiction: 0.0002505670709069818
----
Premise: you see a large cupboard. The large cupboard door is closed.
Hypothesis: You are in the art studio
Entailment: 0.000518354878295213
Neutral: 0.9990247488021851
Contradiction: 0.0004568826116155833
----
Premise: you see a table. On the table is: a glass cup (containing nothing).
Hypothesis: You are in the art studio
Entailment: 0.0004657926911022514
Neutral: 0.9987221360206604
Contradiction: 0.0008120540296658874
----
Premise: This room is called the art studio.
Hypothesis: you see a table with a container on the top containing nothing
Entailment: 0.0006510507664643228
Neutral: 0.9931994080543518
Contradiction: 0.006149493157863617
----
Premise: you see a large cupboard. The large cupboard door is closed.
Hypothesis: you see a table with a container on the 

- $context = (c_1, c_2)$
- $beliefbase = (b_1, b_2)$
- $(b_1 \models c_1 \lor b_2 \models c_1) \land (b_1 \models c_2 \lor b_2 \models c_2)$

In [4]:
num_context = len(context)
num_beliefs = len(obs)

argmax_probs = probs.argmax(-1)
c1 = argmax_probs[: num_beliefs]
c2 = argmax_probs[num_beliefs:]

In [5]:
argmax_probs, c1, c2

(tensor([1, 2, 2, 2, 2, 0]), tensor([1, 2, 2]), tensor([2, 2, 0]))

In [6]:
entailment_idx = int(config.label2id['entailment'])
print(entailment_idx)
if entailment_idx in c1 and entailment_idx in c2:
    print("Entailment")
else:
    print("Not entailment")

1
Not entailment


In [14]:
entailment_mask = torch.where(argmax_probs == entailment_idx, True, False)
entailment_mask.unsqueeze(1), probs

(tensor([[ True],
         [False],
         [False],
         [False],
         [False],
         [False]]),
 tensor([[2.5057e-04, 8.9700e-01, 1.0275e-01],
         [4.5688e-04, 5.1835e-04, 9.9902e-01],
         [8.1205e-04, 4.6579e-04, 9.9872e-01],
         [6.1495e-03, 6.5105e-04, 9.9320e-01],
         [1.2940e-01, 1.4320e-03, 8.6917e-01],
         [8.2104e-01, 2.0682e-02, 1.5828e-01]], grad_fn=<SoftmaxBackward0>))

In [16]:
all_entailment_prob = (probs * entailment_mask.unsqueeze(1))
all_entailment_prob

tensor([[2.5057e-04, 8.9700e-01, 1.0275e-01],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00]], grad_fn=<MulBackward0>)

In [18]:
probs[:, entailment_idx].mean()

tensor(0.1535, grad_fn=<MeanBackward0>)