In [0]:
from google.colab import drive
drive.mount('/content/drive')

In [0]:
import pickle
from collections import Counter
from tqdm import tqdm
import itertools
import pandas as pd
from itertools import islice
import numpy as np
from keras.preprocessing.sequence import pad_sequences
import random
!pip install transformers
import torch
from transformers import BertTokenizer, BertModel, BertConfig
from torch.nn import CrossEntropyLoss
from transformers import AdamW
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler

# Setting device on GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

torch.backends.cudnn.deterministic = True

# Set the random seed manually for reproducibility.
torch.manual_seed(1234)

Using TensorFlow backend.


Using device: cuda

Tesla P100-PCIE-16GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


<torch._C.Generator at 0x7f6dfda399f0>

In [0]:
path = "drive/My Drive/FiQA/"

In [0]:
from evaluate import *

In [0]:
def take(n, iterable):
    "Return first n items of the iterable as a list"
    return list(islice(iterable, n))

def remove_empty(test_set):
    for index, row in enumerate(test_set):
        for doc in row[1]:
            if doc in empty_docs:
                del test_set[index]
    return test_set

def load_pickle(path):
    with open(path, 'rb') as f:
        return pickle.load(f)

def save_pickle(path, data):
    with open(path, 'wb') as handle:
        pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

def pad_seq(seq, max_seq_len):
    # Pad each seq to be the same length to process in batch.
    # pad_token = 0
    if len(seq) >= max_seq_len:
        seq = seq[:max_seq_len]
    else:
        seq += [0]*(max_seq_len - len(seq))
    return seq

In [0]:
# dict mapping of token to idx
vocab = load_pickle(path + 'vocab_full.pickle')
# dict mapping of docid to doc text
docid_to_text = load_pickle(path + 'label_ans.pickle')

# dict mapping of qid to question text
qid_to_text = load_pickle(path + 'qid_text.pickle')

train_qid_rel = load_pickle(path + "qid_rel_train.pickle")
test_qid_rel = load_pickle(path + "qid_rel_test.pickle")
valid_qid_rel = load_pickle(path + "qid_rel_valid.pickle")

train_set = load_pickle(path + 'data/data_train_50.pickle')
valid_set = load_pickle(path + 'data/data_valid_50.pickle')

test_set = load_pickle(path + 'data/data_test_500_rel.pickle')
test_set_full = load_pickle(path + 'data/data_test_500.pickle')

empty_docs = load_pickle(path+'empty_docs.pickle')

In [0]:
train_set = [x for x in train_set if x[1] not in empty_docs]
valid_set = [x for x in valid_set if x[1] not in empty_docs]

test_set = remove_empty(test_set)
test_set_full = remove_empty(test_set_full)

print("Number of training samples: {}".format(len(train_set)))
print("Number of validation samples: {}".format(len(valid_set)))
print("Number of test samples: {}".format(len(test_set)))

Number of training samples: 283707
Number of validation samples: 31582
Number of test samples: 330


In [0]:
# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

Loading BERT tokenizer...


In [0]:
label_to_ans = load_pickle(path+"data-bert/label_to_ans.pickle")
qid_to_text = load_pickle(path+"data-bert/qid_to_text.pickle")

In [0]:
def add_question_token(q_tokens):
    c = ["[CLS]"]
    s = ["[SEP]"]
    q_tokens = c + q_tokens
    q_tokens = q_tokens + s

    return q_tokens

def add_ans_token(a_tokens):
    s = ["[SEP]"]
    a_tokens = a_tokens + s

    return a_tokens

def clip(lst):
    max_seq_len = 512
    if len(lst) > max_seq_len:
        lst = lst[:max_seq_len]
    else:
        lst = lst
    
    return lst

def get_input_ids(sequences, max_seq_len):
    # Tokenize all of the sentences and map the tokens to thier word IDs.
    input_ids = []

    for seq in sequences:
        # `encode` will:
        #   (1) Tokenize the sentence.
        #   (2) Map tokens to their IDs.
        encoded_seq = tokenizer.convert_tokens_to_ids(seq)
        
        # Add the encoded sentence to the list.
        input_ids.append(encoded_seq)

    input_ids = pad_sequences(input_ids, maxlen=max_seq_len, dtype="long", 
                          value=0, truncating="post", padding="post")
    return input_ids

def get_att_mask(input_ids):
    # Create attention masks
    attention_masks = []

    # For each sentence...
    for sent in input_ids:
        
        # Create the attention mask.
        #   - If a token ID is 0, then it's padding, set the mask to 0.
        #   - If a token ID is > 0, then it's a real token, set the mask to 1.
        att_mask = [int(token_id > 0) for token_id in sent]
        
        # Store the attention mask for this sentence.
        attention_masks.append(att_mask)

    return attention_masks

In [0]:
def get_sequence_df(dataset):
    df = pd.DataFrame(dataset)
    df = df.rename(columns={0: 'qid', 1: 'pos', 2:'neg'})
    df_pos = df[['qid', 'pos']]
    df_pos = df_pos.rename(columns={'pos': 'docid'})
    df_pos['label'] = df_pos.apply(lambda x: 1, axis=1)
    df_pos = df_pos.drop_duplicates()

    df_neg = df[['qid', 'neg']]
    df_neg = df_neg.rename(columns={'neg': 'docid'})
    df_neg['label'] = df_neg.apply(lambda x: 0, axis=1)
    data_df = pd.concat([df_pos, df_neg]).sort_values(by=['qid'])

    data_df['question'] = data_df['qid'].apply(lambda x: qid_to_text[x])
    data_df['ans_cand'] = data_df['docid'].apply(lambda x: label_to_ans[x])
    data_df['ques_token'] = data_df['question'].apply(lambda x: add_question_token(x))
    data_df['ans_cand'] = data_df['ans_cand'].apply(lambda x: add_ans_token(x))

    data_df = data_df[['qid', 'docid', 'label', 'ans_cand','ques_token']]
    data_df['seq'] = data_df['ques_token'] + data_df['ans_cand']

    data_df['seq_clipped'] = data_df['seq'].apply(clip)
    # train['len'] = train['seq_clipped'].apply(lambda x: len(x))

    return data_df

In [0]:
def get_pairwise_sequence_df(dataset):
    df = pd.DataFrame(dataset)
    df = df.rename(columns={0: 'qid', 1: 'pos_id', 2:'neg_id'})
    df['pos_label'] = df.apply(lambda x: 1, axis=1)
    df['neg_label'] = df.apply(lambda x: 0, axis=1)

    df['question'] = df['qid'].apply(lambda x: qid_to_text[x])
    df['pos_ans'] = df['pos_id'].apply(lambda x: label_to_ans[x])
    df['neg_ans'] = df['neg_id'].apply(lambda x: label_to_ans[x])

    df['ques_token'] = df['question'].apply(lambda x: add_question_token(x))
    df['pos_ans'] = df['pos_ans'].apply(lambda x: add_ans_token(x))
    df['neg_ans'] = df['neg_ans'].apply(lambda x: add_ans_token(x))

    df = df[['qid', 'pos_id', 'neg_id', 'pos_label', 'neg_label', 'pos_ans', 'neg_ans', 'ques_token']]
    df['pos_seq'] = df['ques_token'] + df['pos_ans']
    df['neg_seq'] = df['ques_token'] + df['neg_ans']

    df['pos_seq_clipped'] = df['pos_seq'].apply(clip)
    df['neg_seq_clipped'] = df['neg_seq'].apply(clip)

    return df

## **Pairwise**

In [0]:
trainset = get_pairwise_sequence_df(train_set)
validset = get_pairwise_sequence_df(valid_set)

# Get the lists of sentences and their labels.
train_pos_seq = trainset.pos_seq_clipped.values
train_neg_seq = trainset.neg_seq_clipped.values
train_pos_labels = trainset.pos_label.values
train_neg_labels = trainset.neg_label.values

valid_pos_seq = validset.pos_seq_clipped.values
valid_neg_seq = validset.neg_seq_clipped.values
valid_pos_labels = validset.pos_label.values
valid_neg_labels = validset.neg_label.values

print(len(train_pos_seq))
print(len(valid_pos_seq))

# train_pos_seq = train_pos_seq[:300]
# train_neg_seq = train_neg_seq[:300]
# train_pos_labels = train_pos_labels[:300]
# train_neg_labels = train_neg_labels[:300]

# valid_pos_seq = valid_pos_seq[:30]
# valid_neg_seq = valid_neg_seq[:30]
# valid_pos_labels = valid_pos_labels[:30]
# valid_neg_labels = valid_neg_labels[:30]

max_seq_len = 512

train_pos_input = get_input_ids(train_pos_seq, max_seq_len)
train_neg_input = get_input_ids(train_neg_seq, max_seq_len)
valid_pos_input = get_input_ids(valid_pos_seq, max_seq_len)
valid_neg_input = get_input_ids(valid_neg_seq, max_seq_len)

train_pos_mask = get_att_mask(train_pos_input)
train_neg_mask = get_att_mask(train_neg_input)
valid_pos_mask = get_att_mask(valid_pos_input)
valid_neg_mask = get_att_mask(valid_neg_input)

283707
31582


In [0]:
# save_pickle(path+'/data-bert/train_pos_labels.pickle', train_pos_labels)
# save_pickle(path+'/data-bert/train_neg_labels.pickle', train_neg_labels)
# save_pickle(path+'/data-bert/valid_pos_labels.pickle', valid_pos_labels)
# save_pickle(path+'/data-bert/valid_neg_labels.pickle', valid_neg_labels)

save_pickle(path+'/data-bert/train_pos_input_512.pickle', train_pos_input)
save_pickle(path+'/data-bert/train_neg_input_512.pickle', train_neg_input)
save_pickle(path+'/data-bert/valid_pos_input_512.pickle', valid_pos_input)
save_pickle(path+'/data-bert/valid_neg_input_512.pickle', valid_neg_input)

save_pickle(path+'/data-bert/train_pos_mask_512.pickle', train_pos_mask)
save_pickle(path+'/data-bert/train_neg_mask_512.pickle', train_neg_mask)
save_pickle(path+'/data-bert/valid_pos_mask_512.pickle', valid_pos_mask)
save_pickle(path+'/data-bert/valid_neg_mask_512.pickle', valid_neg_mask)

In [0]:
train_pos_labels = load_pickle(path+'/data-bert/train_pos_labels.pickle')
train_neg_labels = load_pickle(path+'/data-bert/train_neg_labels.pickle')
valid_pos_labels = load_pickle(path+'/data-bert/valid_pos_labels.pickle')
valid_neg_labels = load_pickle(path+'/data-bert/valid_neg_labels.pickle')

train_pos_input = load_pickle(path+'/data-bert/train_pos_input_512.pickle')
train_neg_input = load_pickle(path+'/data-bert/train_neg_input_512.pickle')
valid_pos_input = load_pickle(path+'/data-bert/valid_pos_input_512.pickle')
valid_neg_input = load_pickle(path+'/data-bert/valid_neg_input_512.pickle')

train_pos_mask = load_pickle(path+'/data-bert/train_pos_mask_512.pickle')
train_neg_mask = load_pickle(path+'/data-bert/train_neg_mask_512.pickle')
valid_pos_mask = load_pickle(path+'/data-bert/valid_pos_mask_512.pickle')
valid_neg_mask = load_pickle(path+'/data-bert/valid_neg_mask_512.pickle')

In [0]:
train_pos_inputs = torch.tensor(train_pos_input)
train_neg_inputs = torch.tensor(train_neg_input)
valid_pos_inputs = torch.tensor(valid_pos_input)
valid_neg_inputs = torch.tensor(valid_neg_input)

train_pos_labels = torch.tensor(train_pos_labels)
train_neg_labels = torch.tensor(train_neg_labels)
valid_pos_labels = torch.tensor(valid_pos_labels)
valid_neg_labels = torch.tensor(valid_neg_labels)

train_pos_masks = torch.tensor(train_pos_mask)
train_neg_masks = torch.tensor(train_neg_mask)
valid_pos_masks = torch.tensor(valid_pos_mask)
valid_neg_masks = torch.tensor(valid_neg_mask)

In [0]:
print(len(train_pos_inputs))
print(len(valid_pos_inputs))

283707
31582


In [0]:
# The DataLoader needs to know our batch size for training, so we specify it 
# here.
# For fine-tuning BERT on a specific task, the authors recommend a batch size of
# 16 or 32.

batch_size = 8

# Create the DataLoader for our training set.
train_data = TensorDataset(train_pos_inputs, train_pos_masks, train_pos_labels, train_neg_inputs, train_neg_masks, train_neg_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# Create the DataLoader for our validation set.
validation_data = TensorDataset(valid_pos_inputs, valid_pos_masks, valid_pos_labels, valid_neg_inputs, valid_neg_masks, valid_neg_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

In [0]:
print(len(train_dataloader))
print(len(validation_dataloader))

35464
3948


In [0]:
import torch.nn as nn

class BertPairwiseClassifier(nn.Module):
    def __init__(self, bert):
        
        super().__init__()

        self.config = BertConfig()
        self.num_labels = self.config.num_labels
        self.bert = bert
        self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):

        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds)

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        return logits

In [0]:
bert = BertModel.from_pretrained('bert-base-uncased')

model = BertPairwiseClassifier(bert)

# Tell pytorch to run this model on the GPU.
model.to(device)

BertPairwiseClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_af

In [0]:
def pairwise_loss(pos_scores, neg_scores):

    cross_entropy_loss = -torch.log(pos_scores) - torch.log(1 - neg_scores)

    margin = 0.2

    hinge_loss = torch.max(torch.tensor(0, dtype=torch.float).to(device), margin - pos_scores + neg_scores)

    loss = (0.5 * cross_entropy_loss + 0.5 * hinge_loss)

    return loss

In [0]:
def train_pairwise(model, train_dataloader, optimizer):

    # Store the average loss after each epoch so we can plot them.
    loss_values = []

    # Reset the total loss for this epoch.
    total_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0

    model.train()

    # For each batch of training data...
    for step, batch in enumerate(tqdm(train_dataloader)):

        # `batch` contains three pytorch tensors:
        #   [0]: input ids 
        #   [1]: attention masks
        #   [2]: labels 
        pos_input = batch[0].to(device)
        pos_mask = batch[1].to(device)
        pos_labels = batch[2].to(device)

        neg_input = batch[3].to(device)
        neg_mask = batch[4].to(device)
        neg_labels = batch[5].to(device)

        model.zero_grad()        

        pos_scores = torch.sigmoid(model(pos_input, token_type_ids=None, attention_mask=pos_mask, labels=pos_labels))[:,1]
        neg_scores = torch.sigmoid(model(neg_input, token_type_ids=None, attention_mask=neg_mask, labels=neg_labels))[:,1]

        loss = pairwise_loss(pos_scores, neg_scores).mean()
        
        # Track the number of batches
        nb_eval_steps += 1

        # Accumulate the training loss over all of the batches
        total_loss += loss.item()
    
        # Perform a backward pass to calculate the gradients.
        loss.backward()

        # Clip the norm of the gradients to 1.0.
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Update parameters and take a step using the computed gradient.
        optimizer.step()

    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)            
    
    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)

    return avg_train_loss

In [0]:
def validate_pairwise(model, validation_dataloader):

    model.eval()

    # Tracking variables 
    total_loss = 0
    nb_eval_steps, nb_eval_examples = 0, 0

    # Evaluate data for one epoch
    for batch in tqdm(validation_dataloader):
        
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)
        
        # Unpack the inputs from our dataloader
        pos_input, pos_mask, pos_labels, neg_input, neg_mask, neg_labels = batch
        
        # Telling the model not to compute or store gradients, saving memory and
        # speeding up validation
        with torch.no_grad():        
            pos_scores = torch.sigmoid(model(pos_input, token_type_ids=None, attention_mask=pos_mask, labels=pos_labels))[:,1]
            neg_scores = torch.sigmoid(model(neg_input, token_type_ids=None, attention_mask=neg_mask, labels=neg_labels))[:,1]

        loss = pairwise_loss(pos_scores, neg_scores).mean()

        # Track the number of batches
        nb_eval_steps += 1

        total_loss += loss.item()

    avg_loss = total_loss / len(validation_dataloader) 

    return avg_loss

In [0]:
optimizer = AdamW(model.parameters(), lr=0.001)

# Lowest validation lost
best_valid_loss = float('inf')

n_epochs = 2

for epoch in range(n_epochs):

    # Evaluate training loss
    train_loss = train_pairwise(model, train_dataloader, optimizer)
    # Evaluate validation loss
    valid_loss = validate_pairwise(model, validation_dataloader)
    
    # At each epoch, if the validation loss is the best
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), path + 'model/' + str(epoch+1)+'_model-bert-pairwise.pt')

    print("\n\n Epoch {}:".format(epoch+1))
    print("\t Train Loss: {}".format(round(train_loss, 3)))
    print("\t Validation Loss: {}\n".format(round(valid_loss, 3)))

  2%|▏         | 616/35464 [08:58<8:27:20,  1.14it/s]

In [0]:
torch.save(model.state_dict(), path + 'model/2_model-bert-pairwise.pt')

In [0]:
print('Memory Usage:')
print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

Memory Usage:
Allocated: 15.1 GB
Cached:    15.2 GB


## **Pointwise**

In [0]:
validset = get_sequence_df(valid_set)

In [0]:
trainset = get_sequence_df(train_set)
validset = get_sequence_df(valid_set)

# Get the lists of sentences and their labels.
train_sequences = trainset.seq_clipped.values
train_labels = trainset.label.values

valid_sequences = validset.seq_clipped.values
valid_labels = validset.label.values

print(len(train_sequences))
print(len(valid_sequences))

train_sequences = train_sequences[:3000]
train_labels = train_labels[:3000]

valid_sequences = valid_sequences[:300]
valid_labels = valid_labels[:300]

max_seq_len = 512

train_input = get_input_ids(train_sequences, max_seq_len)
valid_input = get_input_ids(valid_sequences, max_seq_len)

train_att_mask = get_att_mask(train_input)
valid_att_mask = get_att_mask(valid_input)

298401
33143


In [0]:
# # train_labels = trainset.label.values
# # valid_labels = validset.label.values

# # save_pickle(path+'/data-bert/train_labels.pickle', train_labels)
# # save_pickle(path+'/data-bert/valid_labels.pickle', valid_labels)

# save_pickle(path+'/data-bert/train_input_512.pickle', train_input)
# save_pickle(path+'/data-bert/valid_input_512.pickle', valid_input)
# save_pickle(path+'/data-bert/train_mask_512.pickle', train_att_mask)
# save_pickle(path+'/data-bert/valid_mask_512.pickle', valid_att_mask)

In [0]:
# train_input = load_pickle(path+'/data-bert/train_input.pickle')
# valid_input = load_pickle(path+'/data-bert/valid_input.pickle')
# train_att_mask = load_pickle(path+'/data-bert/train_mask.pickle')
# valid_att_mask = load_pickle(path+'/data-bert/valid_mask.pickle')

train_input = load_pickle(path+'/data-bert/train_input_512.pickle')
valid_input = load_pickle(path+'/data-bert/valid_input_512.pickle')
train_att_mask = load_pickle(path+'/data-bert/train_mask_512.pickle')
valid_att_mask = load_pickle(path+'/data-bert/valid_mask_512.pickle')

train_labels = load_pickle(path+'/data-bert/train_labels.pickle')
valid_labels = load_pickle(path+'/data-bert/valid_labels.pickle')

In [0]:
train_input = train_input[:1000]
train_labels = train_labels[:1000]
train_att_mask = train_att_mask[:1000]

valid_input = valid_input[:100]
valid_labels = valid_labels[:100]
valid_att_mask = valid_att_mask[:100]

In [0]:
print(len(train_input))
print(len(valid_input))

1000
100


In [0]:
# Convert all inputs and labels into torch tensors, the required datatype 
# for our model.
train_inputs = torch.tensor(train_input)
validation_inputs = torch.tensor(valid_input)

train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(valid_labels)

train_masks = torch.tensor(train_att_mask)
validation_masks = torch.tensor(valid_att_mask)

In [0]:
# The DataLoader needs to know our batch size for training, so we specify it 
# here.
# For fine-tuning BERT on a specific task, the authors recommend a batch size of
# 16 or 32.

batch_size = 8

# Create the DataLoader for our training set.
train_data = TensorDataset(train_inputs, train_masks, train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# Create the DataLoader for our validation set.
validation_data = TensorDataset(validation_inputs, validation_masks, validation_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

In [0]:
print(len(train_dataloader))
print(len(validation_dataloader))

125
13


In [0]:
import numpy as np

# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

## **Model**

In [0]:
import torch.nn as nn

class BertClassifier(nn.Module):
    def __init__(self, bert):
        
        super().__init__()

        # self.config = BertConfig.from_pretrained("/content/drive/My Drive/FiQA/model/fin_model/config.json")
        self.config = BertConfig()
        self.num_labels = self.config.num_labels
        self.bert = bert
        # self.dropout = nn.Dropout(self.config.hidden_dropout_prob)
        self.dropout = nn.Dropout(0.5)
        self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, inputs_embeds=None, labels=None):

        outputs = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, head_mask=head_mask, inputs_embeds=inputs_embeds)

        pooled_output = outputs[1]

        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        outputs = (logits,) + outputs[2:]  # add hidden states and attention if they are here

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

In [0]:
# from transformers import BertForSequenceClassification, AdamW, BertConfig

# # Load BertForSequenceClassification, the pretrained BERT model with a single 
# # linear classification layer on top. 
# model = BertForSequenceClassification.from_pretrained("bert-base-uncased")
# model.to(device)

In [0]:
bert = BertModel.from_pretrained('bert-base-uncased')

# model_path = "/content/drive/My Drive/FiQA/model/fin_model"
# bert = BertModel.from_pretrained(model_path)

model = BertClassifier(bert)

# Tell pytorch to run this model on the GPU.
model.to(device)

BertClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=Tru

In [0]:
def train(model, train_dataloader, optimizer, scheduler):

    # Store the average loss after each epoch so we can plot them.
    loss_values = []

    # Reset the total loss for this epoch.
    total_loss = 0
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0

    model.train()

    # For each batch of training data...
    for step, batch in enumerate(tqdm(train_dataloader)):

        # `batch` contains three pytorch tensors:
        #   [0]: input ids 
        #   [1]: attention masks
        #   [2]: labels 
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        model.zero_grad()        

        # Perform a forward pass (evaluate the model on this training batch).
        # This will return the loss (rather than the model output) because we
        # have provided the `labels`.
        outputs = model(b_input_ids, 
                    token_type_ids=None, 
                    attention_mask=b_input_mask, 
                    labels=b_labels)
        # The call to `model` always returns a tuple, so we need to pull the 
        # loss value out of the tuple.
        loss = outputs[0]

        logits = outputs[1]

        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences.
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        
        # Accumulate the total accuracy.
        eval_accuracy += tmp_eval_accuracy

        # Track the number of batches
        nb_eval_steps += 1

        # Accumulate the training loss over all of the batches
        total_loss += loss.item()
    
        # Perform a backward pass to calculate the gradients.
        loss.backward()

        # Clip the norm of the gradients to 1.0.
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Update parameters and take a step using the computed gradient.
        optimizer.step()

        scheduler.step()

    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)            
    
    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)

    acc = eval_accuracy/nb_eval_steps

    return avg_train_loss, acc

In [0]:
def validate(model, validation_dataloader):

    model.eval()

    # Tracking variables 
    total_loss = 0
    eval_loss, eval_accuracy = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0

    # Evaluate data for one epoch
    for batch in tqdm(validation_dataloader):
        
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)
        
        # Unpack the inputs from our dataloader
        b_input_ids, b_input_mask, b_labels = batch
        
        # Telling the model not to compute or store gradients, saving memory and
        # speeding up validation
        with torch.no_grad():        
            outputs = model(b_input_ids, 
                            token_type_ids=None, 
                            attention_mask=b_input_mask,
                            labels=b_labels)
        
        loss = outputs[0]

        logits = outputs[1]

        # Move logits and labels to CPU
        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        
        # Calculate the accuracy for this batch of test sentences.
        tmp_eval_accuracy = flat_accuracy(logits, label_ids)
        
        # Accumulate the total accuracy.
        eval_accuracy += tmp_eval_accuracy

        # Track the number of batches
        nb_eval_steps += 1

        total_loss += loss.item()

    acc = eval_accuracy/nb_eval_steps
    avg_loss = total_loss / len(validation_dataloader) 

    return avg_loss, acc

In [0]:
from transformers import AdamW


# Note: AdamW is a class from the huggingface library (as opposed to pytorch) 
# I believe the 'W' stands for 'Weight Decay fix"
optimizer = AdamW(model.parameters(),
                  lr = 2e-5, # args.learning_rate - default is 5e-5, our notebook had 2e-5
                  eps = 1e-8 # args.adam_epsilon  - default is 1e-8.
                )

In [0]:
from transformers import get_linear_schedule_with_warmup

# Number of training epochs (authors recommend between 2 and 4)
epochs = 2

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0, # Default value in run_glue.py
                                            num_training_steps = total_steps)

In [0]:
# Lowest validation lost
best_valid_loss = float('inf')

for epoch in range(epochs):

    # Evaluate training loss
    train_loss, train_acc = train(model, train_dataloader, optimizer, scheduler)
    
    # Evaluate validation loss
    valid_loss, valid_acc = validate(model, validation_dataloader)
    
    # At each epoch, if the validation loss is the best
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
    torch.save(model.state_dict(), path + str(epoch+1)+'_model-bert-full.pt')

    print("\n\n Epoch {}:".format(epoch+1))
    print("\t Train Loss: {} | Train Accuracy: {}%".format(round(train_loss, 3), round(train_acc*100, 2)))
    print("\t Validation Loss: {} | Validation Accuracy: {}%\n".format(round(valid_loss, 3), round(valid_acc*100, 2)))

100%|██████████| 125/125 [00:58<00:00,  2.13it/s]
100%|██████████| 13/13 [00:01<00:00,  7.32it/s]
  0%|          | 0/125 [00:00<?, ?it/s]



 Epoch 1:
	 Train Loss: 0.17 | Train Accuracy: 96.8%
	 Validation Loss: 0.099 | Validation Accuracy: 98.08%



100%|██████████| 125/125 [00:58<00:00,  2.13it/s]
100%|██████████| 13/13 [00:01<00:00,  7.33it/s]




 Epoch 2:
	 Train Loss: 0.103 | Train Accuracy: 97.4%
	 Validation Loss: 0.291 | Validation Accuracy: 87.5%



In [0]:
for row in test_set:
    row[2] = [x for x in row[2] if x is not 0]

for row in test_set_full:
    row[2] = [x for x in row[2] if x is not 0]

In [0]:
def get_rank(model, test_set, qid_rel, max_seq_len):

    qid_pred_rank = {}

    model.eval()

    for i, seq in enumerate(tqdm(test_set)):
        
        qid, label, cands = seq[0], seq[1], seq[2]

        q_text = add_question_token(qid_to_text[qid])

        cands_id = np.array(cands)

        scores = []

        for docid in cands:

            ans_text = add_ans_token(label_to_ans[docid])

            seq_text = clip(q_text + ans_text)

            encoded_seq = tokenizer.convert_tokens_to_ids(seq_text)

            input_ids = pad_seq(encoded_seq, max_seq_len)

            att_mask = torch.tensor([[int(token_id > 0) for token_id in input_ids]]).to(device)
            
            input_ids = torch.tensor([input_ids]).to(device)

            with torch.no_grad():
            # Forward pass, calculate logit predictions
                outputs = model(input_ids, token_type_ids=None, attention_mask=att_mask)

            logits = outputs[0]

            pred = torch.sigmoid(logits)

            # Move logits and labels to CPU
            pred = pred.detach().cpu().numpy()

            scores.append(pred[:,1][0])

        print(scores)

        # Get the indices of the sorted similarity scores
        sorted_index = np.argsort(scores)[::-1]

        # Get the docid from the sorted indices
        ranked_ans = cands_id[sorted_index]

        # Dict - key: qid, value: ranked list of docids
        qid_pred_rank[qid] = ranked_ans

    return qid_pred_rank

In [0]:
toy_test_label = dict(itertools.islice(test_qid_rel.items(), 5))
toy_test = test_set[:5]

In [0]:
# toy_test = [[1, [14255], [84963, 354716, 14255, 522619]]]

In [0]:
model.load_state_dict(torch.load(path+'1_model-bert-full.pt'))

qid_pred_rank = get_rank(model, toy_test, toy_test_label, max_seq_len=512)

 20%|██        | 1/5 [00:10<00:43, 10.87s/it]

[0.09448364, 0.091464505, 0.09531899, 0.0935531, 0.08623368, 0.0944412, 0.08836689, 0.09304453, 0.088268794, 0.10072528, 0.094937, 0.093893334, 0.094075166, 0.094597846, 0.08860202, 0.093044005, 0.09131927, 0.098692544, 0.09282827, 0.09076655, 0.09082153, 0.09292794, 0.09316235, 0.09599658, 0.0929803, 0.09787936, 0.09260589, 0.09704332, 0.09380808, 0.09536452, 0.085752115, 0.09798474, 0.09313285, 0.085965045, 0.09459156, 0.094858445, 0.09067017, 0.091676995, 0.09342291, 0.09825895, 0.09526684, 0.09755189, 0.097160704, 0.08859482, 0.09224049, 0.093643315, 0.081098124, 0.09534079, 0.101478726, 0.090585925, 0.09185903, 0.090043984, 0.092732616, 0.09962214, 0.09344325, 0.09312794, 0.084370345, 0.09385696, 0.0914924, 0.09742052, 0.100261, 0.08776788, 0.091749705, 0.09152291, 0.09483297, 0.09555992, 0.08925647, 0.096978545, 0.09085019, 0.09232709, 0.09558287, 0.09908514, 0.09034418, 0.08477161, 0.08711484, 0.0859758, 0.09212534, 0.092701495, 0.0954817, 0.09008179, 0.09259034, 0.08682558, 0.0

 40%|████      | 2/5 [00:21<00:32, 10.87s/it]

[0.089038074, 0.087245695, 0.09686424, 0.090947896, 0.092642576, 0.089993894, 0.093893245, 0.090898484, 0.09217901, 0.08610029, 0.09310555, 0.093487635, 0.091983184, 0.08404333, 0.09254097, 0.09248821, 0.09485202, 0.0915587, 0.0966721, 0.0961607, 0.087108955, 0.09980025, 0.0997337, 0.089376, 0.089792706, 0.091721185, 0.09007152, 0.10155326, 0.09090053, 0.10064456, 0.10233374, 0.096182086, 0.095291115, 0.08861255, 0.0946176, 0.09427746, 0.08927234, 0.093150005, 0.08318458, 0.09034387, 0.09614644, 0.09459395, 0.094437726, 0.093337454, 0.093813345, 0.0964419, 0.102040835, 0.095333956, 0.09024939, 0.09653397, 0.089940116, 0.09411199, 0.096570194, 0.09758782, 0.09678069, 0.091985755, 0.08375582, 0.097223885, 0.091557294, 0.08654513, 0.09883658, 0.09191178, 0.089872785, 0.08798662, 0.098162286, 0.09278356, 0.08697244, 0.0869087, 0.097522974, 0.087528646, 0.08916158, 0.10115958, 0.09528215, 0.091390185, 0.088485666, 0.082588404, 0.09563435, 0.087906964, 0.097468555, 0.09543912, 0.08826826, 0.

 60%|██████    | 3/5 [00:32<00:21, 10.86s/it]

[0.08202664, 0.08568311, 0.09017658, 0.09526386, 0.096064754, 0.09406523, 0.08122158, 0.08740375, 0.09736456, 0.09248821, 0.085536025, 0.092340924, 0.090646796, 0.09027831, 0.08648409, 0.08956306, 0.092648506, 0.08822748, 0.094515994, 0.09280088, 0.093344755, 0.08665109, 0.092667185, 0.0940455, 0.08530649, 0.0976262, 0.089112744, 0.09485264, 0.09445631, 0.09315919, 0.09150314, 0.090109564, 0.08813301, 0.09674544, 0.09239764, 0.08385653, 0.08625316, 0.08744082, 0.08598483, 0.090867676, 0.08790739, 0.09706946, 0.08963014, 0.09425612, 0.09472933, 0.095610715, 0.09271945, 0.09560556, 0.10084954, 0.08711935, 0.09293156, 0.09433718, 0.08842018, 0.089276105, 0.092068575, 0.0874609, 0.0907906, 0.09506944, 0.096134074, 0.09049435, 0.09294961, 0.09199798, 0.084740415, 0.098958835, 0.08937444, 0.08042217, 0.092728786, 0.093305245, 0.09156391, 0.07172713, 0.07710424, 0.09171755, 0.08841608, 0.09191569, 0.08755051, 0.08577477, 0.09135618, 0.09266591, 0.09588268, 0.0903715, 0.09597814, 0.101081364, 

 80%|████████  | 4/5 [00:43<00:10, 10.86s/it]

[0.08825335, 0.101886, 0.092105426, 0.08922268, 0.09045948, 0.09309741, 0.08942796, 0.100345545, 0.09602147, 0.098036185, 0.08480027, 0.09468512, 0.09117491, 0.09034348, 0.09715811, 0.08716859, 0.092287794, 0.09813702, 0.0976157, 0.09295471, 0.089064606, 0.08913281, 0.087866575, 0.09903978, 0.084161416, 0.0975043, 0.09257793, 0.091103576, 0.08756962, 0.08769317, 0.093012266, 0.093756296, 0.09539414, 0.09252364, 0.09320732, 0.08504195, 0.08974177, 0.098553464, 0.10199982, 0.09665653, 0.08894987, 0.09428364, 0.08838431, 0.08442191, 0.090665646, 0.08710642, 0.095091455, 0.09346481, 0.08461243, 0.09154452, 0.100219674, 0.090355866, 0.09303701, 0.09172611, 0.091345705, 0.098321326, 0.09169278, 0.096176766, 0.085551746, 0.09531286, 0.087229915, 0.097901665, 0.090055674, 0.0923282, 0.084133625, 0.104487225, 0.09031674, 0.090192296, 0.09840611, 0.09317138, 0.08864275, 0.08518538, 0.095237195, 0.08590501, 0.09496602, 0.09142436, 0.09197719, 0.08614815, 0.0937525, 0.09265928, 0.080283955, 0.0919

100%|██████████| 5/5 [00:54<00:00, 10.86s/it]

[0.08976066, 0.09207014, 0.09648833, 0.09693424, 0.09231177, 0.09587623, 0.09568729, 0.09172134, 0.0854301, 0.085862175, 0.09642867, 0.088303484, 0.088249706, 0.09505611, 0.09269714, 0.09839677, 0.0833062, 0.08633762, 0.09099159, 0.09132982, 0.0929807, 0.08458787, 0.099441245, 0.09223015, 0.090655856, 0.10531307, 0.09141298, 0.112626895, 0.09821563, 0.083793834, 0.09458689, 0.09187455, 0.09185038, 0.08913043, 0.079459906, 0.08939326, 0.08727894, 0.087101, 0.091798656, 0.087874606, 0.10029497, 0.096677616, 0.08835464, 0.08555743, 0.0893991, 0.09152963, 0.09093151, 0.08821685, 0.084498346, 0.0912962, 0.089443825, 0.09526552, 0.09559642, 0.09019402, 0.092288226, 0.09382763, 0.09221077, 0.0926451, 0.0975671, 0.09019159, 0.086794145, 0.09204574, 0.095602676, 0.09369971, 0.08994653, 0.09076586, 0.09545356, 0.09875387, 0.09183202, 0.09166267, 0.096439086, 0.08584152, 0.090720855, 0.09433995, 0.089876145, 0.090831034, 0.09440039, 0.091119945, 0.092928275, 0.09070891, 0.08970008, 0.09155034, 0.




In [0]:
print(test_set[0])

[1, [14255], [84963, 354716, 522619, 418999, 322064, 141738, 303078, 355897, 71987, 219313, 310612, 257168, 410431, 541809, 466718, 89190, 66356, 329209, 283505, 596289, 362060, 113632, 113776, 362069, 129355, 381151, 81343, 46791, 360925, 294738, 157233, 263521, 509862, 90290, 598547, 397608, 442968, 81599, 68969, 434619, 71569, 476980, 365558, 251392, 118615, 292748, 133701, 246461, 327002, 510692, 368263, 304452, 224167, 234436, 98636, 107794, 388713, 524879, 399762, 174321, 282958, 245447, 345070, 361978, 183612, 566417, 539511, 19640, 420295, 525149, 302049, 12729, 223170, 41793, 356884, 288995, 472824, 597053, 40628, 546277, 262960, 275312, 112793, 248448, 363495, 223697, 406789, 349674, 550345, 11132, 265527, 189642, 18934, 527776, 314161, 576985, 146388, 156554, 424720, 109546, 531442, 144190, 399882, 469043, 89008, 494000, 296345, 193367, 299971, 234743, 189765, 354314, 128861, 431010, 158409, 523564, 177074, 257738, 213041, 283113, 187227, 533808, 444899, 507107, 38249, 20217

In [0]:
qid_pred_rank

{1: array([513362, 388042, 479542, 191848, 366074, 462184,  96158, 399762,
         21846, 141738, 413681, 264659, 399115, 444273, 304452, 361637,
        156063, 377621, 441038,  18934, 546509, 507371, 235823,  75005,
        133701, 550345, 146388, 219313, 220621, 539511,  99745, 254158,
        328853,  81886, 192843, 319234, 109250, 381151, 431010, 333954,
        363495, 109546,  16911, 205791, 177074, 173212,  98072, 197175,
        243822, 357717,  14255, 265527, 376129, 540325, 253541, 197205,
        174025, 593760, 245011,  12729,  18539,  69317, 524879, 209997,
        183612, 216077,  72053, 458079, 373481, 529853, 197352, 297938,
         28764, 166183, 364938, 216783, 344093, 253210, 158738,  89190,
        571232, 338096, 378484, 157233,  68969,  24421, 356884, 476980,
        494000,  87113, 576295, 427849, 286245, 583757, 588211, 128861,
         79397, 278460, 424001,  65771, 424008, 176054, 510181, 187384,
        216008, 322906, 410431,  74688, 520922, 172745, 46152

In [0]:
k = 500

num_q = len(toy_test)

MRR, average_ndcg, precision = evaluate(qid_pred_rank, toy_test_label, k)
# MRR, average_ndcg, precision = evaluate(qid_pred_rank, test_qid_rel, k)

print("\n\nAverage nDCG@{} for {} queries: {}\n".format(k, num_q, average_ndcg))

print("MRR@{} for {} queries: {}\n".format(k, num_q, MRR))

print("Average Precision@{}: {}".format(1, precision))



Average nDCG@500 for 5 queries: 0.1352267643706147

MRR@500 for 5 queries: 0.008511731568069007

Average Precision@1: 0.0


In [0]:
save_pickle(path+'rank/2_bert_test_full.pickle', qid_pred_rank)