In [48]:
from transformers import BertModel, BertTokenizer, BertConfig, BertForSequenceClassification
import os
import pandas as pd
import torch
import transformers
import time
import datetime
import numpy as np
import random
import os
from sklearn.metrics import classification_report
from IPython.core.display import display, HTML

In [49]:
# Set the seed value all over the place to make this reproducible.
seed_val = 42

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

output_path = '/home/rajat/repos/explanations/output'
tokenizer = BertTokenizer.from_pretrained(output_path)
model = BertForSequenceClassification.from_pretrained(output_path)

In [50]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
datasets = [
    {
        "name": "Stanford treebank",
        "prefix": "stanford_treebank",
        "train_path": "data/stanford_treebank/sst_train.csv",
        "dev_path": "data/stanford_treebank/sst_dev.csv",
        "test_path": "data/stanford_treebank/sst_test.csv",
        'classes': ['neg', 'pos']
    }
]

In [51]:
def read_data():
    train = pd.read_csv(datasets[0]['train_path'],index_col=0)
    dev = pd.read_csv(datasets[0]['dev_path'],index_col=0)
    test = pd.read_csv(datasets[0]['test_path'],index_col=0)
    return train, dev, test
train_data, dev, test = read_data()

In [52]:
def get_length_without_special_tokens(sentence):
    length = 0
    for i in sentence:
        if i == 0:
            break
        else:
            length += 1
    return length 

def get_rationales(rationale, text, tokenizer, max_length):
    new_rationales = []
    rationale = rationale[: max_length - 2]
    text = text.split()[: max_length - 2]
    for r, word in zip(rationale, text):
        r = float(r)
        ids = tokenizer.encode(word, add_special_tokens=False)
        new_rationales += [r] * len(ids)
    new_rationales = [0] + new_rationales + [0]
    new_rationales += [0] * (max_length - len(new_rationales))
    len(new_rationales)
    return new_rationales

def encode(data, rationales, tokenizer=BertTokenizer.from_pretrained('bert-base-uncased'), **kwargs):
    input_ids = []
    token_type_ids = []
    attention_mask = []
    new_rationales = []
    for rationale, x in zip(rationales, data):
        tokenized_x = tokenizer.encode_plus(x,
                                            max_length=128,
                                            add_special_tokens = True,
                                            pad_to_max_length=True,
                                            padding_side='right',
                                            return_token_type_ids=True,
                                            return_attention_mask=True,
                                            return_special_tokens_mask=True)
        new_rationales.append(get_rationales(rationale, x, tokenizer, 128))
        input_ids.append(tokenized_x['input_ids'])
        token_type_ids.append(tokenized_x['token_type_ids'])
        attention_mask.append(tokenized_x['attention_mask'])
    return torch.tensor(input_ids, dtype=torch.long), torch.tensor(token_type_ids, dtype=torch.long), torch.tensor(attention_mask, dtype=torch.long), torch.tensor(new_rationales, dtype=torch.float)

In [53]:
def get_batches(df, batch_size=4, **kwargs):
    x, y = list(df['text'].values), torch.tensor(list(df['classification'].apply(lambda y: 1 if y == 'pos' else 0).values), device=device, dtype=torch.long)
    formatted_rationales = []
    for rationale in df['rationale'].values:
        rationale = rationale[1:-1]
        temp = [float(r) for r in rationale.split(', ')] 
        formatted_rationales.append(temp)

    input_ids, token_type_ids, attention_mask, new_rationales = encode(x, formatted_rationales, **kwargs)
    tensor_dataset = torch.utils.data.TensorDataset(input_ids, token_type_ids, attention_mask, y, new_rationales)
    tensor_randomsampler = torch.utils.data.RandomSampler(tensor_dataset)
    tensor_dataloader = torch.utils.data.DataLoader(tensor_dataset, sampler=tensor_randomsampler, batch_size=batch_size)
    return tensor_dataloader

In [54]:
def train(batch, epochs=2):
    model.train()
    for e in range(epochs):
        for i, batch_tuple in enumerate(batch):
            batch_tuple = (t.to(device) for t in batch_tuple)
            input_ids, token_type_ids, attention_mask, labels, true_rationale = batch_tuple
            outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
            loss, logits, hidden_states_output, attention_mask_output = outputs
            true_rationale = torch.nn.functional.softmax(true_rationale, dim=1)
            attention_layer = attention_mask_output[11]  # Found from earlier exp of selecting attention layer, head
            loss_explanation = torch.nn.functional.binary_cross_entropy(attention_layer[:,9, 0], true_rationale)
            if i % 100 == 0:
                print("loss - {0}".format(loss))
            model.zero_grad()        
            optimizer.zero_grad()
            total_loss = loss + loss_explanation
            total_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), parameters['max_grad_norm'])
            optimizer.step()
            scheduler.step()

In [55]:
batch_dev = get_batches(dev, batch_size=1, tokenizer=tokenizer)
batch_train = get_batches(train_data, batch_size=8, tokenizer=tokenizer)
batch_test = get_batches(test, batch_size=1, tokenizer=tokenizer)

epochs=3
parameters = {
    'learning_rate': 2e-5,
    'num_warmup_steps': 0,
    'num_training_steps': len(batch_train) * epochs,
    'max_grad_norm': 1
}
#model = BertModel.from_pretrained('bert-base-uncased', output_hidden_states=True, output_attentions=True)
model.to(device)
optimizer = transformers.AdamW(model.parameters(), lr=parameters['learning_rate'], correct_bias=False)
scheduler = transformers.get_linear_schedule_with_warmup(optimizer,
                                                         num_warmup_steps=parameters['num_warmup_steps'],
                                                         num_training_steps=parameters['num_training_steps'])


In [56]:
def evaluate(batch):
    results = []
    model.eval()
    for i, batch_cpu in enumerate(batch):
        batch_gpu = (t.to(device) for t in batch_cpu)
        input_ids, token_type_ids, attention_mask, labels, true_rationale = batch_gpu
        with torch.no_grad():
            outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=labels)
            loss, logits, hidden_states_output, attention_mask_output = outputs
            loss = loss.detach().cpu()
            logits =  logits.detach().cpu()
            labels_cpu = labels.detach().cpu()
            input_ids_cpu = input_ids.detach().cpu()
#             hidden_states_output = tuple(t.detach().cpu() for t in hidden_states_output)
#             attention_mask_output = tuple(t.detach().cpu() for t in attention_mask_output)
#             results.append(tuple([loss, logits, hidden_states_output, attention_mask_output]))
            results.append(tuple([loss, logits, labels_cpu, input_ids_cpu, None, None]))
    return results

In [57]:
train(batch_dev, epochs=epochs)

loss - 0.002777099609375
loss - 0.006067991256713867
loss - 0.0036361217498779297
loss - 0.0041658878326416016
loss - 0.0030498504638671875
loss - 0.002044200897216797
loss - 0.0037147998809814453
loss - 0.0003895759582519531
loss - 0.0015904903411865234
loss - 5.398853302001953
loss - 0.0014379024505615234
loss - 0.0008084774017333984
loss - 0.00067138671875
loss - 0.0001068115234375
loss - 0.00019311904907226562
loss - 0.00106048583984375
loss - 0.0003752708435058594
loss - 0.0004918575286865234
loss - 0.00020933151245117188
loss - 0.0002551078796386719
loss - 0.00018262863159179688
loss - 0.00014972686767578125
loss - 9.489059448242188e-05
loss - 9.012222290039062e-05
loss - 0.00010824203491210938
loss - 7.009506225585938e-05
loss - 5.6743621826171875e-05


In [58]:
def get_prediction(results):
    predictions = []
    true_labels = []
    for result in results:
        loss, logits, labels, input_ids, hidden_states_output, attention_mask_output = result
        prediction = torch.argmax(logits, dim=1).tolist()
        predictions += prediction
        true_labels += labels.tolist()
    return true_labels, predictions

In [59]:
results = evaluate(batch_dev)
y_true, y_pred = get_prediction(results)
target_names = ["negative", "positive"]
classification = classification_report(y_true, y_pred, target_names=target_names, digits=4)
print(classification)

              precision    recall  f1-score   support

    negative     1.0000    1.0000    1.0000       428
    positive     1.0000    1.0000    1.0000       444

    accuracy                         1.0000       872
   macro avg     1.0000    1.0000    1.0000       872
weighted avg     1.0000    1.0000    1.0000       872



In [60]:
results = evaluate(batch_test)
y_true, y_pred = get_prediction(results)
target_names = ["negative", "positive"]
classification = classification_report(y_true, y_pred, target_names=target_names, digits=4)
print(classification)

              precision    recall  f1-score   support

    negative     0.8707    0.9375    0.9029       912
    positive     0.9321    0.8603    0.8947       909

    accuracy                         0.8990      1821
   macro avg     0.9014    0.8989    0.8988      1821
weighted avg     0.9013    0.8990    0.8988      1821



In [63]:
def save():
    output_dir = './outputs_with_rationales/'
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    print("Saving model to {0}", output_dir)
    model.save_pretrained(output_dir)
    tokenizer.save_pretrained(output_dir)

In [64]:
save()

Saving model to {0} ./outputs_with_rationales/


In [65]:
def get_label_string(label_softmax):
    pred = torch.argmax(label_softmax, dim=0).item()
    if pred == 1:
        return 'positive'
    else:
        return 'negative'

def print_attention(attention_batch, input_ids_batch, labels_batch, pred_labels_batch, threshold=0.1):
    html_batch = []
    for input_ids, attention, true_label, pred_label in zip(input_ids_batch, attention_batch, labels_batch, pred_labels_batch):
        html = []
        len_input_ids = get_length_without_special_tokens(input_ids)
        input_ids = input_ids[:len_input_ids]
        attention = attention[:len_input_ids]
        pred_label_string = get_label_string(pred_label)
        true_label_string = get_label_string(true_label)
        for input_id, attention_value in zip(input_ids, attention):
            token = tokenizer.convert_ids_to_tokens(input_id.item())
            attention_value = attention_value.item()
            html.append('<span style="background-color: rgb(255,255,0,{0})">{1}</span>'.format(attention_value, token))
        if pred_label_string != true_label_string:
            html.append('<span><b>(pred - {0}, true - {1})</b></span>'.format(pred_label_string, true_label_string))
        html_string = " ".join(html)
        display(HTML(html_string))
        html_batch.append(html_string)
    return html_batch
def print_attention_for_layer_head(model, batch, layer_idx, head_idx):
    model.eval()
    for batch_cpu in batch:
        with torch.no_grad():
            batch_gpu = (t.to(device) for t in batch_cpu)
            input_ids, token_type_ids, attention_mask, labels = batch_gpu
            outputs = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
            pred = outputs[0].cpu()
            softmax = torch.nn.Softmax(dim=-1)(pred)
            attention_batch = outputs[-1][layer_idx][:, head_idx, 0] # cls attention of the layer, head
            print_attention(attention_batch.cpu(), input_ids.cpu(), labels.cpu(), softmax.cpu())


## Original Rationales

In [66]:
for batch in batch_test:
    batch = tuple(b.cpu() for b in batch)
    input_ids, token_type_ids, attention_mask, y, new_rationales = batch
    print_attention(new_rationales, input_ids, y, y)
