## Load BERT Tokenizer & SQuADv2 and preprocess the dataset

In [None]:
import wandb
from datasets import load_dataset
from transformers import BertTokenizerFast, BertForQuestionAnswering, TrainingArguments, Trainer
import numpy as np
import torch
import torch.nn as nn
from torchcrf import CRF
from typing import Optional, Tuple, Union
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from collections import defaultdict

torch.manual_seed(seed=42)
wandb.login()

dataset = load_dataset('rajpurkar/squad_v2')
dataset['train'] = dataset['train'].shuffle(seed=42).select(range(48000)) #select 48k samples for finetuning
tokenizer = BertTokenizerFast.from_pretrained('spanbert/spanbert-large-cased')

In [None]:
def preprocess_dataset(examples):
    questions = [q.strip() for q in examples['question']] #remove leading and trailing whitespaces
    inputs = tokenizer(questions, examples['context'], max_length=512,
                truncation='only_second', stride=128, return_overflowing_tokens=True,
                return_offsets_mapping=True, padding='max_length')

    offset_mapping = inputs.pop('offset_mapping') #mapping between tokens and original text positions
    sample_map = inputs.pop('overflow_to_sample_mapping') #mapping of tokenized examples with original examples

    #map answers to the tokenized context
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i] #index in dataset
        answer = examples['answers'][sample_idx] #true answer for the question (ground truth)
        start_char = answer['answer_start'][0] if answer['answer_start'] else 0
        end_char = start_char + len(answer['text'][0]) if answer['text'] else 0

        sequence_ids = inputs.sequence_ids(i)

        #find the start and end of the context
        idx = 0
        while idx < len(sequence_ids) and sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while idx < len(sequence_ids) and sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        #if the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)

        #else find the token indices that correspond to the start and end of the answer
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)
            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs['start_positions'] = torch.tensor(start_positions)
    inputs['end_positions'] = torch.tensor(end_positions)
    return inputs


def exact_match_score(predictions, references):
    assert len(predictions) == len(references), "Lists must have the same length"
    matches = sum(p == r for p, r in zip(predictions, references))
    return matches / len(references) * 100 # Convert to percentage


def compute_em(preds):
    label = preds.label_ids
    pred = np.argmax(preds.predictions, axis=-1)
    predictions, references = [], []
    for i in range(len(label[0])):
        predictions.append(str(pred[0][i]) + ',' + str(pred[1][i]))
        references.append(str(label[0][i]) + ',' + str(label[1][i]))
    return exact_match_score(predictions, references)

In [None]:
tokenized_dataset = dataset.map(preprocess_dataset, batched=True, remove_columns=dataset["train"].column_names)
tokenized_dataset

## Fine-tuning spanbert

In [None]:
spanbert = BertForQuestionAnswering.from_pretrained('SpanBERT/spanbert-large-cased')

wandb.init(entity='cv-himanshu', project='finetuning-spanberts')

In [None]:
training_args = TrainingArguments(output_dir='./results', eval_strategy='epoch',
                learning_rate=1e-5, num_train_epochs=6, weight_decay=0.01,
                per_device_train_batch_size=64, per_device_eval_batch_size=64,
                report_to='wandb', logging_dir='./logs', logging_steps=50,
                label_names=['start_positions', 'end_positions'],)

trainer = Trainer(model=spanbert, args=training_args, processing_class=tokenizer,
    train_dataset=tokenized_dataset['train'], eval_dataset=tokenized_dataset['validation'])

trainer.train()

In [None]:
preds = trainer.predict(tokenized_dataset['validation'])

spanbert_em = compute_em(preds)
print(spanbert_em)
print(f"SpanBERT Exact Match Score: {spanbert_em:.2f} %")

wandb.finish()

## Fine-tuning SpanBERT_CRF

In [None]:
class SpanBERTCRF(BertForQuestionAnswering):
    def __init__(self, config):
        super().__init__(config)
        # Additional hidden layers
        self.fc1 = nn.Linear(config.hidden_size, config.hidden_size // 2)
        self.fc2 = nn.Linear(config.hidden_size // 2, config.hidden_size // 4)
        
        self.start_classifier = nn.Linear(config.hidden_size // 4, self.num_labels)
        self.end_classifier = nn.Linear(config.hidden_size // 4, self.num_labels)
        
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.start_crf = CRF(num_tags=config.num_labels, batch_first=True)  # CRF for start positions
        self.end_crf = CRF(num_tags=config.num_labels, batch_first=True)  # CRF for end positions

    def forward( self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        start_positions: Optional[torch.Tensor] = None,
        end_positions: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.bert( input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]  # Shape: (batch_size, sequence_length, hidden_size)

        hidden = self.fc1(sequence_output)
        hidden = self.relu(hidden)
        hidden = self.dropout(hidden)
        hidden = self.fc2(hidden)
        hidden = self.relu(hidden)
        hidden = self.dropout(hidden)

        start_logits = self.start_classifier(hidden)  # (batch_size, seq_length, num_labels)
        end_logits = self.end_classifier(hidden)  # (batch_size, seq_length, num_labels)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            batch_size, seq_length = start_logits.shape[:2]

            # Create label tensors for CRF
            start_labels = torch.zeros((batch_size, seq_length), dtype=torch.long, device=input_ids.device)
            end_labels = torch.zeros((batch_size, seq_length), dtype=torch.long, device=input_ids.device)

            for i in range(batch_size):
                start, end = start_positions[i].item(), end_positions[i].item()
                if 0 <= start < seq_length:
                    start_labels[i, start] = 1 # Mark the correct start position
                if 0 <= end < seq_length:
                    end_labels[i, end] = 1 # Mark the correct end position

            # Compute CRF loss separately for start and end logits
            start_loss = -self.start_crf(start_logits, start_labels, mask=attention_mask.bool(), reduction="mean")
            end_loss = -self.end_crf(end_logits, end_labels, mask=attention_mask.bool(), reduction="mean")
            total_loss = (start_loss + end_loss) / 2

        start_predictions = self.start_crf.decode(start_logits, mask=attention_mask.bool())
        end_predictions = self.end_crf.decode(end_logits, mask=attention_mask.bool())

        bs, sq = start_logits.shape[:2]
        start_preds = torch.zeros((bs, sq), device=input_ids.device)
        end_preds = torch.zeros((bs, sq), device=input_ids.device)

        for i, (ss, es) in enumerate(zip(start_predictions, end_predictions)):
            temp = ss.index(1) if 1 in ss else 0
            start_preds[i, temp] = 1
            temp = es.index(1) if 1 in es else 0
            end_preds[i, temp] = 1

        return (total_loss, start_preds, end_preds) if total_loss is not None else (start_preds, end_preds)


def preprocess_evalset(examples):
    inputs = tokenizer(examples['question'], examples['context'], max_length=512,
                truncation='only_second', stride=128, return_overflowing_tokens=True,
                return_offsets_mapping=True, padding='max_length')

    sample_map = inputs.pop('overflow_to_sample_mapping', list(range((len(inputs['input_ids'])))))
    inputs['example_id'] = [examples['id'][i] for i in sample_map]
    return inputs


def post_process_to_compute_em(features, preds):
    start_logits, end_logits = preds.predictions
    label = preds.label_ids
    predictions = {}
    references = {}
    example_to_features = defaultdict(list)
    for i, feature in enumerate(features):
        example_to_features[feature['example_id']].append((i, feature))
        references[feature['example_id']] = str(label[0][i]) + ',' + str(label[1][i])

    for example_id, feature_list in example_to_features.items():
        best_score = -float("inf")
        best_answer = ""
        for idx, features in feature_list:
            offset = features['offset_mapping']
            start_logit = start_logits[idx]
            end_logit = end_logits[idx]
            for start_index in range((len(start_logit))):
                for end_index in range(start_index, len(end_logit)):
                    if offset[start_index] == (0, 0) or offset[end_index] == (0, 0):
                        continue
                    score = start_logit[start_index] + end_logit[end_index]
                    if score > best_score:
                        best_score = score
                        best_answer = f'{start_index},{end_index}'
        predictions[example_id] = best_answer
    return predictions, references

In [None]:
spanbert_crf = SpanBERTCRF.from_pretrained('SpanBERT/spanbert-large-cased')

wandb.init(entity='cv-himanshu', project='finetuning-spanberts')

In [None]:
training_args_crf = TrainingArguments(output_dir='./results', eval_strategy='epoch',
                learning_rate=1e-5, num_train_epochs=6, weight_decay=0.01,
                per_device_train_batch_size=64, per_device_eval_batch_size=64,
                report_to='wandb', logging_dir='./logs', logging_steps=50,
                label_names=['start_positions', 'end_positions'],)

trainer_crf = Trainer(model=spanbert_crf, args=training_args_crf, processing_class=tokenizer, 
    train_dataset=tokenized_dataset['train'], eval_dataset=tokenized_dataset['validation'])

trainer_crf.train()

In [None]:
final_val_set = dataset['validation'].map(preprocess_evalset, batched=True, remove_columns=dataset["validation"].column_names)

preds = trainer_crf.predict(tokenized_dataset['validation'])

hopefully_final_predictions, references = post_process_to_compute_em(final_val_set, preds)

In [None]:
spanbert_crf_em = exact_match_score(predictions=[hopefully_final_predictions[id] for id in references],
                                    references=[references[id] for id in references])
print(spanbert_em)
print(f"SpanBERT Exact Match Score: {spanbert_em:.2f} %")

wandb.finish()