# Self-attention model (BERT)

In [1]:
# Load model and retrieve attention weights

from bertviz import head_view, model_view
from transformers import BertTokenizer, BertModel

model_version = 'bert-base-uncased'
model = BertModel.from_pretrained(model_version, output_attentions=True)
tokenizer = BertTokenizer.from_pretrained(model_version)
sentence_a = "The cat sat on the mat"
sentence_b = "The cat lay on the rug"
inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt')
input_ids = inputs['input_ids']
token_type_ids = inputs['token_type_ids']
attention = model(input_ids, token_type_ids=token_type_ids)[-1]
sentence_b_start = token_type_ids[0].tolist().index(1)
input_id_list = input_ids[0].tolist() # Batch index 0
tokens = tokenizer.convert_ids_to_tokens(input_id_list) 

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [2]:
tokens

['[CLS]',
 'the',
 'cat',
 'sat',
 'on',
 'the',
 'mat',
 '[SEP]',
 'the',
 'cat',
 'lay',
 'on',
 'the',
 'rug',
 '[SEP]']

In [3]:
attention[1].shape

torch.Size([1, 12, 15, 15])

In [4]:
head_view(attention, tokens, sentence_b_start)

<IPython.core.display.Javascript object>

In [5]:
for a in attention:
    print(a.shape)

torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])
torch.Size([1, 12, 15, 15])


# Encoder-decoder model (BART)

In [7]:
from transformers import AutoTokenizer, AutoModel, utils
from bertviz import model_view

utils.logging.set_verbosity_error()  # Remove line to see warnings

# Initialize tokenizer and model. Be sure to set output_attentions=True.
# Load BART fine-tuned for summarization on CNN/Daily Mail dataset
model_name = "facebook/bart-large-cnn"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, output_attentions=True)

# get encoded input vectors
encoder_input_ids = tokenizer("The House Budget Committee voted Saturday to pass a $3.5 trillion spending bill", return_tensors="pt", add_special_tokens=True).input_ids

# create ids of encoded input vectors
decoder_input_ids = tokenizer("The House Budget Committee passed a spending bill.", return_tensors="pt", add_special_tokens=True).input_ids

outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens=decoder_text
)

<IPython.core.display.Javascript object>

In [15]:
print(outputs.encoder_attentions[0].shape)
print(outputs.decoder_attentions[0].shape)
print(outputs.cross_attentions[0].shape)
print(encoder_text)

torch.Size([1, 16, 18, 18])
torch.Size([1, 16, 11, 11])
torch.Size([1, 16, 11, 18])
['<s>', 'The', 'ĠHouse', 'ĠBudget', 'ĠCommittee', 'Ġvoted', 'ĠSaturday', 'Ġto', 'Ġpass', 'Ġa', 'Ġ$', '3', '.', '5', 'Ġtrillion', 'Ġspending', 'Ġbill', '</s>']
