In [1]:
from transformers import BertModel, BertTokenizer, BertConfig, BertForSequenceClassification
import os
import pandas as pd
import torch
import transformers
import time
import datetime
import numpy as np
import random
import os
from IPython.core.display import display, HTML

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
datasets = [
    {
        "name": "Stanford treebank",
        "prefix": "stanford_treebank",
        "train_path": "/data/sam/stanford_treebank/sst_train.csv",
        "dev_path": "/data/sam/stanford_treebank/sst_dev.csv",
        "test_path": "/data/sam/stanford_treebank/sst_test.csv",
        'classes': ['neg', 'pos']
    }
]

In [3]:
tokenizer = BertTokenizer.from_pretrained('/home/rajat/repos/practice/outputs_stanford_treebank_lr=2e-05_epochs=2')
model = BertForSequenceClassification.from_pretrained('/home/rajat/repos/practice/outputs_stanford_treebank_lr=2e-05_epochs=2')
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [4]:
def encode(data, tokenizer=BertTokenizer.from_pretrained('bert-base-uncased'), **kwargs):
    input_ids = []
    token_type_ids = []
    attention_mask = []
    for x in data:
        tokenized_x = tokenizer.encode_plus(x,
                                            max_length=128,
                                            add_special_tokens = True,
                                            pad_to_max_length=True,
                                            padding_side='right',
                                            return_token_type_ids=True,
                                            return_attention_mask=True)
        input_ids.append(tokenized_x['input_ids'])
        token_type_ids.append(tokenized_x['token_type_ids'])
        attention_mask.append(tokenized_x['attention_mask'])
    return torch.tensor(input_ids, dtype=torch.long), torch.tensor(token_type_ids, dtype=torch.long), torch.tensor(attention_mask, dtype=torch.long)

def read_data():
    train = pd.read_csv(datasets[0]['train_path'],index_col=0)
    dev = pd.read_csv(datasets[0]['dev_path'],index_col=0)[0:100]
    test = pd.read_csv(datasets[0]['test_path'],index_col=0)[0:100]
    return train, dev, test

def get_batches(df, batch_size=4, **kwargs):
    x, y = list(df['original_text'].values), torch.tensor(list(df['classification'].apply(lambda y: 1 if y == 'pos' else 0).values), dtype=torch.long)
    input_ids, token_type_ids, attention_mask = encode(x, **kwargs)
    tensor_dataset = torch.utils.data.TensorDataset(input_ids, token_type_ids, attention_mask, y)
    tensor_randomsampler = torch.utils.data.RandomSampler(tensor_dataset)
    tensor_dataloader = torch.utils.data.DataLoader(tensor_dataset, sampler=tensor_randomsampler, batch_size=batch_size)
    return tensor_dataloader

train_data, dev, test = read_data()
batch_dev = get_batches(dev, batch_size=2, tokenizer=tokenizer)
batch_test = get_batches(test, batch_size=2, tokenizer=tokenizer)

In [5]:
def get_new_sentence_tuple(sentence, top_idx, max_length, ys):
    token_ids = sentence
    new_token_ids = []
    new_attention_masks = []
    new_token_type_ids = []
    new_y = ys 
    for token_idx, token_id in enumerate(token_ids):
        if token_id == 0:
            break
        elif token_idx not in top_idx:
            new_token_ids.append(token_id.item())
            new_attention_masks.append(1)
            new_token_type_ids.append(1)
    
    padding_tokens = max_length - len(new_token_ids)
    new_token_ids += padding_tokens * [0]
    new_attention_masks += padding_tokens * [0]
    new_token_type_ids += padding_tokens * [0]
    return (new_token_ids, new_token_type_ids, new_attention_masks, new_y)

def get_score(model, batch):
    model.eval()
    with torch.no_grad():
        batch_gpu = (t.to(device) for t in batch)
        input_ids, token_type_ids, attention_mask, labels = batch_gpu
        outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        pred = outputs[0]
        softmax = torch.nn.Softmax(dim=-1)(pred)
    return softmax

def get_length_without_special_tokens(sentence):
    length = 0
    for i in sentence:
        if i == 0:
            break
        else:
            length += 1
    return length 

def get_result_per_batch(batch_cpu, attention_cpu, tokenizer, threshold, ys, lmbd=0.25):
    """
    batch_cpu: batch_size * max_len
    attention_cpu(list): n_layers * batch_size, heads, max_len, max_len
    """
    scores = np.zeros((12, 12, 1))
    for layer_idx, layer in enumerate(attention_cpu):
        for head_idx in range(len(layer[0])):
            token_ids, token_type_ids, attention_masks = [], [], []
            for sent_idx, sentence in enumerate(batch_cpu):
                sent_len = get_length_without_special_tokens(sentence) # token_ids
                _, top_idx = layer[sent_idx][head_idx][0].topk(int(sent_len * threshold))
                new_token_ids, new_token_type_ids, new_attention_masks, new_y = get_new_sentence_tuple(sentence, top_idx, 128, ys) # max_length
                token_ids.append(new_token_ids)
                token_type_ids.append(new_token_type_ids)
                attention_masks.append(new_attention_masks)
            tensor_dataset = torch.utils.data.TensorDataset(torch.tensor(token_ids), torch.tensor(token_type_ids), torch.tensor(attention_masks), ys)
            new_batch = next(iter(torch.utils.data.DataLoader(tensor_dataset, batch_size=2)))
            pred = get_score(model, new_batch)
            score = np.mean([(min(x[0].item(), x[1].item())+lmbd)/(max(x[0].item(), x[1].item())+lmbd) for x in pred.cpu()])
            scores[layer_idx][head_idx] = score
        #print("processing layer: {0}".format(layer_idx))
    return scores

def get_results(model, batch):
    attn_lt = None
    model.eval()
    results = np.zeros((12, 12, len(batch))) # layers, head, sentence_score
    for batch_idx, batch_cpu in enumerate(batch):
        with torch.no_grad():
            batch_gpu = (t.to(device) for t in batch_cpu)
            input_ids, token_type_ids, attention_mask, labels = batch_gpu
            outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        attention = outputs[-1]  # 12 * (batch_size,12,128,128)
        attention_cpu = list(t.to("cpu") for t in attention)
        result_per_batch = get_result_per_batch(batch_cpu[0], attention_cpu, tokenizer, 0.25, batch_cpu[3].cpu()) # layer, head, sentences
        for layer in range(12):
            for head in range(12):
                results[layer][head][batch_idx] = result_per_batch[layer][head][0]
        if batch_idx % 5 == 0:
            print("processing batch: {0}".format(batch_idx))
    return torch.tensor(results) 

def get_average_results_per_layer_head(results):
    return torch.mean(results, dim=-1)

def get_best_layer_head_pair(layer_head_mean):
    head_value, head_indices = torch.max(layer_head_mean, dim=1)   
    layer, layer_index = torch.max(head_value, dim=0)
    layer_index = layer_index.item()
    head_index = head_indices[layer_index].item()
    return layer_index, head_index

def get_label_string(label_softmax):
    pred = torch.argmax(label_softmax, dim=0).item()
    if pred == 1:
        return 'positive'
    else:
        return 'negative'

def get_label_string_2(label):
    if label == 1:
        return 'positive'
    else:
        return 'negative'

def print_attention(attention_batch, input_ids_batch, labels_batch, pred_labels_batch, threshold=0.1):
    html_batch = []
    for input_ids, attention, true_label, pred_label in zip(input_ids_batch, attention_batch, labels_batch, pred_labels_batch):
        html = []
        len_input_ids = get_length_without_special_tokens(input_ids)
        input_ids = input_ids[:len_input_ids]
        attention = attention[:len_input_ids]
        pred_label_string = get_label_string(pred_label)
        true_label_string = get_label_string_2(true_label)
        for input_id, attention_value in zip(input_ids, attention):
            token = tokenizer.convert_ids_to_tokens(input_id.item())
            attention_value = attention_value.item()
            html.append('<span style="background-color: rgb(255,255,0,{0})">{1}</span>'.format(10 * attention_value, token))
        if pred_label_string != true_label_string:
            html.append('<span><b>(pred - {0}, true - {1})</b></span>'.format(pred_label_string, true_label_string))
        html_string = " ".join(html)
        display(HTML(html_string))
        html_batch.append(html_string)
    return html_batch

def print_attention_for_layer_head(model, batch, layer_idx, head_idx):
    model.eval()
    for batch_cpu in batch:
        with torch.no_grad():
            batch_gpu = (t.to(device) for t in batch_cpu)
            input_ids, token_type_ids, attention_mask, labels = batch_gpu
            outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
            pred = outputs[0].cpu()
            softmax = torch.nn.Softmax(dim=-1)(pred)
            attention_batch = outputs[-1][layer_idx][:, head_idx, 0] # cls attention of the layer, head
            print_attention(attention_batch.cpu(), input_ids.cpu(), labels.cpu(), softmax.cpu())


In [6]:
res = get_results(model, batch_dev)

processing batch: 0
processing batch: 5
processing batch: 10
processing batch: 15
processing batch: 20
processing batch: 25
processing batch: 30
processing batch: 35
processing batch: 40
processing batch: 45


In [7]:
layer_head_mean = get_average_results_per_layer_head(res)

In [8]:
layer_idx, head_idx = get_best_layer_head_pair(layer_head_mean)

In [9]:
print_attention_for_layer_head(model, batch_test, layer_idx, head_idx)

In [10]:
layer_idx, head_idx

(11, 9)