## Getting Started

In [None]:
import os
import torch

from transformers import BertTokenizer
from bertviz.transformers_neuron_view import BertModel as BertvizModel
from bertviz.transformers_neuron_view import BertTokenizer as BertvizTokenizer
from bertviz.neuron_view import show
from bertviz import model_view

from models.nets import get_end_to_end_net
from utils.utils import read_tsv

In [None]:
bert_state_path = '../weights/biobert_large_v1.1_pubmed_torch'
trained_model_path = '../weights/end-to-end-1-fixvalid/best_model_13000'

In [None]:
# initialize model and tokenizers
net = get_end_to_end_net(
    bert_state_path,
    1024,
    [1024, 1024],
    8,
    'ReLU'
).cpu()
net.load_state_dict(torch.load(trained_model_path))
net.eval()
tokenizer = BertTokenizer(os.path.join(bert_state_path, 'vocab.txt'), do_lower_case=False)

In [None]:
bertviz_net = get_end_to_end_net(
    bert_state_path,
    1024,
    [1024, 1024],
    8,
    'ReLU'
).cpu()
bertviz_net.bert = BertvizModel.from_pretrained(bert_state_path)
bertviz_net.load_state_dict(torch.load(trained_model_path), strict=False)
bertviz_net.eval()
bertviz_tokenizer = BertvizTokenizer(os.path.join(bert_state_path, 'vocab.txt'), do_lower_case=False)

In [None]:
# read in datasets
data = read_tsv('../data/merged/training/train.txt')

In [None]:
sentence = data[550][2]
sentence

## Attention Viewer

In [None]:
# sentence = "Hello World"

In [None]:
show(bertviz_net.bert, 'bert', bertviz_tokenizer, sentence)

In [None]:
ids = tokenizer.encode(sentence, return_tensors='pt')
tokens = tokenizer.convert_ids_to_tokens(ids[0])
output = net.bert(ids, output_attentions=True)
model_view(output.attentions, tokens, include_layers=[0])