In [None]:
from datasets import load_dataset
from transformers import BertForQuestionAnswering, BertTokenizerFast

# import evaluate
# import numpy as np
# import torch

In [None]:
# Load model and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizerFast.from_pretrained(model_name)
model = BertForQuestionAnswering.from_pretrained(model_name)

In [9]:
# squad v1
dataset = load_dataset("squad")

In [None]:
def preprocess_function(examples):
    # Tokenize both questions and contexts
    tokenized_examples = tokenizer(
        examples["question"],
        examples["context"],
        # Only truncate the context (second sequence)
        truncation="only_second",
        max_length=384,  # Maximum length for BERT
        stride=128,  # Overlap between chunks when splitting long contexts
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        padding="max_length",
    )

    # Since one example might give us several features,
    # we need a map from feature to example
    sample_mapping = tokenized_examples.pop("overflow_to_sample_mapping")
    offset_mapping = tokenized_examples.pop("offset_mapping")

    # Initialize lists to store the labels
    tokenized_examples["start_positions"] = []
    tokenized_examples["end_positions"] = []

    for i, offsets in enumerate(offset_mapping):
        # Get the corresponding example index
        sample_index = sample_mapping[i]

        # Get the answer
        answers = examples["answers"][sample_index]
        start_char = answers["answer_start"][0]
        end_char = start_char + len(answers["text"][0])

        # Find the start and end token positions
        sequence_ids = tokenized_examples.sequence_ids(i)

        # Find the start and end of the context
        context_start = 0
        while sequence_ids[context_start] != 1:
            context_start += 1
        context_end = len(sequence_ids) - 1
        while sequence_ids[context_end] != 1:
            context_end -= 1

        # If the answer is out of the span, set to the CLS token
        if (start_char < offsets[context_start][0]) or (
            end_char > offsets[context_end][1]
        ):
            tokenized_examples["start_positions"].append(0)
            tokenized_examples["end_positions"].append(0)
        else:
            # Find the start and end tokens
            start_token = context_start
            while (start_token <= context_end) and (
                offsets[start_token][0] <= start_char
            ):
                start_token += 1
            tokenized_examples["start_positions"].append(start_token - 1)

            end_token = context_end
            while (
                (end_token >= context_start) and (offsets[end_token][1])
            ) >= end_char:
                end_token -= 1
            tokenized_examples["end_positions"].append(end_token + 1)

    return tokenized_examples


# Apply the preprocessing to the dataset
tokenized_dataset = dataset.map(
    preprocess_function, batched=True, remove_columns=dataset["train"].column_names
)

Map:   0%|          | 0/87599 [00:00<?, ? examples/s]

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

In [15]:
tokenized_dataset["train"][1]

{'input_ids': [101,
  2054,
  2003,
  1999,
  2392,
  1997,
  1996,
  10289,
  8214,
  2364,
  2311,
  1029,
  102,
  6549,
  2135,
  1010,
  1996,
  2082,
  2038,
  1037,
  3234,
  2839,
  1012,
  10234,
  1996,
  2364,
  2311,
  1005,
  1055,
  2751,
  8514,
  2003,
  1037,
  3585,
  6231,
  1997,
  1996,
  6261,
  2984,
  1012,
  3202,
  1999,
  2392,
  1997,
  1996,
  2364,
  2311,
  1998,
  5307,
  2009,
  1010,
  2003,
  1037,
  6967,
  6231,
  1997,
  4828,
  2007,
  2608,
  2039,
  14995,
  6924,
  2007,
  1996,
  5722,
  1000,
  2310,
  3490,
  2618,
  4748,
  2033,
  18168,
  5267,
  1000,
  1012,
  2279,
  2000,
  1996,
  2364,
  2311,
  2003,
  1996,
  13546,
  1997,
  1996,
  6730,
  2540,
  1012,
  3202,
  2369,
  1996,
  13546,
  2003,
  1996,
  24665,
  23052,
  1010,
  1037,
  14042,
  2173,
  1997,
  7083,
  1998,
  9185,
  1012,
  2009,
  2003,
  1037,
  15059,
  1997,
  1996,
  24665,
  23052,
  2012,
  10223,
  26371,
  1010,
  2605,
  2073,
  1996,
  6261,
  2984,