In [1]:
import pandas as pd
import spacy
import pprint as pprint
import numpy as np
import os
import random
import torch
import math
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
device

device(type='cuda')

In [4]:
PROCESSED_DATA_DIR = "processed_data"
MODEL_NAME = "BERT_Attention"
model = BertModel.from_pretrained('bert-base-uncased').to(device) 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
nlp = spacy.load('en')

In [5]:
def get_bert_embeddings(passage_text):
    passage_text_processed = nlp(passage_text)
    passage_with_separators = ' '.join(['[CLS]'] + [sent.text + ' [SEP]' for sent in passage_text_processed.sents])
    passage_with_separators_tokenized = tokenizer.tokenize(passage_with_separators)    
    model.eval()
    indexed_tokens = tokenizer.convert_tokens_to_ids(passage_with_separators_tokenized)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)

    with torch.no_grad():
        encoded_layers, _ = model(tokens_tensor)

    indices = [i for i, w in enumerate(passage_with_separators_tokenized) if (w not in ['[CLS]', '[SEP]'])]
    nonseparator_tokens = [w for i, w in enumerate(passage_with_separators_tokenized) if (w not in ['[CLS]', '[SEP]'])]
    nonseparators = torch.squeeze(encoded_layers[-1])[indices][:]

    attn_vectors_per_word = []
    encountered_words = []
    i = 0
    carry_over = None
    had_carry_over = False
    
    for w_i, word in enumerate(passage_text_processed):
        word = word.text.lower()
        first_attention_vector = nonseparators[i]
        current_word = ''
        if word == ' ':
            attn_vectors_per_word.append(first_attention_vector)
            continue
        if carry_over:
            current_word = carry_over
            carry_over = None
        while current_word[:len(word)] != word:
            current_token = nonseparator_tokens[i]
            current_word += (current_token if (current_token[:2] != '##') else current_token[2:])
            i += 1
        encountered_words.append(current_word)
        if not had_carry_over:
            attn_vectors_per_word.append(first_attention_vector)
        else:
            had_carry_over = False
        if len(current_word) > len(word):
            attn_vectors_per_word.append(first_attention_vector)
            carry_over = current_word[len(word):]
            had_carry_over = True
    output = torch.stack(attn_vectors_per_word)
    assert len([word for word in passage_text_processed]) == len(attn_vectors_per_word)
    return output

In [6]:
def get_featues_from_pairs(X_batch, embeddings):
    batch_embeddings = []
    for x in X_batch:
        doc_id, a_start, a_end, b_start, b_end = x 
        doc_emb = embeddings[doc_id]
        emb_a = torch.sum(doc_emb[a_start:a_end+1], 0)
        emb_b = torch.sum(doc_emb[b_start:b_end+1], 0)
        emb_dot = torch.mul(emb_a, emb_a)
        emb_cat   = torch.cat((emb_a, emb_b), 0)
        emb   = torch.cat((emb_cat, emb_dot), 0)
        batch_embeddings.append(emb)
    return torch.stack(batch_embeddings)

In [7]:
def train_epoch(model, opt, criterion, batch_size, X_data, Y_data, embeddings, mode="train"):
    
    if(mode == "train"):
        model.train()
    else:
        model.eval()
    
    losses = []
    running_corrects = 0
    shuffled_idx = list(np.random.permutation(len(X_data)))
    minibatch_idxs = np.array_split(shuffled_idx, len(shuffled_idx)/batch_size) 
    ones = 0
    zeros = 0
    for minibatch_ids in minibatch_idxs:
        x_batch_raw = X_data[minibatch_ids]
        x_batch = get_featues_from_pairs(x_batch_raw, embeddings)
        y_batch = torch.tensor(Y_data[minibatch_ids]).type(torch.float32)
        x_batch = Variable(x_batch).to(device)
        y_batch = Variable(y_batch).to(device)
        opt.zero_grad()
        
        if(mode == "train"):
            y_hat = model(x_batch)
        else:
            with torch.no_grad():
                y_hat = model(x_batch)
        
        y_preds = (y_hat > 0.7).type(torch.float32)
        loss = criterion(y_hat, y_batch)
        corrects = float(torch.sum(y_preds == y_batch).item())
        running_corrects += corrects
        ones  += torch.sum(y_preds==1).item()
        zeros +=  torch.sum(y_preds==0).item()
        if(mode == "train"):
            loss.backward()
            opt.step()    
        losses.append(loss.item())
        
    print("ones and zeros", ones, zeros)
    accuracy = running_corrects * 1.0 / len(shuffled_idx)
    avg_loss = sum(losses) * 1.0 / len(losses)
    return avg_loss, accuracy

In [8]:
def predict(model, X, embeddings, batch_size=256):
    model.eval()
    idxs = range(0, len(X))
    minibatch_idxs = np.array_split(idxs, len(idxs)/min(batch_size,len(idxs)))
    y_preds_all = torch.Tensor().to(device)
    y_hats_all = torch.Tensor().to(device)
    for minibatch_ids in minibatch_idxs:
        x_batch_raw = X[minibatch_ids]
        x_batch = get_featues_from_pairs(x_batch_raw, embeddings)
        x_batch = Variable(x_batch).to(device)
        with torch.no_grad():
                y_hats = model(x_batch)
        y_preds = (y_hats > 0.7).type(torch.float32)
        y_preds_all = torch.cat((y_preds_all,y_preds))
        y_hats_all = torch.cat((y_preds_all,y_hats))
    return y_preds_all, y_hats_all

In [9]:
def evaluate(model, X, Y, embeddings):
    y_preds, y_hats = predict(model, X, embeddings)
    Y = Variable(torch.tensor(Y).type(torch.float32)).to(device)
    corrects = float(torch.sum(y_preds == Y).item())
    accuracy = corrects * 1.0 / Y.size()[0]
    return accuracy,y_preds, y_hats

In [10]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(768*3, 512)
        self.relu1 = nn.ReLU()
        self.dout = nn.Dropout(0.4)
        self.fc2 = nn.Linear(512, 64)
        self.relu2 = nn.ReLU()
        self.out = nn.Linear(64, 1)
        self.out_act = nn.Sigmoid()
        
    def forward(self, input_):
        a1 = self.fc1(input_)
        h1 = self.relu1(a1)
        dout = self.dout(h1)
        a2 = self.fc2(dout)
        h2 = self.relu2(a2)
        a3 = self.out(h2)
        y = self.out_act(a3)
        return y