## Load BERT Tokenizer & SQuADv2 and preprocess the dataset

In [1]:
import wandb
from datasets import load_dataset
from transformers import BertTokenizerFast, BertForQuestionAnswering, TrainingArguments, Trainer
import numpy as np
import torch
import torch.nn as nn
from torchcrf import CRF


torch.manual_seed(seed=42)
wandb.login()

dataset = load_dataset('rajpurkar/squad_v2')
dataset['train'] = dataset['train'].shuffle(seed=42).select(range(5000)) #select 32k samples for finetuning
tokenizer = BertTokenizerFast.from_pretrained('spanbert/spanbert-large-cased')

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mhimanshu22216[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [2]:
def preprocess_dataset(examples):
    questions = [q.strip() for q in examples['question']] #remove leading and trailing whitespaces
    inputs = tokenizer(questions, examples['context'], max_length=512,
                truncation='only_second', stride=128, return_overflowing_tokens=True,
                return_offsets_mapping=True, padding='max_length')

    offset_mapping = inputs.pop('offset_mapping') #mapping between tokens and original text positions
    sample_map = inputs.pop('overflow_to_sample_mapping') #mapping of tokenized examples with original examples

    #map answers to the tokenized context
    start_positions = []
    end_positions = []

    for i, offset in enumerate(offset_mapping):
        sample_idx = sample_map[i] #index in dataset
        answer = examples['answers'][sample_idx] #true answer for the question (ground truth)
        start_char = answer['answer_start'][0] if answer['answer_start'] else 0
        end_char = start_char + len(answer['text'][0]) if answer['text'] else 0

        sequence_ids = inputs.sequence_ids(i)

        #find the start and end of the context
        idx = 0
        while idx < len(sequence_ids) and sequence_ids[idx] != 1:
            idx += 1
        context_start = idx
        while idx < len(sequence_ids) and sequence_ids[idx] == 1:
            idx += 1
        context_end = idx - 1

        #if the answer is not fully inside the context, label it (0, 0)
        if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
            start_positions.append(0)
            end_positions.append(0)

        #else find the token indices that correspond to the start and end of the answer
        else:
            idx = context_start
            while idx <= context_end and offset[idx][0] <= start_char:
                idx += 1
            start_positions.append(idx - 1)
            idx = context_end
            while idx >= context_start and offset[idx][1] >= end_char:
                idx -= 1
            end_positions.append(idx + 1)

    inputs['start_positions'] = torch.tensor(start_positions)
    inputs['end_positions'] = torch.tensor(end_positions)
    return inputs


def exact_match_score(predictions, references):
    assert len(predictions) == len(references), "Lists must have the same length"
    matches = sum(p == r for p, r in zip(predictions, references))
    return matches / len(references) * 100 # Convert to percentage


def compute_em(preds):
    label = preds.label_ids
    pred = np.argmax(preds.predictions, axis=-1)
    predictions, references = [], []
    for i in range(len(label[0])):
        predictions.append(str(pred[0][i]) + ',' + str(pred[1][i]))
        references.append(str(label[0][i]) + ',' + str(label[1][i]))
    return exact_match_score(predictions, references)

In [3]:
tokenized_dataset = dataset.map(preprocess_dataset, batched=True, remove_columns=dataset["train"].column_names)

## Fine-tuning spanbert

In [None]:
spanbert = BertForQuestionAnswering.from_pretrained('SpanBERT/spanbert-large-cased')

wandb.init(entity='cv-himanshu', project='finetuning-spanberts')

In [None]:
training_args = TrainingArguments(output_dir='./results', eval_strategy='epoch',
                learning_rate=1e-5, num_train_epochs=6, weight_decay=0.01,
                per_device_train_batch_size=64, per_device_eval_batch_size=64,
                report_to='wandb', logging_dir='./logs', logging_steps=50)

trainer = Trainer(model=spanbert, args=training_args, processing_class=tokenizer,
    train_dataset=tokenized_dataset['train'], eval_dataset=tokenized_dataset['validation'])

trainer.train()

In [None]:
preds = trainer.predict(tokenized_dataset['validation'])

spanbert_em = compute_em(preds)
print(spanbert_em)
print(f"SpanBERT Exact Match Score: {spanbert_em:.2f} %")

wandb.finish()

## Fine-tuning SpanBERT_CRF

In [4]:
from typing import Optional, Tuple, Union
from transformers.modeling_outputs import QuestionAnsweringModelOutput
from transformers import BertPreTrainedModel, BertModel
from torch.nn import CrossEntropyLoss

class SpanBERTCRF(BertForQuestionAnswering):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.crf = CRF(num_tags=config.num_labels, batch_first=True) #will learn transitions between start and end positions

    def forward(
       self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        start_positions: Optional[torch.Tensor] = None,
        end_positions: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
        r"""
        start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the start of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for position (index) of the end of the labelled span for computing the token classification loss.
            Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
            are not taken into account for computing the loss.
        """
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Get BERT embeddings
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]  # Shape: (batch_size, sequence_length, hidden_size)

        # Compute logits
        logits = self.qa_outputs(sequence_output)  # Shape: (batch_size, sequence_length, num_labels)
        start_logits, end_logits = logits.split(1, dim=-1)
        start_logits = start_logits.squeeze(-1).contiguous()  # Shape: (batch_size, seq_length)
        end_logits = end_logits.squeeze(-1).contiguous()  # Shape: (batch_size, seq_length)

        total_loss = None
        if start_positions is not None and end_positions is not None:
            # If we are on multi-GPU, split add a dimension
            if len(start_positions.size()) > 1:
                start_positions = start_positions.squeeze(-1)
            if len(end_positions.size()) > 1:
                end_positions = end_positions.squeeze(-1)
            # sometimes the start/end positions are outside our model inputs, we ignore these terms
            ignored_index = start_logits.size(1)
            start_positions = start_positions.clamp(0, ignored_index)
            end_positions = end_positions.clamp(0, ignored_index)

            batch_size, seq_length = start_logits.shape

            # Create label tensors for CRF
            start_labels = torch.zeros((batch_size, seq_length), dtype=torch.long, device=input_ids.device)
            end_labels = torch.zeros((batch_size, seq_length), dtype=torch.long, device=input_ids.device)

            for i in range(batch_size):
                start, end = start_positions[i].item(), end_positions[i].item()
                if 0 <= start < seq_length:
                    start_labels[i, start] = 1  # Mark the correct start position
                if 0 <= end < seq_length:
                    end_labels[i, end] = 1  # Mark the correct end position

            # Stack start and end logits
            stacked_logits = torch.stack([start_logits, end_logits], dim=-1)  # (batch_size, seq_length, 2)

            # Compute CRF loss separately for start and end logits
            start_loss = -self.crf(stacked_logits, start_labels, mask=attention_mask.bool(), reduction="mean")
            end_loss = -self.crf(stacked_logits, end_labels, mask=attention_mask.bool(), reduction="mean")
            total_loss = (start_loss + end_loss) / 2

        if not return_dict:
            output = (start_logits, end_logits) + outputs[2:]
            return ((total_loss,) + output) if total_loss is not None else output

        return QuestionAnsweringModelOutput(
            loss=total_loss,
            start_logits=start_logits,
            end_logits=end_logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

In [5]:
model = SpanBERTCRF.from_pretrained('SpanBERT/spanbert-large-cased')

wandb.init(entity='cv-himanshu', project='finetuning-spanberts')

Some weights of SpanBERTCRF were not initialized from the model checkpoint at SpanBERT/spanbert-large-cased and are newly initialized: ['crf.end_transitions', 'crf.start_transitions', 'crf.transitions', 'qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
training_args = TrainingArguments(output_dir='./results', eval_strategy='epoch',
                learning_rate=1e-5, num_train_epochs=2, weight_decay=0.01,
                per_device_train_batch_size=64, per_device_eval_batch_size=64,
                report_to='wandb', logging_dir='./logs', logging_steps=50,
                label_names=['start_positions', 'end_positions'],)

trainer = Trainer(model=model, args=training_args, processing_class=tokenizer, 
    train_dataset=tokenized_dataset['train'], eval_dataset=tokenized_dataset['validation'])

trainer.train()



Epoch,Training Loss,Validation Loss
1,No log,5.986413
2,13.753800,4.797614




TrainOutput(global_step=80, training_loss=10.642256927490234, metrics={'train_runtime': 386.1617, 'train_samples_per_second': 25.984, 'train_steps_per_second': 0.207, 'total_flos': 9318646205607936.0, 'train_loss': 10.642256927490234, 'epoch': 2.0})

In [7]:
preds = trainer.predict(tokenized_dataset['validation'])

spanbert_em = compute_em(preds)
print(spanbert_em)
print(f"SpanBERT Exact Match Score: {spanbert_em:.2f} %")

0.0
SpanBERT Exact Match Score: 0.00 %
