In [1]:
!pip install pytorch-crf

Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl.metadata (2.4 kB)
Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2


In [1]:
import random
import numpy as np
import torch
from datasets import load_dataset
from transformers import AutoConfig, AutoTokenizer, AutoModel, AutoModelForQuestionAnswering, TrainingArguments, Trainer, default_data_collator
import matplotlib.pyplot as plt
from torchcrf import CRF

# Preprocessing Steps

In [2]:
squad = load_dataset("squad_v2")
train_subset = squad["train"].select(range(15000))
val_subset = squad["validation"]

In [3]:
tokenizer = AutoTokenizer.from_pretrained("SpanBERT/spanbert-base-cased", use_fast=True)

In [4]:
def binary_search(arr, target):
    lo, hi = 0, len(arr)
    while lo < hi:
        mid = (lo + hi) // 2
        if target < arr[mid]:
            hi = mid
        else:
            lo = mid + 1
    return lo

def prepare_features(dataset):
    tokenized_dataset = tokenizer(
        dataset["question"],
        dataset["context"],
        truncation="only_second",
        max_length=384,
        stride=128,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length"
    )

    overflow_to_sample_mapping = tokenized_dataset.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_dataset.pop("offset_mapping")
    
    tokenized_dataset["start_positions"] = []
    tokenized_dataset["end_positions"] = []

    idx = -1

    for offset in offset_mapping:
        idx += 1
        sample_idx = overflow_to_sample_mapping[idx]

        answer = dataset["answers"][sample_idx]

        if len(answer["text"]) == 0:
            # No answer found
            tokenized_dataset["start_positions"].append(0)
            tokenized_dataset["end_positions"].append(0)
            continue

        start_char_idx = answer["answer_start"][0]
        end_char_idx = start_char_idx + len(answer["text"][0])

        start_offsets = [token[0] for token in offset]
        end_offsets = [token[1] for token in offset]

        token_start_index = binary_search(start_offsets, start_char_idx) - 1
        token_start_index = max(0, token_start_index)

        token_end_index = binary_search(end_offsets, end_char_idx) - 1
        token_end_index = max(token_start_index, token_end_index)

        if offset[token_start_index][0] > start_char_idx or offset[token_end_index][1] < end_char_idx:
            tokenized_dataset["start_positions"].append(0)
            tokenized_dataset["end_positions"].append(0)
        else:
            tokenized_dataset["start_positions"].append(token_start_index)
            tokenized_dataset["end_positions"].append(token_end_index)
    
    return tokenized_dataset

In [5]:
train_dataset = train_subset.map(prepare_features, batched=True, remove_columns=train_subset.column_names)
val_dataset = val_subset.map(prepare_features, batched=True, remove_columns=val_subset.column_names)

Map:   0%|          | 0/11873 [00:00<?, ? examples/s]

# SpanBERT Training

In [10]:
model_spanbert = AutoModelForQuestionAnswering.from_pretrained("SpanBERT/spanbert-base-cased")

  return self.fget.__get__(instance, owner)()
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
training_args = TrainingArguments(
    output_dir="./results/SpanBERT",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    num_train_epochs=6,
    weight_decay=0.01,
    # fp16=True,  # enable mixed precision training for speed
    logging_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="exact_match",
    greater_is_better=True
)

In [9]:
def exact_match_score(predictions, references):
    assert len(predictions) == len(references), "Lists must have the same length"
    matches = sum(p == r for p, r in zip(predictions, references))
    return matches / len(references) * 100

In [10]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    start_predictions, end_predictions = predictions
    start_label_ids, end_label_ids = labels
    
    start_pred = np.argmax(start_predictions, axis=1)
    end_pred = np.argmax(end_predictions, axis=1)
    
    pred_answers = []
    ref_answers = []
    for i in range(len(start_pred)):
        pred_answers.append((start_pred[i], end_pred[i]))
        ref_answers.append((start_label_ids[i], end_label_ids[i]))

    score_all = exact_match_score(pred_answers, ref_answers)

    non_empty_indices = []
    for i, (s, e) in enumerate(zip(start_label_ids, end_label_ids)):
        if not (s == 0 and e == 0):
            non_empty_indices.append(i)

    if len(non_empty_indices) == 0:
        return {"exact_match_non_empty": 0.0, "exact_match": score_all}

    filtered_start_preds = np.argmax(start_predictions[non_empty_indices], axis=1)
    filtered_end_preds = np.argmax(end_predictions[non_empty_indices], axis=1)
    filtered_start_labels = start_label_ids[non_empty_indices]
    filtered_end_labels = end_label_ids[non_empty_indices]

    pred_answers = list(zip(filtered_start_preds, filtered_end_preds))
    ref_answers = list(zip(filtered_start_labels, filtered_end_labels))
    
    score = exact_match_score(pred_answers, ref_answers)
    return {"exact_match_non_empty": score, "exact_match": score_all}

In [11]:
trainer_spanbert = Trainer(
    model=model_spanbert,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    processing_class=tokenizer,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)

In [14]:
trainer_spanbert.train()



Epoch,Training Loss,Validation Loss,Exact Match Non Empty,Exact Match
1,2.3459,2.181549,46.296061,26.491299
2,1.6314,2.122664,50.507352,34.842251
3,1.3089,2.045232,50.23507,37.663848
4,1.1167,2.207699,51.614631,41.022117
5,0.9635,2.264676,49.999092,40.1114
6,0.8608,2.447841,50.398439,41.542527




TrainOutput(global_step=2868, training_loss=1.5048309376715283, metrics={'train_runtime': 5038.01, 'train_samples_per_second': 18.208, 'train_steps_per_second': 0.569, 'total_flos': 1.7977347511815168e+16, 'train_loss': 1.5048309376715283, 'epoch': 6.0})

In [17]:
results = trainer_spanbert.evaluate()
exact_match_non_empty = results["eval_exact_match_non_empty"]
exact_match = results["eval_exact_match"]
print(f"Exact Match for all Examples: {exact_match:.4f}")
print(f"Exact Match for Non empty answers: {exact_match_non_empty:.4f}")
print(results)



{'eval_loss': 2.4478414058685303, 'eval_exact_match_non_empty': 50.39843891813396, 'eval_exact_match': 41.542527240201665, 'eval_runtime': 195.3942, 'eval_samples_per_second': 62.939, 'eval_steps_per_second': 3.936, 'epoch': 6.0}


In [19]:
exact_match_non_empty = results["eval_exact_match_non_empty"]
exact_match = results["eval_exact_match"]

print(f"Exact Match for all Examples: {exact_match:.4f}")
print(f"Exact Match for Non empty answers: {exact_match_non_empty:.4f}")

Exact Match for all Examples: 41.5425
Exact Match for Non empty answers: 50.3984


# SpanBERT-CRF Training

In [None]:
class SpanBERTCRF(torch.nn.Module):
    def __init__(self, num_tags):
        super(SpanBERTCRF, self).__init__()
        self.spanbert = AutoModelForQuestionAnswering.from_pretrained("SpanBERT/spanbert-base-cased")
        hidden_size = self.spanbert.config.hidden_size
        self.classifier = torch.nn.Linear(hidden_size, num_tags)
        self.crf = CRF(num_tags, batch_first=True)

    def forward(self, input_ids, attention_mask=None, token_type_ids=None, start_positions=None, end_positions=None):
        outputs = self.spanbert(
            input_ids=input_ids, 
            attention_mask=attention_mask, 
            token_type_ids=token_type_ids,
            start_positions=start_positions,
            end_positions=end_positions,
            output_hidden_states=True,
            return_dict=True
        )

        sequence_output = outputs.hidden_states[-1]
        logits = self.classifier(sequence_output)

        start_logits = outputs.start_logits
        end_logits = outputs.end_logits

        is_training = start_positions is not None and end_positions is not None

        if is_training:
            crf_tags = torch.zeros(tuple(input_ids.size()), device=input_ids.device)
            
            for i in range(crf_tags.size()[0]):
                crf_tags[i, start_positions[i]:end_positions[i]+1] = 1
            
            crf_loss = -self.crf(logits, crf_tags, mask=attention_mask.bool(), reduction='mean')
            
            spanbert_loss = outputs.loss
            
            total_loss = spanbert_loss + crf_loss
            
            return {
                "loss": total_loss,
                "start_logits": start_logits,
                "end_logits": end_logits,
                "hidden_states": outputs.hidden_states
            }
        else:
            crf_tags = self.crf.decode(logits, mask=attention_mask.bool())
            
            return {
                "start_logits": start_logits,
                "end_logits": end_logits, 
                "crf_tags": crf_tags,
                "hidden_states": outputs.hidden_states
            }

In [11]:
train_dataset_crf = train_subset.map(prepare_features, batched=True, remove_columns=train_subset.column_names)
val_dataset_crf = val_subset.map(prepare_features, batched=True, remove_columns=val_subset.column_names)

In [14]:
model_spanbert_crf = SpanBERTCRF(num_tags=2)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [11]:
training_args = TrainingArguments(
    output_dir="./results/SpanBERT-CRF",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=8,
    num_train_epochs=6,
    weight_decay=0.01,
    logging_steps=100,
    save_total_limit=2,
    load_best_model_at_end=True,
    metric_for_best_model="exact_match",
    greater_is_better=True,
)

In [12]:
class SpanBertCRFTrainer(Trainer):
    def compute_loss(self, model, inputs, num_items_in_batch = None, return_outputs=False):
        outputs = model(**inputs)
        loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]
        return (loss, outputs) if return_outputs else loss
        
    def prediction_step(self, model, inputs, prediction_loss_only=False, ignore_keys=None):
        with torch.no_grad():
            outputs = model(**inputs)
            
        if prediction_loss_only:
            return (outputs["loss"].detach() if isinstance(outputs, dict) else outputs[0].detach(), None, None)
            
        # Extract predictions
        start_logits = outputs["start_logits"].detach()
        end_logits = outputs["end_logits"].detach()
        
        # Extract labels
        start_positions = inputs["start_positions"]
        end_positions = inputs["end_positions"]
        
        return (
            outputs["loss"].detach() if "loss" in outputs else None,
            (start_logits, end_logits),
            (start_positions, end_positions)
        )

In [13]:
def exact_match_score(predictions, references):
    assert len(predictions) == len(references), "Lists must have the same length"
    matches = sum(p == r for p, r in zip(predictions, references))
    return matches / len(references) * 100

In [None]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    start_predictions, end_predictions = predictions
    start_label_ids, end_label_ids = labels
    
    start_pred = np.argmax(start_predictions, axis=1)
    end_pred = np.argmax(end_predictions, axis=1)
    
    pred_answers = []
    ref_answers = []
    for i in range(len(start_pred)):
        pred_answers.append((start_pred[i], end_pred[i]))
        ref_answers.append((start_label_ids[i], end_label_ids[i]))

    score_all = exact_match_score(pred_answers, ref_answers)

    non_empty_indices = []
    for i, (s, e) in enumerate(zip(start_label_ids, end_label_ids)):
        if not (s == 0 and e == 0):
            non_empty_indices.append(i)

    if len(non_empty_indices) == 0:
        return {"exact_match_non_empty": 0.0, "exact_match": score_all}

    filtered_start_preds = np.argmax(start_predictions[non_empty_indices], axis=1)
    filtered_end_preds = np.argmax(end_predictions[non_empty_indices], axis=1)
    filtered_start_labels = start_label_ids[non_empty_indices]
    filtered_end_labels = end_label_ids[non_empty_indices]

    pred_answers = list(zip(filtered_start_preds, filtered_end_preds))
    ref_answers = list(zip(filtered_start_labels, filtered_end_labels))
    
    score = exact_match_score(pred_answers, ref_answers)
    return {"exact_match_non_empty": score, "exact_match": score_all}

In [15]:
trainer_spanbert_crf = SpanBertCRFTrainer(
    model=model_spanbert_crf,
    args=training_args,
    train_dataset=train_dataset_crf,
    eval_dataset=val_dataset_crf,
    processing_class=tokenizer,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)

In [19]:
trainer_spanbert_crf.train()



Epoch,Training Loss,Validation Loss,Exact Match Non Empty,Exact Match
1,14.096,12.782765,39.2341,19.107985
2,9.2472,10.729434,43.01245,25.613108
3,6.6786,10.372138,46.302832,29.39421
4,5.2687,10.64336,49.545289,30.451293
5,4.1172,11.105186,46.002318,28.841275
6,3.4708,11.901663,47.1256,28.337128




TrainOutput(global_step=2868, training_loss=8.069557434677911, metrics={'train_runtime': 7088.4083, 'train_samples_per_second': 12.941, 'train_steps_per_second': 0.405, 'total_flos': 0.0, 'train_loss': 8.069557434677911, 'epoch': 6.0})

In [28]:
results = trainer_spanbert_crf.evaluate()
print(results)

{'eval_loss': 10.643360137939453, 'eval_exact_match_non_empty': 49.5452895262298, 'eval_exact_match': 30.45129289315336, 'eval_runtime': 352.7288, 'eval_samples_per_second': 34.865, 'eval_steps_per_second': 2.18, 'epoch': 6.0}


In [20]:
exact_match_non_empty = results["eval_exact_match_non_empty"]
exact_match = results["eval_exact_match"]

print(f"Exact Match for all Examples: {exact_match:.4f}")
print(f"Exact Match for Non empty answers: {exact_match_non_empty:.4f}")

Exact Match for all Examples: 30.4513
Exact Match for Non empty answers: 49.5453
