In [4]:
import csv
import logging
import os
import random
import sys
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
#from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

from bertviz.bertviz import attention, visualization
from bertviz.bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [53]:
logger = logging.getLogger(__name__)
bert_classifier_model_dir = "./output/" ## Path of BERT classifier model path
device = torch.device("cpu")
n_gpu = 1
logger.info("device: {}, n_gpu {}".format(device, n_gpu))
## Model for performing Classification
model_cls = BertForSequenceClassification.from_pretrained(bert_classifier_model_dir, num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model_cls.to(device)
model_cls.eval()
## Model to get the attention weights of all the heads
model = BertModel.from_pretrained(bert_classifier_model_dir)
model.to(device)
model.eval()

06/29/2020 20:26:35 - INFO - __main__ -   device: cpu, n_gpu 1
06/29/2020 20:26:35 - INFO - pytorch_pretrained_bert.modeling -   loading archive file ./output/
06/29/2020 20:26:35 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "type_vocab_size": 2,
  "vocab_size": 30522
}

06/29/2020 20:26:38 - INFO - bertviz.bertviz.pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/rajat/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084
06/29/2020 20:26:38 - INFO - bertviz.bertviz.pytorch_pretrained_bert.modelin

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
        )

In [54]:
max_seq_len=128 # Maximum sequence length 
sm = torch.nn.Softmax(dim=-1) ## Softmax over the batch

In [55]:
def run_multiple_examples(input_sentences, bs=2):
    """
    This fucntion returns classification predictions for batch of sentences.
    input_sentences: list of strings
    bs : batch_size : int
    """
    
    ## Prepare data for classification
    ids = []
    segment_ids = []
    input_masks = []
    pred_lt = []
    for sen in input_sentences:
        text_tokens = tokenizer.tokenize(sen)
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        temp_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))

        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    
    ## Convert input lists to Torch Tensors
    ids = torch.tensor(ids)
    segment_ids = torch.tensor(segment_ids)
    input_masks = torch.tensor(input_masks)
    
    steps = len(ids) // bs
    
    for i in range(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(device)
        temp_segment_ids = temp_segment_ids.to(device)
        temp_input_masks = temp_input_masks.to(device)
        
        with torch.no_grad():
            preds = sm(model_cls(temp_ids, temp_segment_ids, temp_input_masks))
            _, _, attn = model(temp_ids, temp_segment_ids, temp_input_masks)
        pred_lt.extend(preds.tolist())
    
    return attn

In [56]:
def read_file(input_file, length):
    """Reads a coma separated value file."""
    with open(input_file, "r") as f:
        reader = csv.reader(f,delimiter=",")
        next(reader)
        lines = []
        for line in reader:
            lines.append(line[3][:length])
            break
        return lines

In [57]:
pos_data = read_file('/data/rajat/datasets/reddit/dev.csv',128)
data = pos_data

In [58]:
a = run_multiple_examples(data)

tensor([[  101,  7929,  2903,  2054,  2017,  2215,  2017,  8239,  4485, 16985,
          2094,  2111,  2066,  2017,  2024,  1037,  2350,  2112,  1997,  2339,
         14398,  2145,  6526,  1999,  1996,  2034,  2173, 10930,   102,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,  

In [59]:
a[8]['attn_probs'][0][1][0]

tensor([0.0024, 0.0109, 0.0027, 0.0026, 0.0033, 0.0043, 0.0229, 0.3890, 0.2886,
        0.1067, 0.0562, 0.0248, 0.0025, 0.0009, 0.0013, 0.0012, 0.0011, 0.0003,
        0.0002, 0.0002, 0.0069, 0.0008, 0.0005, 0.0004, 0.0006, 0.0014, 0.0006,
        0.0515, 0.0153, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 

In [50]:
data

['ok believe what you want you fucking shit tard people like you are a major part of why racism still exists in the first place yo']

In [49]:
p.__version__


'0.6.1'