In [20]:
import pickle
import csv
import logging
import os
import random
import sys

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
#from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

from bertviz.bertviz import attention, visualization
from bertviz.bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer

In [2]:
logger = logging.getLogger(__name__)
bert_classifier_model_dir = "models/" ## Path of BERT classifier model path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("device: {}, n_gpu {}".format(device, n_gpu))

In [3]:
# file paths
data_dir = os.getcwd()
dataset = "data" # amazon / yelp / imagecaption
train_0 = os.path.join(data_dir ,"{}/train/trump.txt".format(dataset))
train_1 = os.path.join(data_dir,"{}/train/en.txt".format(dataset))
test_0 = os.path.join(data_dir,"{}/test/trump.txt".format(dataset))
test_1 = os.path.join(data_dir,"{}/test/en.txt".format(dataset))
dev_0 = os.path.join(data_dir,"{}/dev/trump.txt".format(dataset))
dev_1 = os.path.join(data_dir,"{}/dev/en.txt".format(dataset))
reference_0 = os.path.join(data_dir,"./{}/reference_0.txt".format(dataset))
reference_1 = os.path.join(data_dir,"./{}/reference_1.txt".format(dataset))

In [4]:
# file paths
data_dir = os.getcwd()
dataset = "data" # amazon / yelp / imagecaption
train_0_out = os.path.join(data_dir ,"{}/processed_files_with_bert_with_best_head/sentiment_train_0.txt".format(dataset))
train_1_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_train_1.txt".format(dataset))
test_0_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_test_0.txt".format(dataset))
test_1_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_test_1.txt".format(dataset))
dev_0_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_dev_0.txt".format(dataset))
dev_1_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/sentiment_dev_1.txt".format(dataset))
reference_0_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/reference_0.txt".format(dataset))
reference_1_out = os.path.join(data_dir,"{}/processed_files_with_bert_with_best_head/reference_1.txt".format(dataset))

In [5]:
## Model for performing Classification
model_cls = BertForSequenceClassification.from_pretrained(bert_classifier_model_dir, num_labels=2)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model_cls.to(device)
model_cls.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): BertLayerNorm()
              (dropout): Dropout(p=0.1, inplace=False)
            )
    

In [6]:
## Model to get the attention weights of all the heads
model = BertModel.from_pretrained(bert_classifier_model_dir)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model.to(device)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): BertLayerNorm()
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): BertLayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
        )
        (intermediate): BertIntermediate(
          (dense): Linear(in_features

In [7]:
max_seq_len=512 # Maximum sequence length 
sm = torch.nn.Softmax(dim=-1) ## Softmax over the batch

In [8]:
common_words=['is','are','was','were','has','have','had','a','an','the','this','that','these','those','there','how','i','we',
             'he','she','it','they','them','their','his','him','her','us','our', 'and','in','my','your','you', 'will', 'shall']
common_words_tokens = tokenizer.convert_tokens_to_ids(common_words)
not_to_remove_ids = tokenizer.convert_tokens_to_ids(["[CLS]","[SEP]", ".", "?", "!"])
not_to_remove_ids += common_words_tokens

In [16]:
def read_file(file_path):
    with open(file_path, encoding = "ISO-8859-1") as fp:
        data = fp.read().splitlines()
    return data

In [10]:
def create_output_file(original_sentences,processed_sentences, output_file, sentiment="<POS>"):
    with open(output_file,"w") as fp:
        for sen1,sen2 in zip(original_sentences,processed_sentences):
            if sen1 != None and sen2 != None:
                str1 = sentiment + " <CON_START> " + sen2 + " <START> " + sen1 + " <END>\n"
                fp.write(str1)

In [11]:
def create_ref_output_file(processed_sentences, output_file, sentiment="<POS>"):
    with open(output_file,"w") as fp:
        for sen in tqdm(processed_sentences):
            if sen != None:
                str1 = sentiment + " <CON_START> " + sen + " <START>\n"
                fp.write(str1)

In [12]:
def concate_files(inp_files, out_files):
    with open(out_files,"w") as fp:
        for file in inp_files:
            with open(file) as f:
                for line in f:
                    fp.write(line)

In [13]:
def run_attn_examples(input_sentences, layer, head, bs=128):
    """
    Returns Attention weights for selected Layer and Head along with ids and tokens
    of the input_sentence
    """
    ids = []
    ids_to_decode = [None for k in range(len(input_sentences))]
    tokens_to_decode = [None for k in range(len(input_sentences))]
    segment_ids = []
    input_masks = []
    attention_weights = [None for z in input_sentences]
    ## BERT pre-processing
    for j,sen in enumerate(tqdm(input_sentences)):
        
        text_tokens = tokenizer.tokenize(sen)
        if len(text_tokens) >= max_seq_len-2:
            text_tokens = text_tokens[:max_seq_len-4]
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        tokens_to_decode[j] = tokens
        temp_ids = tokenizer.convert_tokens_to_ids(tokens)
        ids_to_decode[j] = temp_ids
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))
        
        
        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    
    # Convert Ids to Torch Tensors
    ids = torch.tensor(ids) 
    segment_ids = torch.tensor(segment_ids)
    input_masks = torch.tensor(input_masks)
    
    steps = len(ids) // bs
    
    for i in trange(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(device)
        temp_segment_ids = temp_segment_ids.to(device)
        temp_input_masks = temp_input_masks.to(device)
        with torch.no_grad():
             _, _, attn = model(temp_ids, temp_segment_ids, temp_input_masks)
        # Concate Attention weights
        for j in range(len(attn[layer]['attn_probs'])):
            attention_weights[i * bs + j] = (attn[layer]['attn_probs'][j][head][0]).to('cpu')
    
    return attention_weights, ids_to_decode, tokens_to_decode

In [14]:
def prepare_data(aw, ids_to_decode, tokens_to_decode):
    out_sen = [None for i in range(len(aw))]
    for i in trange(len(aw)):
        #topv, topi = aw[i].topk(len(inps_tokens[i]))
        topv, topi = aw[i].topk(ids_to_decode[i].index(0))
        topi = topi.tolist()
        topv = topv.tolist()
        #print(i,train_0[i])
        #print(tokens_to_decode[i])
        #print("Original Top Indexes = {}".format(topi))
        topi = [topi[j] for j in range(len(topi)) if ids_to_decode[i][topi[j]] not in not_to_remove_ids] # remove noun and common words
        #print("After removing Nouns = {}".format(topi))
        topi = [topi[j] for j in range(len(topi)) if "##" not in tokens_to_decode[i][topi[j]]] # Remove half words
        #print("After removing Half-words = {}".format(topi))

        if (len(topi) < 4 and len(topi) > 0):
            topi = [topi[0]]
        elif(len(topi) < 8):
            topi = topi[:2]
        else:
            topi = topi[:3]

        #print("Final Topi = {}".format(topi))
        final_indexes = []
        count = 0
        count1 = 0
        #print(ids_to_decode[i], tokens_to_decode[i])
        while ids_to_decode[i][count] != 0:
            if count in topi:
                while ids_to_decode[i][count + count1 + 1] != 0:
                    if "##" in tokens_to_decode[i][count + count1 + 1]:
                        count1 += 1
                    else:
                        break
                count += count1
                count1 = 0
            else:
                final_indexes.append(ids_to_decode[i][count])
            count += 1

        #print(final_indexes)
        temp_out_sen = tokenizer.convert_ids_to_tokens(final_indexes)
        temp_out_sen = " ".join(temp_out_sen).replace(" ##", "").replace("[CLS]","").replace("[SEP]","")
        #print(temp_out_sen, "\n\n")
        out_sen[i] = temp_out_sen.strip()
    
    return out_sen

In [18]:
train_0_data = read_file(train_0)
train_1_data = read_file(train_1)
dev_0_data = read_file(dev_0)
dev_1_data = read_file(dev_1)
test_0_data = read_file(test_0)
test_1_data = read_file(test_1)
# ref_0_data = read_file(reference_0)
# ref_1_data = read_file(reference_1)

In [22]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(train_0_data, layer=0, head=1, bs=128)
train_0_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(train_0_data, train_0_out_sen, train_0_out, sentiment="<NEG>")


  0%|          | 0/13711 [00:00<?, ?it/s][A
  0%|          | 64/13711 [00:00<00:21, 639.32it/s][A
  1%|          | 134/13711 [00:00<00:20, 655.85it/s][A
  1%|▏         | 202/13711 [00:00<00:20, 658.65it/s][A
  2%|▏         | 278/13711 [00:00<00:19, 678.68it/s][A
  3%|▎         | 348/13711 [00:00<00:19, 677.71it/s][A
  3%|▎         | 420/13711 [00:00<00:19, 689.75it/s][A
  4%|▎         | 496/13711 [00:00<00:18, 704.36it/s][A
  4%|▍         | 577/13711 [00:00<00:17, 732.51it/s][A
  5%|▍         | 649/13711 [00:00<00:18, 725.16it/s][A
  5%|▌         | 720/13711 [00:01<00:18, 719.56it/s][A
  6%|▌         | 806/13711 [00:01<00:17, 756.49it/s][A
  6%|▋         | 882/13711 [00:01<00:17, 749.22it/s][A
  7%|▋         | 957/13711 [00:01<00:18, 695.03it/s][A
  7%|▋         | 1028/13711 [00:01<00:18, 696.93it/s][A
  8%|▊         | 1107/13711 [00:01<00:17, 722.31it/s][A
  9%|▊         | 1180/13711 [00:01<00:17, 721.76it/s][A
  9%|▉         | 1254/13711 [00:01<00:17, 724.36it/s][A

 75%|███████▌  | 10349/13711 [00:15<00:05, 647.23it/s][A
 76%|███████▌  | 10415/13711 [00:15<00:05, 583.11it/s][A
 76%|███████▋  | 10475/13711 [00:15<00:05, 566.94it/s][A
 77%|███████▋  | 10533/13711 [00:15<00:05, 565.96it/s][A
 77%|███████▋  | 10591/13711 [00:15<00:05, 524.45it/s][A
 78%|███████▊  | 10645/13711 [00:15<00:06, 493.15it/s][A
 78%|███████▊  | 10717/13711 [00:16<00:05, 543.25it/s][A
 79%|███████▉  | 10806/13711 [00:16<00:04, 614.06it/s][A
 79%|███████▉  | 10873/13711 [00:16<00:04, 605.91it/s][A
 80%|███████▉  | 10937/13711 [00:16<00:04, 599.64it/s][A
 80%|████████  | 11004/13711 [00:16<00:04, 616.05it/s][A
 81%|████████  | 11073/13711 [00:16<00:04, 635.28it/s][A
 81%|████████  | 11140/13711 [00:16<00:04, 641.47it/s][A
 82%|████████▏ | 11206/13711 [00:16<00:03, 645.11it/s][A
 82%|████████▏ | 11273/13711 [00:16<00:03, 651.26it/s][A
 83%|████████▎ | 11339/13711 [00:16<00:03, 628.32it/s][A
 83%|████████▎ | 11403/13711 [00:17<00:03, 624.00it/s][A
 84%|████████▎

KeyboardInterrupt: 

In [None]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(train_1_data, layer=0, head=1, bs=128)
train_1_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(train_1_data, train_1_out_sen, train_1_out, sentiment="<POS>")

In [None]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(dev_0_data, layer=0, head=1, bs=128)
dev_0_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(dev_0_data, dev_0_out_sen, dev_0_out, sentiment="<NEG>")

In [None]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(dev_1_data, layer=0, head=1, bs=128)
dev_1_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(dev_1_data, dev_1_out_sen, dev_1_out, sentiment="<POS>")

In [None]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(test_1_data, layer=0, head=1, bs=128)
test_1_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(test_1_data, test_1_out_sen, test_1_out, sentiment="<POS>")

In [None]:
aw, ids_to_decode, tokens_to_decode = run_attn_examples(test_0_data, layer=0, head=1, bs=128)
test_0_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
create_output_file(test_0_data, test_0_out_sen, test_0_out, sentiment="<NEG>")

In [21]:
# aw, ids_to_decode, tokens_to_decode = run_attn_examples(ref_1_data, layer=0, head=1, bs=128)
# ref_1_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
# create_ref_output_file(ref_1_data, ref_1_out_sen, reference_1_out, sentiment="<NEG>")

NameError: name 'ref_1_data' is not defined

In [None]:
# aw, ids_to_decode, tokens_to_decode = run_attn_examples(ref_0_data, layer=0, head=1, bs=128)
# ref_0_out_sen = prepare_data(aw, ids_to_decode, tokens_to_decode)
# create_ref_output_file(ref_0_data, ref_0_out_sen, reference_0_out, sentiment="<POS>")