In [2]:
import pandas as pd
import spacy
import pprint as pprint
import numpy as np
import os
import random
import torch
import math
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
device

device(type='cuda')

In [31]:
PROCESSED_DATA_DIR = "processed_data"
MODEL_NAME = "BERT_Attention"
model = BertModel.from_pretrained('bert-base-uncased').to(device) 
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
nlp = spacy.load('en')

In [32]:
def get_bert_embeddings(passage_text):
    passage_text_processed = nlp(passage_text)
    passage_with_separators = ' '.join(['[CLS]'] + [sent.text + ' [SEP]' for sent in passage_text_processed.sents])
    passage_with_separators_tokenized = tokenizer.tokenize(passage_with_separators)    
    model.eval()
    indexed_tokens = tokenizer.convert_tokens_to_ids(passage_with_separators_tokenized)
    tokens_tensor = torch.tensor([indexed_tokens]).to(device)

    with torch.no_grad():
        encoded_layers, _ = model(tokens_tensor)

    indices = [i for i, w in enumerate(passage_with_separators_tokenized) if (w not in ['[CLS]', '[SEP]'])]
    nonseparator_tokens = [w for i, w in enumerate(passage_with_separators_tokenized) if (w not in ['[CLS]', '[SEP]'])]
    nonseparators = torch.squeeze(encoded_layers[-1])[indices][:]

    attn_vectors_per_word = []
    encountered_words = []
    i = 0
    carry_over = None
    had_carry_over = False
    
    for w_i, word in enumerate(passage_text_processed):
        word = word.text.lower()
        first_attention_vector = nonseparators[i]
        current_word = ''
        if word == ' ':
            attn_vectors_per_word.append(first_attention_vector)
            continue
        if carry_over:
            current_word = carry_over
            carry_over = None
        while current_word[:len(word)] != word:
            current_token = nonseparator_tokens[i]
            current_word += (current_token if (current_token[:2] != '##') else current_token[2:])
            i += 1
        encountered_words.append(current_word)
        if not had_carry_over:
            attn_vectors_per_word.append(first_attention_vector)
        else:
            had_carry_over = False
        if len(current_word) > len(word):
            attn_vectors_per_word.append(first_attention_vector)
            carry_over = current_word[len(word):]
            had_carry_over = True
    #output = torch.stack(attn_vectors_per_word)
    assert len([word for word in passage_text_processed]) == len(attn_vectors_per_word)
    return attn_vectors_per_word

In [33]:
train_passages = np.load("processed_data/train_passage_list.npy")

In [34]:
train_embeddings = []
for i,passage in enumerate(train_passages):
    print("processing doc", i)
    if(i in long_passages_and_exceptions):
        continue
    embeddings = get_bert_embeddings(str(passage))
    train_embeddings.append(embeddings)

processing doc 0
processing doc 1
processing doc 2
processing doc 3
processing doc 4
processing doc 5
processing doc 6
processing doc 7
processing doc 8
processing doc 9
processing doc 10
processing doc 11
processing doc 12
processing doc 13
processing doc 14
processing doc 15
processing doc 16
processing doc 17
processing doc 18
processing doc 19
processing doc 20
processing doc 21
processing doc 22
processing doc 23
processing doc 24
processing doc 25
processing doc 26
processing doc 27
processing doc 28
processing doc 29
processing doc 30
processing doc 31
processing doc 32
processing doc 33
processing doc 34
processing doc 35
processing doc 36
processing doc 37
processing doc 38
processing doc 39
processing doc 40
processing doc 41
processing doc 42
processing doc 43
processing doc 44
processing doc 45
processing doc 46
processing doc 47
processing doc 48
processing doc 49
processing doc 50
processing doc 51
processing doc 52
processing doc 53
processing doc 54
processing doc 55
pr

processing doc 440
processing doc 441
processing doc 442
processing doc 443
processing doc 444
processing doc 445
processing doc 446
processing doc 447
processing doc 448
processing doc 449
processing doc 450
processing doc 451
processing doc 452
processing doc 453
processing doc 454
processing doc 455
processing doc 456
processing doc 457
processing doc 458
processing doc 459
processing doc 460
processing doc 461
processing doc 462
processing doc 463
processing doc 464
processing doc 465
processing doc 466
processing doc 467
processing doc 468
processing doc 469
processing doc 470
processing doc 471
processing doc 472
processing doc 473
processing doc 474
processing doc 475
processing doc 476
processing doc 477
processing doc 478
processing doc 479
processing doc 480
processing doc 481
processing doc 482
processing doc 483
processing doc 484
processing doc 485
processing doc 486
processing doc 487
processing doc 488
processing doc 489
processing doc 490
processing doc 491
processing d

processing doc 873
processing doc 874
processing doc 875
processing doc 876
processing doc 877
processing doc 878
processing doc 879
processing doc 880
processing doc 881
processing doc 882
processing doc 883
processing doc 884
processing doc 885
processing doc 886
processing doc 887
processing doc 888
processing doc 889
processing doc 890
processing doc 891
processing doc 892
processing doc 893
processing doc 894
processing doc 895
processing doc 896
processing doc 897
processing doc 898
processing doc 899
processing doc 900
processing doc 901
processing doc 902
processing doc 903
processing doc 904
processing doc 905
processing doc 906
processing doc 907
processing doc 908
processing doc 909
processing doc 910
processing doc 911
processing doc 912
processing doc 913
processing doc 914
processing doc 915
processing doc 916
processing doc 917
processing doc 918
processing doc 919
processing doc 920
processing doc 921
processing doc 922
processing doc 923
processing doc 924
processing d

In [35]:
test_passages = np.load("processed_data/test_passage_list.npy")

In [37]:
test_embeddings = []
for i,passage in enumerate(test_passages):
    print("processing doc", i)
    embeddings = get_bert_embeddings(str(passage))
    test_embeddings.append(embeddings)

processing doc 0
processing doc 1
processing doc 2
processing doc 3
processing doc 4
processing doc 5
processing doc 6
processing doc 7
processing doc 8
processing doc 9
processing doc 10
processing doc 11
processing doc 12
processing doc 13
processing doc 14
processing doc 15
processing doc 16
processing doc 17
processing doc 18
processing doc 19
processing doc 20
processing doc 21
processing doc 22
processing doc 23
processing doc 24
processing doc 25
processing doc 26
processing doc 27
processing doc 28
processing doc 29
processing doc 30
processing doc 31
processing doc 32
processing doc 33
processing doc 34
processing doc 35
processing doc 36
processing doc 37
processing doc 38
processing doc 39
processing doc 40
processing doc 41
processing doc 42
processing doc 43
processing doc 44
processing doc 45
processing doc 46
processing doc 47
processing doc 48
processing doc 49
processing doc 50
processing doc 51
processing doc 52
processing doc 53
processing doc 54
processing doc 55
pr

In [39]:
print(len(train_embeddings), len(test_embeddings), len(train_embeddings[0][0]), len(test_embeddings[0][0]))

1189 133 768 768


In [24]:
training_positive_pairs = np.load("./processed_data/train_positive_pairs.npy")
training_negative_pairs = np.load("./processed_data/train_positive_pairs.npy")
test_positive_pairs = np.load("./processed_data/test_positive_pairs.npy")
test_negative_pairs = np.load("./processed_data/test_positive_pairs.npy")

In [25]:
X = []
Y = []
X_test = []
Y_test = []

In [26]:
for i, doc_pairs in enumerate(training_positive_pairs):
    for coref_pair in doc_pairs:
        sample = [i, coref_pair[0][0], coref_pair[0][1], coref_pair[1][0], coref_pair[1][1]]
        X.append(sample)
        Y.append([1])
for i, doc_pairs in enumerate(training_negative_pairs):
    for coref_pair in doc_pairs:
        sample = [i, coref_pair[0][0], coref_pair[0][1], coref_pair[1][0], coref_pair[1][1]]
        X.append(sample)
        Y.append([0])

In [27]:
for i, doc_pairs in enumerate(test_positive_pairs):
    for coref_pair in doc_pairs:
        sample = [i, coref_pair[0][0], coref_pair[0][1], coref_pair[1][0], coref_pair[1][1]]
        X_test.append(sample)
        Y_test.append([1])
for i, doc_pairs in enumerate(test_negative_pairs):
    for coref_pair in doc_pairs:
        sample = [i, coref_pair[0][0], coref_pair[0][1], coref_pair[1][0], coref_pair[1][1]]
        X_test.append(sample)
        Y_test.append([0])

In [28]:
shuffled_idx = list(np.random.permutation(len(X)))
X = [X[i] for i in shuffled_idx]
Y = [Y[i] for i in shuffled_idx]

In [29]:
shuffled_idx = list(np.random.permutation(len(X_test)))
X_test = [X_test[i] for i in shuffled_idx]
Y_test = [Y_test[i] for i in shuffled_idx]