## Load Tokenizer & SQuADv2 and preprocess the dataset

In [None]:
#importing required libraries and modules
import wandb
from datasets import load_dataset
from transformers import BertTokenizerFast, BertForQuestionAnswering, TrainingArguments, Trainer
import numpy as np
import torch
import torch.nn as nn
from torchcrf import CRF
from typing import Optional, Tuple, Union
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from collections import defaultdict

#set seed for reproducibility and wandb login
torch.manual_seed(seed=42)
wandb.login()

#load the dataset and tokenizer
dataset = load_dataset('rajpurkar/squad_v2')
dataset['train'] = dataset['train'].shuffle(seed=42).select(range(60000)) #select 60k samples for finetuning
tokenizer = BertTokenizerFast.from_pretrained('spanbert/spanbert-large-cased')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhimanshu22216[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
def preprocess_dataset(examples):
    """
    This function preprocesses the dataset by tokenizing the context and question,
    and finding the start and end positions of the answer in the context.
    The code below is heavily inspired by the Hugging Face NLP course for Question Answering.
    link to the course: https://huggingface.co/learn/nlp-course/en/chapter7/7
    """
    questions = [q.strip() for q in examples['question']] #remove leading and trailing whitespaces
    inputs = tokenizer(questions, examples['context'], max_length=512,
                truncation='only_second', stride=128, return_overflowing_tokens=True,
                return_offsets_mapping=True, padding='max_length')

    offset_mapping = inputs.pop('offset_mapping') #mapping between tokens and original text positions
    sample_map = inputs.pop('overflow_to_sample_mapping') #mapping of tokenized examples with original examples

    #map answers to the tokenized context
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i] #index in dataset
        answer = examples['answers'][sample_idx] #true answer for the question (ground truth)
        start_char = answer['answer_start'][0] if answer['answer_start'] else 0
        end_char = start_char + len(answer['text'][0]) if answer['text'] else 0

        sequence_ids = inputs.sequence_ids(i)

        #find the start and end of the context
        idx = 0
        while idx < len(sequence_ids) and sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while idx < len(sequence_ids) and sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        #if the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)

        #else find the token indices that correspond to the start and end of the answer
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)
            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs['start_positions'] = torch.tensor(start_positions)
    inputs['end_positions'] = torch.tensor(end_positions)
    return inputs


def exact_match_score(predictions, references):
    assert len(predictions) == len(references), "Lists must have the same length"
    matches = sum(p == r for p, r in zip(predictions, references))
    return matches / len(references) * 100 # Convert to percentage


def compute_em(preds):
    """
    This function is used to post-process the predictions and compute the exact match score for SpanBERT model.
    This function converts the start and end positions of the answer to a string
    (instead of transforming to actual text from the context) and then computes the exact match score.
    This should work as good as the variation where we convert the token indices to text
    because its just a string comparison so either they are same or not.
    """
    label = preds.label_ids #actual start and end positions of the answer
    pred = np.argmax(preds.predictions, axis=-1) #predicted start and end positions of the answer
    predictions, references = [], []
    for i in range(len(label[0])):
        predictions.append(str(pred[0][i]) + ',' + str(pred[1][i]))
        references.append(str(label[0][i]) + ',' + str(label[1][i]))
    return exact_match_score(predictions, references)

In [None]:
#preprocess the dataset
tokenized_dataset = dataset.map(preprocess_dataset, batched=True, remove_columns=dataset["train"].column_names)
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 60217
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'],
        num_rows: 11985
    })
})

## Fine-tuning spanbert

In [None]:
#load model and initialize wandb run
spanbert = BertForQuestionAnswering.from_pretrained('SpanBERT/spanbert-large-cased')

wandb.init(entity='cv-himanshu', project='finetuning-spanberts')

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at SpanBERT/spanbert-large-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
#define training arguments and train with Trainer API
training_args = TrainingArguments(output_dir='./results', eval_strategy='epoch',
                learning_rate=1e-5, num_train_epochs=6, weight_decay=0.01,
                per_device_train_batch_size=64, per_device_eval_batch_size=64,
                report_to='wandb', logging_dir='./logs', logging_steps=50,
                label_names=['start_positions', 'end_positions'],)

trainer = Trainer(model=spanbert, args=training_args, processing_class=tokenizer,
    train_dataset=tokenized_dataset['train'], eval_dataset=tokenized_dataset['validation'])

trainer.train()



Epoch,Training Loss,Validation Loss
1,1.428,1.221892
2,1.0008,0.946458
3,0.7999,0.879596
4,0.6721,0.894592
5,0.5628,0.910687
6,0.5348,0.928524




TrainOutput(global_step=2826, training_loss=0.993496185486315, metrics={'train_runtime': 6198.6139, 'train_samples_per_second': 58.288, 'train_steps_per_second': 0.456, 'total_flos': 3.3554369366983066e+17, 'train_loss': 0.993496185486315, 'epoch': 6.0})

In [6]:
#predict on validation set and compute the exact match score
preds = trainer.predict(tokenized_dataset['validation'])

spanbert_em = compute_em(preds)
print(spanbert_em)
print(f"SpanBERT Exact Match Score: {spanbert_em:.2f} %")

wandb.finish()



70.27951606174385
SpanBERT Exact Match Score: 70.28 %


0,1
eval/loss,█▂▁▁▂▂
eval/runtime,▁▃▁███
eval/samples_per_second,█▆█▁▁▁
eval/steps_per_second,█▇█▂▁▁
test/loss,▁
test/runtime,▁
test/samples_per_second,▁
test/steps_per_second,▁
train/epoch,▁▁▁▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇████

0,1
eval/loss,0.92852
eval/runtime,70.7082
eval/samples_per_second,169.499
eval/steps_per_second,1.329
test/loss,0.92852
test/runtime,70.8578
test/samples_per_second,169.142
test/steps_per_second,1.327
total_flos,3.3554369366983066e+17
train/epoch,6.0


## Fine-tuning SpanBERT_CRF

In [None]:
class SpanBERTCRF(BertForQuestionAnswering):
    """
    This is a custom model class that extends the SpanBERT model and adds a CRF layer on top of it.
    The CRF layer is used to predict the start and end positions of the answer in the context.
    It also has some additional hidden layers to learn better representations.
    This code is heavily inspired by the BertForQuestionAnswering class in the Hugging Face Transformers library.
    link to the code: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bert/modeling_bert.py 
    """
    def __init__(self, config):
        """
        Initialize the model with SpanBERT and additional hidden layers.
        """
        super().__init__(config)
        # Additional hidden layers
        self.fc1 = nn.Linear(config.hidden_size, config.hidden_size // 2)
        self.fc2 = nn.Linear(config.hidden_size // 2, config.hidden_size // 4)
        
        self.start_classifier = nn.Linear(config.hidden_size // 4, self.num_labels)
        self.end_classifier = nn.Linear(config.hidden_size // 4, self.num_labels)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.start_crf = CRF(num_tags=config.num_labels, batch_first=True)  # CRF for start positions
        self.end_crf = CRF(num_tags=config.num_labels, batch_first=True)  # CRF for end positions

    def forward( self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        start_positions: Optional[torch.Tensor] = None,
        end_positions: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
        """
        Defines the forward pass for the model.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        #get the outputs from SpanBERT model
        outputs = self.bert( input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        #get the output of last hidden layer from SpanBERT
        sequence_output = outputs[0]  # Shape: (batch_size, sequence_length, hidden_size)

        #pass the output through additional hidden layers
        hidden = self.fc1(sequence_output)
        hidden = self.relu(hidden)
        hidden = self.dropout(hidden)
        hidden = self.fc2(hidden)
        hidden = self.relu(hidden)
        hidden = self.dropout(hidden)

        #get the logits for start and end positions
        start_logits = self.start_classifier(hidden)  # (batch_size, seq_length, num_labels)
        end_logits = self.end_classifier(hidden)  # (batch_size, seq_length, num_labels)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            batch_size, seq_length = start_logits.shape[:2]

            # Create label tensors for CRF
            start_labels = torch.zeros((batch_size, seq_length), dtype=torch.long, device=input_ids.device)
            end_labels = torch.zeros((batch_size, seq_length), dtype=torch.long, device=input_ids.device)

            for i in range(batch_size):
                start, end = start_positions[i].item(), end_positions[i].item()
                if 0 <= start < seq_length:
                    start_labels[i, start] = 1 # Mark the correct start position
                if 0 <= end < seq_length:
                    end_labels[i, end] = 1 # Mark the correct end position

            # Compute CRF loss separately for start and end logits
            start_loss = -self.start_crf(start_logits, start_labels, mask=attention_mask.bool(), reduction="mean")
            end_loss = -self.end_crf(end_logits, end_labels, mask=attention_mask.bool(), reduction="mean")
            total_loss = (start_loss + end_loss) / 2

        #perform crf decoding to get the start and end positions
        start_predictions = self.start_crf.decode(start_logits, mask=attention_mask.bool())
        end_predictions = self.end_crf.decode(end_logits, mask=attention_mask.bool())

        #convert the predictions to one-hot encoded tensors
        bs, sq = start_logits.shape[:2]
        start_preds = torch.zeros((bs, sq), device=input_ids.device)
        end_preds = torch.zeros((bs, sq), device=input_ids.device)

        for i, (ss, es) in enumerate(zip(start_predictions, end_predictions)):
            temp = ss.index(1) if 1 in ss else 0
            start_preds[i, temp] = 1
            temp = es.index(1) if 1 in es else 0
            end_preds[i, temp] = 1

        #return the total loss and start and end predictions
        return (total_loss, start_preds, end_preds) if total_loss is not None else (start_preds, end_preds)


def preprocess_evalset(examples):
    """
    This function is used to preprocess the evaluation dataset by tokenizing the context and question.
    """
    inputs = tokenizer(examples['question'], examples['context'], max_length=512,
                truncation='only_second', stride=128, return_overflowing_tokens=True,
                return_offsets_mapping=True, padding='max_length')

    sample_map = inputs.pop('overflow_to_sample_mapping', list(range((len(inputs['input_ids'])))))
    inputs['example_id'] = [examples['id'][i] for i in sample_map]
    return inputs


def post_process_to_compute_em(features, preds):
    """
    This function is used to post-process the predictions and compute the exact match score for SpanBERTCRF model.
    """
    start_logits, end_logits = preds.predictions #predicted start and end positions of the answer
    label = preds.label_ids #actual start and end positions of the answer
    predictions = {}
    references = {}
    id_array = []
    example_to_features = defaultdict(list)
    #mapping the features to the example ids, also storing the references
    for i, feature in enumerate(features):
        example_to_features[feature['example_id']].append((i, feature))
        references[feature['example_id']] = str(label[0][i]) + ',' + str(label[1][i])
        id_array.append(feature['example_id'])

    #iterate over the examples and find the best start and end positions
    for example_id, feature_list in example_to_features.items():
        best_score = -float("inf")
        best_answer = ""
        #iterate over the features and find the best start and end positions
        for idx, features in feature_list:
            offset = features['offset_mapping']
            start_logit = start_logits[idx]
            end_logit = end_logits[idx]
            #iterate over the start and end positions and find the best answer
            for start_index in range((len(start_logit))):
                for end_index in range(start_index, len(end_logit)):
                    if offset[start_index] == (0, 0) or offset[end_index] == (0, 0):
                        #unanswerable question so skip
                        continue
                    #compute the score for the answer
                    score = start_logit[start_index] + end_logit[end_index]
                    if score > best_score:
                        best_score = score
                        best_answer = f'{start_index},{end_index}'
        predictions[example_id] = best_answer

    #create reference and prediction lists using the dictionaries
    final_preds, final_refs = [], []
    for id in id_array:
        final_preds.append(predictions[id])
        final_refs.append(references[id])
    return exact_match_score(final_preds, final_refs)

In [None]:
#preprocess the evaluation set, load the model and initialize wandb run
final_val_set = dataset['validation'].map(preprocess_evalset, batched=True, remove_columns=dataset["validation"].column_names)
final_val_set

spanbert_crf = SpanBERTCRF.from_pretrained('SpanBERT/spanbert-large-cased')

wandb.init(entity='cv-himanshu', project='finetuning-spanberts')

Some weights of SpanBERTCRF were not initialized from the model checkpoint at SpanBERT/spanbert-large-cased and are newly initialized: ['end_classifier.bias', 'end_classifier.weight', 'end_crf.end_transitions', 'end_crf.start_transitions', 'end_crf.transitions', 'fc1.bias', 'fc1.weight', 'fc2.bias', 'fc2.weight', 'qa_outputs.bias', 'qa_outputs.weight', 'start_classifier.bias', 'start_classifier.weight', 'start_crf.end_transitions', 'start_crf.start_transitions', 'start_crf.transitions']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
#define training arguments and train with Trainer API
training_args_crf = TrainingArguments(output_dir='./results', eval_strategy='epoch',
                learning_rate=1e-5, num_train_epochs=6, weight_decay=0.01,
                per_device_train_batch_size=64, per_device_eval_batch_size=64,
                report_to='wandb', logging_dir='./logs', logging_steps=50,
                label_names=['start_positions', 'end_positions'],)

trainer_crf = Trainer(model=spanbert_crf, args=training_args_crf, processing_class=tokenizer, 
    train_dataset=tokenized_dataset['train'], eval_dataset=tokenized_dataset['validation'])

trainer_crf.train()



Epoch,Training Loss,Validation Loss
1,6.3761,6.26981
2,2.614,2.195243
3,1.857,1.657076
4,1.4016,1.471706
5,1.2337,1.386216
6,1.1536,1.426081




TrainOutput(global_step=2826, training_loss=8.340884785743276, metrics={'train_runtime': 8941.6355, 'train_samples_per_second': 40.407, 'train_steps_per_second': 0.316, 'total_flos': 3.362731018478346e+17, 'train_loss': 8.340884785743276, 'epoch': 6.0})

In [None]:
#predict on validation set and compute the exact match score
preds = trainer_crf.predict(tokenized_dataset['validation'])

spanbert_crf_em = post_process_to_compute_em(final_val_set, preds)
print(spanbert_crf_em)
print(f"SpanBERT-CRF Exact Match Score: {spanbert_crf_em:.2f} %")

wandb.finish()

66.03254067584481
SpanBERT-CRF Exact Match Score: 66.03 %


0,1
eval/loss,█▂▁▁▁▁
eval/runtime,▁▁▂▃▆█
eval/samples_per_second,██▇▆▃▁
eval/steps_per_second,███▆▃▁
test/loss,▁
test/runtime,▁
test/samples_per_second,▁
test/steps_per_second,▁
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▇▇▇▇▇█████
train/global_step,▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇█████

0,1
eval/loss,1.42608
eval/runtime,128.6654
eval/samples_per_second,93.149
eval/steps_per_second,0.731
test/loss,1.42608
test/runtime,128.0138
test/samples_per_second,93.623
test/steps_per_second,0.734
total_flos,3.362731018478346e+17
train/epoch,6.0
