In [53]:
import torch
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, DistilBertConfig, BertForSequenceClassification, BertTokenizer, BertConfig
from datasets import load_dataset,load_metric
import numpy as np


from accelerate import Accelerator


accelerator = Accelerator()
device = accelerator.device

In [54]:
# texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
# texts = "I really didn't like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
texts = "I don't like this movie."

# pretrained_name = "distilbert-base-uncased-finetuned-sst-2-english"
pretrained_name = "textattack/bert-base-uncased-SST-2"

# pred_config = DistilBertConfig.from_pretrained(pretrained_name)
# pred_tokenizer = DistilBertTokenizer.from_pretrained(pretrained_name)
# pred_model = DistilBertForSequenceClassification.from_pretrained(pretrained_name).to(device)

pred_config = BertConfig.from_pretrained(pretrained_name)
pred_tokenizer = BertTokenizer.from_pretrained(pretrained_name)
pred_model = BertForSequenceClassification.from_pretrained(pretrained_name).to(device)

inputs = pred_tokenizer(texts, return_tensors="pt")
with torch.no_grad():
    inputs = {key:val.to(device) for key,val in inputs.items()}
    logits = pred_model(**inputs).logits

predicted_class_id = logits.argmax().item()
print(pred_model.config.id2label[predicted_class_id])

print(inputs)

LABEL_0
{'input_ids': tensor([[ 101, 1045, 2123, 1005, 1056, 2066, 2023, 3185, 1012,  102]],
       device='cuda:0'), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}


In [55]:
from datasets import load_dataset
imdb = load_dataset("imdb")


In [72]:
from maskgen.text_models.text_maskgen_model2 import MaskGeneratingModel
pred_hidden_dim = pred_model.config.hidden_size
num_labels = pred_model.config.num_labels

mask_gen_model = MaskGeneratingModel(pred_model, hidden_size=pred_hidden_dim, num_classes=num_labels)
mask_gen_model.to(device)

mask_gen_model.load_state_dict(torch.load('text_mask_gen_model/mask_gen_model_0_70.pth'))
# mask_gen_model.load_state_dict(torch.load('trained/mask_gen_model12/mask_gen_model_2_90.pth'))
mask_gen_model.eval()
print()




In [75]:
# texts = "This movie was the best movie I have ever seen! some scenes were ridiculous, but acting was great."
texts = "I really didn't like this movie. Some of the actors were good, but overall the movie was boring."
# texts = "I hate that I love you."
# texts = "I don't like this movie."
# texts = "I really love this film."
# texts = 'I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\'t match the background, and painfully one-dimensional characters cannot be overcome with a \'sci-fi\' setting. (I\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\'s not. It\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\'s rubbish as they have to always say "Gene Roddenberry\'s Earth..." otherwise people would not continue watching. Roddenberry\'s ashes must be turning in their orbit as this dull, cheap, poorly edited (watching it without advert breaks really brings this home) trudging Trabant of a show lumbers into space. Spoiler. So, kill off a main character. And then bring him back as another actor. Jeeez! Dallas all over again.'

# idx = 14
# texts = imdb["test"][idx]['text']

label = None
# label = torch.tensor([1]).to(device)

inputs = pred_tokenizer(texts, return_tensors="pt")
with torch.no_grad():
    inputs = {key:val.to(device) for key,val in inputs.items()}
    logits = pred_model(**inputs).logits

predicted_class_id = logits.argmax().item()
print(pred_model.config.id2label[predicted_class_id])

if label is None:
    label = logits.argmax().unsqueeze(0)


from captum.attr import visualization
expl = mask_gen_model.attribute_text(inputs['input_ids'], inputs['attention_mask'], label)[0][1:-1]
print(expl)
tokens = pred_tokenizer.convert_ids_to_tokens(inputs['input_ids'].flatten())[1:-1]
# normalize expl
expl = (expl - expl.min()) / (expl.max() - expl.min()) # - 0.5
vis_data_records = [visualization.VisualizationDataRecord(
                                 label * expl + (1 - label) * (-expl),
                                0,
                                0,
                                0,
                                0,
                                1,       
                                tokens,
                                1)]
                            
visualization.visualize_text(vis_data_records)

LABEL_0
tensor([0.4209, 0.5455, 0.7139, 0.4016, 0.7630, 0.4861, 0.4441, 0.3539, 0.4043,
        0.3413, 0.3233, 0.3420, 0.3153, 0.3212, 0.2967, 0.3743, 0.5592, 0.6849,
        0.5570, 0.3998, 0.6744, 0.7604, 0.4164], device='cuda:0')


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,1.0,"i really didn ' t like this movie . some of the actors were good , but overall the movie was boring ."
,,,,


True Label,Predicted Label,Attribution Label,Attribution Score,Word Importance
0.0,0 (0.00),0.0,1.0,"i really didn ' t like this movie . some of the actors were good , but overall the movie was boring ."
,,,,


: 

In [61]:
label

tensor([1], device='cuda:0')

In [21]:
expl.shape

torch.Size([21])

In [None]:
len(tokens)

23