In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
import pickle
import random
from collections import Counter
from tqdm import tqdm
import itertools
import pandas as pd
from itertools import islice
import numpy as np
import random
import torch
from torch.nn import CrossEntropyLoss
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn.functional import softmax

# Setting device on GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
print()

if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0)/1024**3,1), 'GB')

torch.backends.cudnn.deterministic = True

# Set the random seed manually for reproducibility.
torch.manual_seed(1234)

Using device: cuda

Tesla P100-PCIE-16GB
Memory Usage:
Allocated: 0.0 GB
Cached:    0.0 GB


<torch._C.Generator at 0x7fc8577480d0>

In [3]:
!pip install transformers
from transformers import BertTokenizer, BertForSequenceClassification, AdamW, get_linear_schedule_with_warmup



In [0]:
path = "drive/My Drive/FiQA/"

In [0]:
from evaluate import *
from utils import *

In [0]:
# Dictionary - key: qid, value: list of positive docid
train_qid_rel = load_pickle(path + "new-data/qid_rel_train.pickle")
test_qid_rel = load_pickle(path + "new-data/qid_rel_test.pickle")
valid_qid_rel = load_pickle(path + "new-data/qid_rel_valid.pickle")

# List of lists:
# Each element is a list containing [qid, positive docid, negative docid]
# train_set = load_pickle(path + 'new-data/data_50/train_set_50.pickle')
# valid_set = load_pickle(path + 'new-data/data_50/valid_set_50.pickle')
# train_set = load_pickle(path + 'new-data/data_25/train_set_25.pickle')
# valid_set = load_pickle(path + 'new-data/data_25/valid_set_25.pickle')
train_set = load_pickle(path + 'new-data/data_10/train_set_10.pickle')
valid_set = load_pickle(path + 'new-data/data_10/valid_set_10.pickle')
# train_set = load_pickle(path + 'new-data/train_set.pickle')
# valid_set = load_pickle(path + 'new-data/valid_set.pickle')

# List of lists:
# Each element is a list contraining [qid, list of pos docid, list of candidate docid]
# Contains candidates with all pos docids
test_set = load_pickle(path + 'new-data/data_50/test_set_50.pickle')
# Contains candidates retrieved by BM25
# May be missing pos docids in candidates
test_set_full = load_pickle(path + 'new-data/data_50/test_set_full_50.pickle')

# Dictionary mapping docid and qid to raw text
docid_to_text = load_pickle(path + 'new-data/docid_to_text.pickle')
qid_to_text = load_pickle(path + 'new-data/qid_to_text.pickle')

In [12]:
print("Number of training samples: {}".format(len(train_set)))
print("Number of validation samples: {}".format(len(valid_set)))
print("Number of test samples: {}".format(len(test_set)))

Number of training samples: 56810
Number of validation samples: 6320
Number of test samples: 333


In [13]:
# Example of the training set [qid, pos docid, neg docid]
print(train_set[:10])

[[0, 18850, 214003], [0, 18850, 473658], [0, 18850, 468016], [0, 18850, 319793], [0, 18850, 318321], [0, 18850, 490176], [0, 18850, 573077], [0, 18850, 523540], [0, 18850, 257835], [0, 18850, 499286]]


In [14]:
# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
# tokenizer = BertTokenizer.from_pretrained('bert-large-uncased', do_lower_case=True)

Loading BERT tokenizer...


HBox(children=(IntProgress(value=0, description='Downloading', max=231508, style=ProgressStyle(description_wid…




In [0]:
def get_input(questions, answers, max_seq_len):
    """
    Returns input objects for training:
        input_ids: List of lists
                Each element contains a list of padded/clipped numericalized
                tokens of the sequences including [CLS] and [SEP] tokens
                e.g. [[101, 2054, 2003, 102, 2449, 1029, 102], ...]
        token_type_ids: List of lists
                Each element contains a list of segment token indices to 
                indicate first and second portions of the inputs. 
                0 corresponds to a question token, 1 corresponds an answer token
                e.g. [[0, 0, 0, 0, 1, 1, 1], ...]
        att_masks: List of lists
                Each element contains a list of mask values
                Mask to avoid performing attention on padding token indices. 
                1 for tokens that are NOT MASKED, 0 for MASKED tokens.
                e.g. [[1, 1, 1, 1, 1, 1, 1], ...]
    -----------------
    questions: List of strings
            Each element contains a question string
    answers: List of strings
            Each element contains an asnwer string
    max_seq_len: int
            Maximum sequence length
    """
    input_ids = []
    token_type_ids = []
    att_masks = []

    for i in tqdm(range(len(questions))):
        a = questions[i]
        b = answers[i]

        # Tokenize the questions and answers, apply padding, and trim the vectors
        # to the max_seq_len
        encoded_seq = tokenizer.encode_plus(a, b, 
                                            max_length=max_seq_len, 
                                            pad_to_max_length=True, 
                                            return_token_type_ids=True,
                                            return_attention_mask = True)

        input_id = encoded_seq['input_ids']
        token_type_id = encoded_seq['token_type_ids']
        att_mask = encoded_seq['attention_mask']

        assert len(input_id) == max_seq_len, "Input id dimension incorrect!"
        assert len(token_type_id) == max_seq_len, "Token type id dimension incorrect!"
        assert len(att_mask) == max_seq_len, "Attention mask dimension incorrect!"

        input_ids.append(input_id)
        token_type_ids.append(token_type_id)
        att_masks.append(att_mask)

    return input_ids, token_type_ids, att_masks

In [0]:
# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels):
    # Get the column with the higher probability
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

## **Pairwise**

In [0]:
def get_pairwise_sequence_df(dataset):
    """
    Converts training and validation data into a df with relevancy labels
    and map the qid and docid to text.
    
    Returns data_df: df with columns qid, pos docid,
            neg docid, pos label, neg_label, question (text), 
            pos answer (text), neg answer (text)
    ---------------
    dataset: train or validation set in the form of list of lists
    """
    df = pd.DataFrame(dataset)
    df = df.rename(columns={0: 'qid', 1: 'pos_id', 2:'neg_id'})
    df['pos_label'] = df.apply(lambda x: 1, axis=1)
    df['neg_label'] = df.apply(lambda x: 0, axis=1)

    df['question'] = df['qid'].apply(lambda x: qid_to_text[x])
    df['pos_ans'] = df['pos_id'].apply(lambda x: docid_to_text[x])
    df['neg_ans'] = df['neg_id'].apply(lambda x: docid_to_text[x])

    return df

In [22]:
trainset = get_pairwise_sequence_df(train_set)
train_questions = trainset.question.values
train_pos_answers = trainset.pos_ans.values
train_neg_answers = trainset.neg_ans.values

train_pos_labels = trainset.pos_label.values
train_neg_labels = trainset.neg_label.values

train_pos_input, train_pos_type_id, train_pos_att_mask = get_input(train_questions, train_pos_answers, 256)
train_neg_input, train_neg_type_id, train_neg_att_mask = get_input(train_questions, train_neg_answers, 256)

100%|██████████| 56810/56810 [04:17<00:00, 220.78it/s]
100%|██████████| 56810/56810 [05:37<00:00, 168.57it/s]


In [23]:
validset = get_pairwise_sequence_df(valid_set)
valid_questions = validset.question.values
valid_pos_answers = validset.pos_ans.values
valid_neg_answers = validset.neg_ans.values

valid_pos_labels = validset.pos_label.values
valid_neg_labels = validset.neg_label.values

valid_pos_input, valid_pos_type_id, valid_pos_att_mask = get_input(valid_questions, valid_pos_answers, 256)
valid_neg_input, valid_neg_type_id, valid_neg_att_mask = get_input(valid_questions, valid_neg_answers, 256)

100%|██████████| 6320/6320 [00:29<00:00, 216.00it/s]
100%|██████████| 6320/6320 [00:28<00:00, 218.77it/s]


In [0]:
# save_pickle(path+'/data-bert/train_pos_labels_10.pickle', train_pos_labels)
# save_pickle(path+'/data-bert/train_neg_labels_10.pickle', train_neg_labels)
# save_pickle(path+'/data-bert/valid_pos_labels_10.pickle', valid_pos_labels)
# save_pickle(path+'/data-bert/valid_neg_labels_10.pickle', valid_neg_labels)

# save_pickle(path+'/data-bert/train_pos_input_256_10.pickle', train_pos_input)
# save_pickle(path+'/data-bert/train_neg_input_256_10.pickle', train_neg_input)
# save_pickle(path+'/data-bert/valid_pos_input_256_10.pickle', valid_pos_input)
# save_pickle(path+'/data-bert/valid_neg_input_256_10.pickle', valid_neg_input)

# save_pickle(path+'/data-bert/train_pos_type_id_256_10.pickle', train_pos_type_id)
# save_pickle(path+'/data-bert/train_neg_type_id_256_10.pickle', train_neg_type_id)
# save_pickle(path+'/data-bert/valid_pos_type_id_256_10.pickle', valid_pos_type_id)
# save_pickle(path+'/data-bert/valid_neg_type_id_256_10.pickle', valid_neg_type_id)

# save_pickle(path+'/data-bert/train_pos_mask_256_10.pickle', train_pos_att_mask)
# save_pickle(path+'/data-bert/train_neg_mask_256_10.pickle', train_neg_att_mask)
# save_pickle(path+'/data-bert/valid_pos_mask_256_10.pickle', valid_pos_att_mask)
# save_pickle(path+'/data-bert/valid_neg_mask_256_10.pickle', valid_neg_att_mask)

In [0]:
train_pos_labels = load_pickle(path+'/data-bert/train_pos_labels_10.pickle')
train_neg_labels = load_pickle(path+'/data-bert/train_neg_labels_10.pickle')
valid_pos_labels = load_pickle(path+'/data-bert/valid_pos_labels_10.pickle')
valid_neg_labels = load_pickle(path+'/data-bert/valid_neg_labels_10.pickle')

train_pos_input = load_pickle(path+'/data-bert/train_pos_input_256_10.pickle')
train_neg_input = load_pickle(path+'/data-bert/train_neg_input_256_10.pickle')
valid_pos_input = load_pickle(path+'/data-bert/valid_pos_input_256_10.pickle')
valid_neg_input = load_pickle(path+'/data-bert/valid_neg_input_256_10.pickle')

train_pos_type_id = load_pickle(path+'/data-bert/train_pos_type_id_256_10.pickle')
train_neg_type_id = load_pickle(path+'/data-bert/train_neg_type_id_256_10.pickle')
valid_pos_type_id = load_pickle(path+'/data-bert/valid_pos_type_id_256_10.pickle')
valid_neg_type_id = load_pickle(path+'/data-bert/valid_neg_type_id_256_10.pickle')

train_pos_mask = load_pickle(path+'/data-bert/train_pos_mask_256_10.pickle')
train_neg_mask = load_pickle(path+'/data-bert/train_neg_mask_256_10.pickle')
valid_pos_mask = load_pickle(path+'/data-bert/valid_pos_mask_256_10.pickle')
valid_neg_mask = load_pickle(path+'/data-bert/valid_neg_mask_256_10.pickle')

In [0]:
# train_pos_labels = train_pos_labels[:100]
# train_neg_labels = train_neg_labels[:100]
# train_pos_input = train_pos_input[:100]
# train_neg_input = train_neg_input[:100]
# train_pos_type_id = train_pos_type_id[:100]
# train_neg_type_id = train_neg_type_id[:100]
# train_pos_mask = train_pos_mask[:100]
# train_neg_mask = train_neg_mask[:100]

# valid_pos_labels = valid_pos_labels[:10]
# valid_neg_labels = valid_neg_labels[:10]
# valid_pos_input = valid_pos_input[:10]
# valid_neg_input = valid_neg_input[:10]
# valid_pos_type_id = valid_pos_type_id[:10]
# valid_neg_type_id = valid_neg_type_id[:10]
# valid_pos_mask = valid_pos_mask[:10]
# valid_neg_mask = valid_neg_mask[:10]

In [0]:
# Convert lists to PyTorch tensors
train_pos_inputs = torch.tensor(train_pos_input)
train_neg_inputs = torch.tensor(train_neg_input)
valid_pos_inputs = torch.tensor(valid_pos_input)
valid_neg_inputs = torch.tensor(valid_neg_input)

train_pos_labels = torch.tensor(train_pos_labels)
train_neg_labels = torch.tensor(train_neg_labels)
valid_pos_labels = torch.tensor(valid_pos_labels)
valid_neg_labels = torch.tensor(valid_neg_labels)

train_pos_type_ids = torch.tensor(train_pos_type_id)
train_neg_type_ids = torch.tensor(train_neg_type_id)
valid_pos_type_ids = torch.tensor(valid_pos_type_id)
valid_neg_type_ids = torch.tensor(valid_neg_type_id)

train_pos_masks = torch.tensor(train_pos_mask)
train_neg_masks = torch.tensor(train_neg_mask)
valid_pos_masks = torch.tensor(valid_pos_mask)
valid_neg_masks = torch.tensor(valid_neg_mask)

In [0]:
# Create DataLoaders to train the model in batches

batch_size = 16

# Create the DataLoader for our training set.
train_data = TensorDataset(train_pos_inputs, train_pos_type_ids, train_pos_masks, train_pos_labels, train_neg_inputs, train_neg_type_ids, train_neg_masks, train_neg_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

# Create the DataLoader for our validation set.
validation_data = TensorDataset(valid_pos_inputs, valid_pos_type_ids, valid_pos_masks, valid_pos_labels, valid_neg_inputs, valid_neg_type_ids, valid_neg_masks, valid_neg_labels)
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=batch_size)

In [0]:
def pairwise_loss(pos_scores, neg_scores):
    """
    Pairwise learning approach introduced in https://arxiv.org/pdf/1905.07588.pdf
    """

    cross_entropy_loss = -torch.log(pos_scores) - torch.log(1 - neg_scores)

    margin = 1

    hinge_loss = torch.max(torch.tensor(0, dtype=torch.float).to(device), margin - pos_scores + neg_scores)

    loss = (0.5 * cross_entropy_loss + 0.5 * hinge_loss)

    return loss

In [0]:
def train_pairwise(model, train_dataloader, optimizer, scheduler):

    # Store the average loss after each epoch so we can plot them.
    loss_values = []

    # Reset the loss and accuracy for each epoch
    total_loss = 0
    nb_train_steps = 0
    train_accuracy = 0

    # Set model in training mode
    model.train()

    # For each batch of training data...
    for step, batch in enumerate(tqdm(train_dataloader)):

        # batch contains eight PyTorch tensors:
        pos_input = batch[0].to(device)
        pos_type_id = batch[1].to(device)
        pos_mask = batch[2].to(device)
        pos_labels = batch[3].to(device)

        neg_input = batch[4].to(device)
        neg_type_id = batch[5].to(device)
        neg_mask = batch[6].to(device)
        neg_labels = batch[7].to(device)

        # Zero gradients
        model.zero_grad()

        # Compute predictinos for postive and negative QA pairs
        pos_outputs = model(pos_input, token_type_ids=pos_type_id, attention_mask=pos_mask, labels=pos_labels)
        neg_outputs = model(neg_input, token_type_ids=neg_type_id, attention_mask=neg_mask, labels=neg_labels)

        # Get the logits from the model for positive and negative QA pairs
        pos_logits = pos_outputs[1]
        neg_logits = neg_outputs[1]

        # Get the column of the relevant scores and apply activation function
        pos_scores = softmax(pos_logits, dim=1)[:,1]
        neg_scores = softmax(neg_logits, dim=1)[:,1]
        
        # Compute pairwise loss and get the mean of each batch
        loss = pairwise_loss(pos_scores, neg_scores).mean()

        # Move logits and labels to CPU
        p_logits = pos_logits.detach().cpu().numpy()
        p_labels = pos_labels.to('cpu').numpy()
        n_logits = neg_logits.detach().cpu().numpy()
        n_labels = neg_labels.to('cpu').numpy()

        # Calculate the accuracy for each batch
        tmp_pos_accuracy = flat_accuracy(p_logits, p_labels)
        tmp_neg_accuracy = flat_accuracy(n_logits, n_labels)

        # Accumulate the total accuracy.
        train_accuracy += tmp_pos_accuracy
        train_accuracy += tmp_neg_accuracy
        
        # Track the number of batches (2 for pos and neg accuracies)
        nb_train_steps += 2

        # Accumulate the training loss over all of the batches
        total_loss += loss.item()
    
        # Perform a backward pass to calculate the gradients.
        loss.backward()

        # Clip the norm of the gradients to 1.0.
        # This is to help prevent the "exploding gradients" problem.
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        # Update parameters and take a step using the computed gradient.
        optimizer.step()

        # Update scheduler
        scheduler.step()

    # Calculate the average loss over the training data.
    avg_train_loss = total_loss / len(train_dataloader)            
    
    # Store the loss value for plotting the learning curve.
    loss_values.append(avg_train_loss)

    # Compute accuracy for each epoch
    acc = train_accuracy/nb_train_steps

    return avg_train_loss, acc, loss_values

In [0]:
def validate_pairwise(model, validation_dataloader):

    # Set model in evaluation mode
    model.eval()

    # Tracking variables 
    total_loss = 0
    nb_eval_steps = 0
    eval_accuracy = 0

    # Evaluate data for one epoch
    for batch in tqdm(validation_dataloader):
        
        # Add batch to GPU
        batch = tuple(t.to(device) for t in batch)
        
        # Unpack the inputs from our dataloader
        pos_input, pos_type_id, pos_mask, pos_labels, neg_input, neg_type_id, neg_mask, neg_labels = batch
        
        # Telling the model not to compute or store gradients, saving memory and
        # speeding up validation
        with torch.no_grad():
            # Compute predictinos for postive and negative QA pairs
            pos_outputs = model(pos_input, token_type_ids=pos_type_id, attention_mask=pos_mask, labels=pos_labels)
            neg_outputs = model(neg_input, token_type_ids=neg_type_id, attention_mask=neg_mask, labels=neg_labels)

            # Get logits
            pos_logits = pos_outputs[1]
            neg_logits = neg_outputs[1]

            # Apply activation function
            pos_scores = softmax(pos_logits, dim=1)[:,1]
            neg_scores = softmax(neg_logits, dim=1)[:,1]
        
        loss = pairwise_loss(pos_scores, neg_scores).mean()

        # Move logits and labels to CPU
        p_logits = pos_logits.detach().cpu().numpy()
        p_labels = pos_labels.to('cpu').numpy()
        n_logits = neg_logits.detach().cpu().numpy()
        n_labels = neg_labels.to('cpu').numpy()

        # Calculate the accuracy for this batch of test sentences.
        tmp_pos_accuracy = flat_accuracy(p_logits, p_labels)
        tmp_neg_accuracy = flat_accuracy(n_logits, n_labels)

        # Accumulate the total accuracy.
        eval_accuracy += tmp_pos_accuracy
        eval_accuracy += tmp_neg_accuracy

        # Track the number of batches
        nb_eval_steps += 2

        total_loss += loss.item()

    avg_loss = total_loss / len(validation_dataloader)
    acc = eval_accuracy/nb_eval_steps

    return avg_loss, acc

## **Training**

In [33]:
# Load BertForSequenceClassification, the pretrained BERT model with a single linear classification layer on top

# model_path = "/content/drive/My Drive/FiQA/model/fin_model"
# model = BertForSequenceClassification.from_pretrained(model_path, cache_dir=None, num_labels=2)

model = BertForSequenceClassification.from_pretrained("bert-base-uncased", cache_dir=None, num_labels=2)
# model = BertForSequenceClassification.from_pretrained("bert-large-uncased", cache_dir=None, num_labels=2)
model.to(device)

HBox(children=(IntProgress(value=0, description='Downloading', max=361, style=ProgressStyle(description_width=…




HBox(children=(IntProgress(value=0, description='Downloading', max=440473133, style=ProgressStyle(description_…




BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [0]:
optimizer = AdamW(model.parameters(), lr = 2e-6, eps = 1e-8)

# Number of training epochs (authors recommend between 2 and 4)
epochs = 3

# Total number of training steps is number of batches * number of epochs.
total_steps = len(train_dataloader) * epochs

# Create the learning rate scheduler.
scheduler = get_linear_schedule_with_warmup(optimizer, 
                                            num_warmup_steps = 0,
                                            num_training_steps = total_steps)

In [0]:
# Lowest validation lost
best_valid_loss = float('inf')

for epoch in range(epochs):

    # Evaluate training loss
    # train_loss, train_acc, loss_values = train(model, train_dataloader, optimizer, scheduler)
    train_loss, train_acc, loss_values = train_pairwise(model, train_dataloader, optimizer, scheduler)
    # Evaluate validation loss
    # valid_loss, valid_acc = validate(model, validation_dataloader)
    valid_loss, valid_acc = validate_pairwise(model, validation_dataloader)
    
    # At each epoch, if the validation loss is the best
    # if valid_loss < best_valid_loss:
    #     best_valid_loss = valid_loss
    torch.save(model.state_dict(), path + 'model/' + str(epoch+1)+'_pairwise_10.pt')

    print("\n\n Epoch {}:".format(epoch+1))
    print("\t Train Loss: {} | Train Accuracy: {}%".format(round(train_loss, 3), round(train_acc*100, 2)))
    print("\t Validation Loss: {} | Validation Accuracy: {}%\n".format(round(valid_loss, 3), round(valid_acc*100, 2)))

 27%|██▋       | 953/3551 [12:45<34:42,  1.25it/s]

## **Evalulation**

In [0]:
def get_rank(model, test_set, qid_rel, max_seq_len):
    """
    Returns a dictionary - key: qid, value: list of ranked candidates
    -------------------
    model - PyTorch model
    test_set - List of lists:
            Each element is a list contraining 
            [qid, list of pos docid, list of candidate docid]
    qid_rel: Dictionary
            key: qid, value: list of relevant answer id
    max_seq_len: int
            Maximum sequence length
    """

    # Initiate empty dictionary
    qid_pred_rank = {}

    # Set model to evaluation mode
    model.eval()

    # For each element in the test set
    for i, seq in enumerate(tqdm(test_set)):
        
        # question id, list of rel answers, list of candidates
        qid, label, cands = seq[0], seq[1], seq[2]

        # Map question id to text
        q_text = qid_to_text[qid]

        # Convert list to numpy array
        cands_id = np.array(cands)

        # Empty list for the probability scores of relevancy
        scores = []

        # For each answer in the candidates
        for docid in cands:

            # Map the docid to text
            ans_text = docid_to_text[docid]

            # Create inputs for the model
            encoded_seq = tokenizer.encode_plus(q_text, ans_text, 
                                            max_length=max_seq_len, 
                                            pad_to_max_length=True, 
                                            return_token_type_ids=True,
                                            return_attention_mask = True)

            # Numericalized, padded, clipped seq with special tokens
            input_ids = torch.tensor([encoded_seq['input_ids']]).to(device)
            # Specify question seq and answer seq
            token_type_ids = torch.tensor([encoded_seq['token_type_ids']]).to(device)
            # Sepecify which position is part of the seq which is padded
            att_mask = torch.tensor([encoded_seq['attention_mask']]).to(device)

            # Don't calculate gradients
            with torch.no_grad():
            # Forward pass, calculate logit predictions for each QA pair
                outputs = model(input_ids, token_type_ids=token_type_ids, attention_mask=att_mask)

            # Get the predictions
            logits = outputs[0]

            # Apply activation function
            pred = softmax(logits, dim=1)
            # pred = torch.sigmoid(logits)

            # Move logits and labels to CPU
            pred = pred.detach().cpu().numpy()

            # Append relevant scores to list (where label = 1)
            scores.append(pred[:,1][0])

        print(scores)

        # Get the indices of the sorted similarity scores
        sorted_index = np.argsort(scores)[::-1]

        # Get the list of docid from the sorted indices
        ranked_ans = cands_id[sorted_index]

        # Dict - key: qid, value: ranked list of docids
        qid_pred_rank[qid] = ranked_ans

    return qid_pred_rank

In [0]:
toy_test_label = dict(itertools.islice(test_qid_rel.items(), 2))
toy_test = test_set[:2]
# toy_test = [[14, [398960], [84963, 14255, 398960]],
#             [68, [19183], [107584, 562777, 19183]],
#             [70, [327002], [107584, 327002, 19183]]]

In [0]:
model.load_state_dict(torch.load(path+'model/1_pairwise_25.pt'))

# qid_pred_rank = get_rank(model, test_set, test_qid_rel, max_seq_len=512)
qid_pred_rank = get_rank(model, toy_test, toy_test_label, max_seq_len=256)


  0%|          | 0/2 [00:00<?, ?it/s][A
 50%|█████     | 1/2 [00:07<00:07,  7.39s/it][A

[0.9999527, 0.999949, 0.18187124, 0.999884, 0.000103418584, 0.99862635, 0.9999441, 0.99994326, 0.9999306, 0.999944, 0.9999107, 0.011065468, 0.99990654, 0.9999043, 0.9996063, 0.99993455, 0.999925, 0.9999478, 0.9997577, 0.9284304, 0.124180555, 0.9999361, 0.9197168, 0.9999304, 0.999887, 0.99994993, 0.9999486, 0.00038539234, 0.999936, 0.9999198, 0.98532575, 0.9999027, 0.9998845, 0.99889565, 0.99994695, 0.89547914, 0.9347405, 0.99992275, 0.004050158, 0.0004175653, 0.70191604, 0.9997336, 0.61686075, 0.9996581, 0.99994326, 0.992488, 0.9998895, 0.99988997, 0.0010666391, 0.8459604, 0.9999151, 0.9999553, 0.99987435, 0.99994063, 0.9998692, 0.9999106, 0.77538884, 0.99957675, 0.9996093, 0.9997527, 0.99994993, 0.9999403, 0.8784555, 0.99993014, 0.99838126, 0.9999491, 0.9998518, 0.9999471, 0.9999548, 0.999882, 0.99994826, 0.99995434, 0.99989235, 0.7867769, 0.99981743, 0.74969816, 0.9998274, 0.9997148, 0.00039994254, 0.99994624, 0.087119706, 0.9999229, 0.99973744, 0.99985087, 0.9978452, 0.99994147, 0.9


100%|██████████| 2/2 [00:15<00:00,  7.49s/it][A
[A

[0.9999287, 0.0072370716, 0.99994373, 0.9999317, 0.9999299, 0.99995506, 0.99957365, 0.9999552, 0.00057716575, 0.9998795, 0.9999534, 0.99996006, 0.9999405, 0.9999261, 0.99993515, 0.7964893, 0.99969685, 0.9999058, 0.9999336, 0.8458747, 0.9999441, 0.9999472, 0.99995685, 0.99995244, 0.9999547, 0.99401265, 0.9999547, 0.99992085, 0.9999304, 0.9998442, 0.99995935, 0.99789006, 0.99994314, 0.999928, 0.9999423, 0.99995923, 0.99995387, 0.9999374, 0.9999492, 0.9999448, 0.999943, 0.99992156, 0.99994683, 0.9996908, 0.99995124, 0.9999485, 0.9999329, 0.00013012807, 0.99995685, 0.0009288644, 0.99994457, 0.9998565, 0.022961553, 0.99992335, 0.99994254, 0.9998714, 0.9999542, 0.9999484, 0.9996592, 0.99994934, 0.9999589, 0.99993706, 0.9998497, 0.99993575, 0.9999453, 0.99995613, 0.9999505, 0.9998099, 0.011141944, 0.99994516, 0.9999552, 0.9999478, 0.99994814, 0.9999567, 0.99995184, 0.99995375, 0.9999583, 0.00016959306, 0.9999527, 0.9999472, 0.9999584, 0.9999287, 0.99984765, 0.9999304, 0.00031214024, 0.9999368

In [0]:
k = 10

num_q = len(test_set)

# MRR, average_ndcg, precision, rank_pos = evaluate(qid_pred_rank, test_qid_rel, k)
MRR, average_ndcg, precision, rank_pos = evaluate(qid_pred_rank, toy_test_label, k)

print("\n\nAverage nDCG@{} for {} queries: {}\n".format(k, num_q, average_ndcg))

print("MRR@{} for {} queries: {}\n".format(k, num_q, MRR))

print("Average Precision@{}: {}".format(1, precision))



Average nDCG@10 for 333 queries: 0.4510568402073561

MRR@10 for 333 queries: 0.44166666666666665

Average Precision@1: 0.4


In [0]:
save_pickle(path+'rank/rank_bert_256_full.pickle', qid_pred_rank)