In [1]:
from tqdm import tqdm
import json
import numpy as np
from transformers import BertModel, BertConfig, BertTokenizer, BertForQuestionAnswering, BertPreTrainedModel
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import re

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

**DATASET, COLLATOR & DATALOADER**

In [2]:
'''
Dataloader returns a tuple of 
(IDs of tokens for BERT input i.e. [CLS]<question>[SEP]<candidate>[SEP],
attention mask,
token type ids required for 2-sentence input for BERT of type [<[CLS]>00000000<[SEP]>11111111<[SEP]>],
maximum sequence length of each batch)
'''

class NQDataset(Dataset):
  def __init__(self, ids):
    self.ids = ids
  def __len__(self):
    return len(self.ids)
  def __getitem__(self, index):
    return self.ids[index]

class Collator(object):
  def __init__(self, data_dict, new_token_dict, tokenizer, max_seq_len=384, max_ques_len=64):
    self.data_dict = data_dict
    self.new_token_dict = new_token_dict
    self.tokenizer = tokenizer
    self.max_seq_len = max_seq_len
    self.max_ques_len = max_ques_len

  def get_sample(self, data_id, candidate_idx):
    data = self.data_dict[data_id]
    question_tokens = self.tokenizer.tokenize(data['question_text'])[:self.max_ques_len]
    data_words = data['document_text'].split()

    max_ans_len = self.max_seq_len - len(question_tokens) - 3
    candidate = data['long_answer_candidates'][candidate_idx]
    candidate_start = candidate['start_token']
    candidate_end = candidate['end_token']
    candidate_words = data_words[candidate_start:candidate_end]


    for i, word in enumerate(candidate_words):
      if re.match(r'<.+>', word):
        if word in self.new_token_dict:
          candidate_words[i] = self.new_token_dict[word]
        else:
          candidate_words[i] = '<'

    candidate_tokens = []
    for i, word in enumerate(candidate_words):
      tokens = self.tokenizer.tokenize(word)
      if (len(candidate_tokens) + len(tokens)) > max_ans_len:
        break
      candidate_tokens.extend(tokens)

    input_tokens = ['[CLS]'] + question_tokens + ['[SEP]'] + candidate_tokens + ['[SEP]']
    input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens)

    return input_ids, candidate_start, candidate_end, len(input_ids)

  def __call__(self, batch_ids):
    batch_size = len(batch_ids)
    temp_batch_input_ids = []
    batch_seq_len = []
    batch_start_tokens = []
    batch_end_tokens = []
    batch_input_ids_temp = []

    for i, (data_id, candidate_idx) in enumerate(batch_ids):
      input_ids, start_token, end_token, seq_len = self.get_sample(data_id, candidate_idx)
      batch_input_ids_temp.append(input_ids)
      batch_start_tokens.append(start_token)
      batch_end_tokens.append(end_token)
      batch_seq_len.append(seq_len)

    batch_max_seq_len = max(batch_seq_len)
    batch_input_ids = np.zeros((batch_size, batch_max_seq_len), dtype=np.int64)
    batch_token_type_ids = np.zeros((batch_size, batch_max_seq_len), dtype=np.int64)

    for i in range(batch_size):
      input_ids = batch_input_ids_temp[i]
      batch_input_ids[i, :len(input_ids)] = input_ids
      SEP_ID = self.tokenizer.convert_tokens_to_ids('[SEP]')
      # to get in BERT format of 0s and 1s for 2 sentence-inputs
      batch_token_type_ids[i, :len(input_ids)] = [0 if k<=input_ids.index(SEP_ID) else 1 for k in range(len(input_ids))]

    batch_attention_mask = batch_input_ids > 0

    return torch.from_numpy(batch_input_ids), torch.from_numpy(batch_attention_mask), torch.from_numpy(batch_token_type_ids), batch_max_seq_len

**MODEL**

In [3]:
class BertForQuestionAnswering(BertPreTrainedModel):
  def __init__(self, config):
    super(BertForQuestionAnswering, self).__init__(config)
    self.num_labels = config.num_labels
    self.bert = BertModel(config)
    self.qa_outputs = nn.Linear(config.hidden_size, 2)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)
    self.classifier = nn.Linear(config.hidden_size, config.num_labels)
    self.init_weights()

  def forward(self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None):
    out = self.bert(input_ids, 
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids,
                    position_ids=position_ids,
                    head_mask=head_mask)
    
    seq_output = out[0]
    pooled_output = out[1]

    qa_logits = self.qa_outputs(seq_output)
    start_logits, end_logits = qa_logits.split(1, dim=-1)
    start_logits = start_logits.squeeze(-1)
    end_logits = end_logits.squeeze(-1)

    pooled_output = self.dropout(pooled_output)
    classifier_logits = self.classifier(pooled_output)

    return start_logits, end_logits, classifier_logits

**EVAL & PREDICT**

In [4]:
'''
This function does a single pass through the dev dataset and
returns the output of the last classifier layer of the trained model
which signifies the probabilities of that candidate representing the 
long answer for that question.
'''

def val(model, eval_dataloader, exid_candid_sorted, max_seq_len, num_labels, batch_size):
    model.eval()

    start_probs = np.zeros((len(exid_candid_sorted), max_seq_len), dtype=np.float32)
    end_probs = np.zeros((len(exid_candid_sorted), max_seq_len), dtype=np.float32)
    class_probs = np.zeros((len(exid_candid_sorted), num_labels), dtype=np.float32)

    for i, (batch_input_ids, batch_attention_mask, batch_token_type_ids, batch_max_seq_len) in tqdm(enumerate(eval_dataloader)):
      with torch.no_grad():
        start = i * batch_size
        if i == len(eval_dataloader)-1:
          end = len(eval_dataloader.dataset)
        else:
          end = start + batch_size
        batch_input_ids, batch_attention_mask, batch_token_type_ids = batch_input_ids.cuda(), batch_attention_mask.cuda(), batch_token_type_ids.cuda()

        start_logits, end_logits, class_logits = model(batch_input_ids, batch_attention_mask, batch_token_type_ids)
        start_probs[start:end, :batch_max_seq_len] += F.softmax(start_logits, dim=1).cpu().data.numpy()
        end_probs[start:end, :batch_max_seq_len] += F.softmax(end_logits, dim=1).cpu().data.numpy()
        class_probs[start:end] += F.softmax(class_logits, dim=1).cpu().data.numpy()
        
    return class_probs

In [8]:
'''
This function takes the class probabilities for all combinations of 
question and its corresponding long answer candidates. 
Assigns answer to each question and puts it in required prediction
format, and finally dumps into a JSON.
'''

def predict(class_probs, ids, exid_candid_sorted, data_dict):
    
    # Initialize a temporary dictionary to store prediction values.
    temp_dict = {}
    for doc_id in ids:
        temp_dict[doc_id] = {
                             'long_answer': {'start_token': -1, 'end_token': -1},
                             'long_answer_score': -1.0,
                             'short_answers': [{'start_token': -1, 'end_token': -1}],
                             'short_answers_score': -1.0,
                             'yes_no_answer': 'NONE'
                            }
    
    # For each doc_id (example_id) we check the long_ans_score i.e. the class_prob value at position1
    # [index 0 => 'NO ANSWER' index 1 => 'LONG ANSWER'].
    # The candidate with highest long answer score, is chosen as the predicted answer for that question.
    for i, (doc_id, candidate_idx) in tqdm(enumerate(exid_candid_sorted)):
      long_ans_score = class_probs[i, 1]
      if long_ans_score > temp_dict[doc_id]['long_answer_score']:
        temp_dict[doc_id]['long_answer_score'] = long_ans_score
        temp_dict[doc_id]['long_answer']['start_token'] = data_dict[doc_id]['long_answer_candidates'][candidate_idx]['start_token']
        temp_dict[doc_id]['long_answer']['end_token'] = data_dict[doc_id]['long_answer_candidates'][candidate_idx]['end_token']
        
    # Preparing final dict in expected predictions format
    final_dict = {}
    final_dict['predictions'] = []

    for doc_id in ids:
      pred_dict = {                       
                    'example_id': doc_id,
                    'long_answer': {'start_byte': -1, 'end_byte': -1, 'start_token': temp_dict[doc_id]['long_answer']['start_token'], 'end_token': temp_dict[doc_id]['long_answer']['end_token']},
                    'long_answer_score': str(temp_dict[doc_id]['long_answer_score']),
                    'short_answers': [{'start_byte': -1, 'end_byte': -1, 'start_token': temp_dict[doc_id]['short_answers'][0]['start_token'], 'end_token': temp_dict[doc_id]['short_answers'][0]['end_token']}],
                    'short_answers_score': str(temp_dict[doc_id]['short_answers_score']),
                    'yes_no_answer': temp_dict[doc_id]['yes_no_answer']
                  }
      final_dict['predictions'].append(pred_dict)
        
    # Dump to JSON file
    with open('predictions.json', 'w') as f:
        json.dump(final_dict, f)

In [9]:
def main():
    # Create structures to hold inputs
    input_file = 'data/dev/simplified-dev-sample.no-annot.jsonl'
    
    exid_candid = [] #list of tuples of (example id, candidate id) for each candidate in each example     
    candidate_lens = [] #list of candidate lengths
    exid_candid2candlen = {} #mapping of (example id, candidate id) to length of that candidate
    ids = [] #list of example ids
    data_dict = {} #compilation of data we require


    with open(input_file) as f:
      for n, line in tqdm(enumerate(f)):
        data = json.loads(line)
        data_id = data['example_id']
        ids.append(data_id)

        data_dict[data_id] = {'document_text': data['document_text'],
                              'question_text': data['question_text'],
                              'long_answer_candidates': data['long_answer_candidates']}

        question_len = len(data['question_text'].split())

        for i, candidate in enumerate(data['long_answer_candidates']):
          exid_candid.append((data_id, i))
          candidate_len = question_len + candidate['end_token'] - candidate['start_token']
          candidate_lens.append(candidate_len)
          exid_candid2candlen[(data_id, i)] = candidate_len
            
    # Sorting the list of (example id, candidate id) by candidate length for faster inference
    sorting_idx = np.argsort(np.array(candidate_lens))

    exid_candid_sorted = []
    for idx in sorting_idx:
      exid_candid_sorted.append(exid_candid[idx])
    
    # Hyperparameters
    max_seq_len = 360
    max_question_len = 64
    batch_size = 10
    
    # List of HTML tokens to be added to the vocab
    new_tokens = {'<P>':'qw1',
                  '<Table>':'qw2',
                  '<Tr>':'qw3',
                  '<Ul>':'qw4',
                  '<Ol>':'qw5',
                  '<Fl>':'qw6',
                  '<Li>':'qw7',
                  '<Dd>':'qw8',
                  '<Dt>':'qw9'}
    
    # Instantiating model
    model_path = "models/"
    config_file = BertConfig.from_pretrained(model_path)
    config_file.num_labels = 2       # 2 labels for 'long answer' and 'no answer'
    config_file.vocab_size = 30531   # 30522 + 9 HTML tokens
    tokenizer = BertTokenizer.from_pretrained(model_path, do_lower_case=True)
    tokenizer.add_tokens(list(new_tokens.values()))

    model = BertForQuestionAnswering.from_pretrained(model_path, config=config_file)
    model.resize_token_embeddings(len(tokenizer))
    
    val_dataset = NQDataset(exid_candid_sorted)
    val_collate = Collator(data_dict=data_dict, 
                            new_token_dict=new_tokens,
                            tokenizer=tokenizer,
                            max_seq_len=max_seq_len,
                            max_ques_len=max_question_len)
    val_dataloader = DataLoader(dataset=val_dataset,
                                 collate_fn=val_collate,
                                 batch_size=batch_size,
                                 shuffle=False,
                                 num_workers=8,
                                 pin_memory=True)
    
    class_probs = val(model.to(DEVICE), val_dataloader, exid_candid_sorted, max_seq_len, config_file.num_labels, batch_size)
    
    predict(class_probs=class_probs, ids=ids, exid_candid_sorted=exid_candid_sorted, data_dict=data_dict)

In [10]:
main()

200it [00:00, 2254.92it/s]
2742it [08:36,  5.31it/s]
27419it [00:00, 1105102.74it/s]
