## Getting Started

In [None]:
!pip3 install bertviz

In [None]:
import os
import torch

from transformers import BertTokenizer
from bertviz.transformers_neuron_view import BertModel as BertvizModel
from bertviz.transformers_neuron_view import BertTokenizer as BertvizTokenizer
from bertviz.neuron_view import show
from bertviz import model_view

from models.nets import get_end_to_end_net
from utils.utils import read_tsv

### A more in-depth explanation of how this attention visualization works is written [in this blog post](https://towardsdatascience.com/deconstructing-bert-part-2-visualizing-the-inner-workings-of-attention-60a16d86b5c1?gi=0205807bbbe7)

In [None]:
# read bert models
bert_state_path = '../model_weights/biobert_large_v1.1_pubmed_torch'
trained_model_path = '../weights/end-to-end-1/best_model_10000'

# read in datasets
data = read_tsv('../data/merged/training/train.txt')

In [None]:
sentence = data[550][2]
sentence

## Neuron View

The neuron view visualizes individual neurons in the query and key vectors and shows how they are used to compute attention.

In [None]:
''' 
The neuron view is invoked differently than the head view or model view, 
due to requiring access to the model's query/key vectors, which are not returned through the Huggingface API. 
It is currently limited to custom versions of BERT, GPT-2, and RoBERTa included with BertViz.
'''

bertviz_net = get_end_to_end_net(
    bert_state_path,
    1024,
    [1024, 1024],
    8,
    'ReLU'
).cpu()
bertviz_net.bert = BertvizModel.from_pretrained(bert_state_path)
bertviz_net.load_state_dict(torch.load(trained_model_path), strict=False)
bertviz_net.eval()

bertviz_tokenizer = BertvizTokenizer(os.path.join(bert_state_path, 'vocab.txt'), do_lower_case=False)

In [None]:
show(bertviz_net.bert, 'bert', bertviz_tokenizer, sentence)

## Model View

The model view shows a bird's-eye view of attention across all layers and heads.

In [None]:
''' 
The model view requires an unmodified instance of the bert model, as the encode/tokenizer functions are missing
certain functions in the modified bertviz versions.
'''

# initialize model and tokenizers
net = get_end_to_end_net(
    bert_state_path,
    1024,
    [1024, 1024],
    8,
    'ReLU'
).cpu()
net.load_state_dict(torch.load(trained_model_path))
net.eval()
tokenizer = BertTokenizer(os.path.join(bert_state_path, 'vocab.txt'), do_lower_case=False)

In [None]:
ids = tokenizer.encode(sentence, return_tensors='pt') # bertviz tokenizer does not support return_tensors
tokens = tokenizer.convert_ids_to_tokens(ids[0]) 
output = net.bert(ids, output_attentions=True) # bertviz net does not support output_attentions
model_view(output.attentions, tokens) #,include_layers=[0])