In [1]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
import IPython
import torch
import torch.nn as nn

from transformers import BertModel, BertConfig, BertTokenizer
from bertviz.bertviz import head_view

In [2]:
%%javascript
require.config({
  paths: {
      d3: '//cdnjs.cloudflare.com/ajax/libs/d3/3.4.8/d3.min',
      jquery: '//ajax.googleapis.com/ajax/libs/jquery/2.0.0/jquery.min',
  }
});

<IPython.core.display.Javascript object>

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class BertEnron(nn.Module):
    def __init__(self, config, nclasses, mlp=False):
        super(BertEnron, self).__init__()
        self.model_class, self.tokenizer_class, self.pretrained_weights = config
        self.lm_layer = self.model_class.from_pretrained(self.pretrained_weights)

        self.main = nn.Sequential(
            nn.Linear(768, nclasses)
        ) if not mlp else nn.Sequential(
            nn.Linear(768, 100),
            nn.BatchNorm1d(100),
            nn.ReLU(),
            nn.Linear(100, nclasses)
        )

    def forward(self, input_ids, attn_masks):
        last_hidden_states = self.lm_layer(input_ids, attention_mask=attn_masks)[0]
        cls = last_hidden_states[:, 0, :]
        return self.main(cls)

In [5]:
def show_head_view(model, tokenizer, sentence_a, sentence_b=None):
    inputs = tokenizer.encode_plus(sentence_a, sentence_b, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs['input_ids']
    if sentence_b:
        token_type_ids = inputs['token_type_ids']
        attention = model(input_ids, token_type_ids=token_type_ids)[-1]
        sentence_b_start = token_type_ids[0].tolist().index(1)
    else:
        attention = model(input_ids)[-1]
        sentence_b_start = None
    input_id_list = input_ids[0].tolist() # Batch index 0
    tokens = tokenizer.convert_ids_to_tokens(input_id_list)    
    head_view(attention, tokens, sentence_b_start)

In [6]:
if not os.path.exists('trained_bert'):
    os.makedirs('trained_bert')

In [11]:
enron_bert = nn.DataParallel(BertEnron((BertModel, BertTokenizer, 'bert-base-uncased'), 78)).to(device)
enron_bert.load_state_dict(torch.load('model-3.pth', map_location=torch.device('cpu')))

<All keys matched successfully>

In [12]:
BERT_PATH = './trained_bert/'
enron_bert.module.lm_layer.save_pretrained(BERT_PATH)

In [14]:
bert_model = BertModel.from_pretrained(BERT_PATH, output_attentions=True)
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
print('Sucessfully loaded {} parameters.'.format(len(bert_model.state_dict())))

Sucessfully loaded 199 parameters.


In [17]:
example = "Is there going to be a conference call about all the regulatory issues?"
show_head_view(bert_model, bert_tokenizer, example, None)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>