# BigBirdModel

Tokeniser: BigBirdTokenizer.pre_trained('google/bigbird-roberta-base')

Model: BigBirdForQuestionAnswering

Pre-trained model: [google/bigbird-roberta-base](https://huggingface.co/google/bigbird-roberta-base)

Reasoning: 
BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Since the context is of very long sequences (of length more than 4096), the BigBird model which relies on block sparse attention might be more efficient and feasible than BERT.

In [2]:
import json

### Data Prep

In [3]:
from collections.abc import Iterable

# To return a string with the contexts concatenated

def flatten(xs):
    for x in xs:
        if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
            yield from flatten(x)
        else:
            yield x

In [4]:
# To convert data into 3 lists: contexts, questions and answers respectively

def read_data(path):
    with open(path, 'r') as f:
      data = json.load(f)
    
    contexts = []
    questions = []
    answers = []
    
    for group in data:
        contexts.append(''.join(flatten(group['context'])))
        questions.append(group['question'])
        answers.append(group['answer'])
        
    return contexts, questions, answers

In [5]:
train_contexts, train_questions, train_answers = read_data('train_set.json')
val_contexts, val_questions, val_answers = read_data('dev_set.json')

In [6]:
# The Bir takes in input answer_start (start index of answer in context) and answer_end (end index of answer in context).
# Return a list of dictionaries with {text:answerstring, answer_start: start_idx, answer_end: end_idx}

def update_train_answers(answers, contexts):
    temp = []
    for answer, context in zip(answers,contexts):
        gold_text = answer
        start_idx = context.find(answer)
        # There are some yes/no answers not found in the context, hence assign them to a position outside the sequence
        # Based on documentation: "Position outside of the sequence are not taken into account for computing the loss."
        if start_idx == -1:
            temp.append({'text':answer, 'answer_start':len(context)+1, 'answer_end':len(context)+1})
        end_idx = start_idx + len(gold_text)
        if context[start_idx:end_idx] == gold_text:
            temp.append({'text':answer, 'answer_start':start_idx, 'answer_end':end_idx})
        else:
            for n in [1,2]:
                if context[start_idx-n:end_idx-n] == gold_text:
                    temp.append({'text':answer, 'answer_start':start_idx, 'answer_end':end_idx})
    return temp

In [7]:
train_answers_n = update_train_answers(train_answers,train_contexts)
val_answers_n = update_train_answers(val_answers,val_contexts)

### Tokenize Inputs

In [8]:
from transformers import BigBirdTokenizer

# Tokenise input using DistilBERT tokeniser from pretrained 'distilbert-base-uncased'
tokenizer = BigBirdTokenizer.from_pretrained('google/bigbird-roberta-base')

Downloading:   0%|          | 0.00/846k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/775 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/760 [00:00<?, ?B/s]

In [None]:
train_encodings = tokenizer(train_contexts, train_questions,train_answers, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, val_answers, truncation=True, padding=True)

In [None]:
# Add the start and end position tokens which is required for the BigBirdForQuestionAnswering model
def add_token_answers(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        start_positions.append(encodings.char_to_token(i,answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i,answers[i]['answer_end']))
        
        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        # end position cannot be found, char_to_token found space, so shift one token forward
        go_back = 1
        while end_positions[-1] is None:
            end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end']-go_back)
            go_back +=1
    encodings.update({
        'start_positions':start_positions,
        'end_positions':end_positions
                     })

In [None]:
add_token_answers(train_encodings,train_answers_n)

In [None]:
add_token_answers(val_encodings,val_answers_n)

In [None]:
print(train_encodings.keys())

In [None]:
import torch


class NLDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

train_dataset = NLDataset(train_encodings)
val_dataset = NLDataset(val_encodings)

### Fine-tuning

In [25]:
from transformers import BigBirdForQuestionAnswering

# Initialising the BigBirdForQuestionAnswering model using pre-trained 'google/bigbird-roberta-base'
model = BigBirdForQuestionAnswering.from_pretrained('google/bigbird-roberta-base')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_layer_norm.weight', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_layer_norm.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this mode

In [26]:
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm

In [27]:
# setup GPU/CPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')# move model over to detected device
model.to(device)

DistilBertForQuestionAnswering(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0): TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
            

In [None]:
# activate training mode of model
model.train()
# initialize adam optimizer with weight decay (reduces chance of overfitting)
optim = AdamW(model.parameters(), lr=5e-5)

# initialize data loader for training data
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

for epoch in range(3):
    # set model to train mode
    model.train()
    # setup loop (we use tqdm for the progress bar)
    loop = tqdm(train_loader, leave=True)
    
    for batch in loop:
        # initialize calculated gradients (from prev step)
        optim.zero_grad()
        # pull all the tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        # train model on batch and return outputs (incl. loss)
        outputs = model(input_ids, attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions)
        # extract loss
        loss = outputs[0]
        # calculate loss for every parameter that needs grad update
        loss.backward()
        # update parameters
        optim.step()
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())

Epoch 0:   0%|          | 12/5653 [02:30<18:23:10, 11.73s/it, loss=5.38]

In [None]:
model_path = 'models/bigbird-custom'
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

In [None]:
# switch model out of training mode
model.eval()

#val_sampler = SequentialSampler(val_dataset)
val_loader = DataLoader(val_dataset, batch_size=16)

acc = []

# initialize loop for progress bar
loop = tqdm(val_loader)
# loop through batches
for batch in loop:
    # we don't need to calculate gradients as we're not training
    total_tp = 0
    total_fp = 0
    total_fn = 0
    with torch.no_grad():
        # pull batched items from loader
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_true = batch['start_positions'].to(device)
        end_true = batch['end_positions'].to(device)
        # make predictions
        outputs = model(input_ids, attention_mask=attention_mask)
        # pull preds out
        start_pred = torch.argmax(outputs['start_logits'], dim=1)
        end_pred = torch.argmax(outputs['end_logits'], dim=1)
        
        # Calculate ACCURACY
        # calculate accuracy for both and append to accuracy list
        acc.append(((start_pred == start_true).sum()/len(start_pred)).item())
        acc.append(((end_pred == end_true).sum()/len(end_pred)).item())
        
        # Calculate F1
        # calculate True Positive, False Negative and False Positive
        for i in range(len(start_pred)):
            x = range(start_pred[i],end_pred[i])
            y = range(start_true[i],end_true[i])

            xs = set(x)
            ys = set(y)
            tp = len(xs&ys)
            fp = len(xs-ys)
            fn = len(ys-xs)
            total_tp += tp
            total_fp += fp
            total_fn += fn
        
# calculate average accuracy in total
acc = sum(acc)/len(acc)
precision = total_tp/(total_tp+total_fp)
recall=total_tp/(total_tp+total_fn)

In [None]:
print("T/F\tstart\tend\n")
for i in range(len(start_true)):
    print(f"true\t{start_true[i]}\t{end_true[i]}\n"
          f"pred\t{start_pred[i]}\t{end_pred[i]}\n")

In [None]:
print(acc)
print(recall)
print(precision)

### Load Pre-trained model 

In [None]:
model = BigBirdForQuestionAnswering.from_pretrained('models/bigbird-custom')

### Results
Accuracy: 