# Preprocessing

In [2]:
# load in data
from datasets import load_dataset
raw_datasets = load_dataset("squad")
raw_datasets


# tokenizer
from transformers import AutoTokenizer

model_checkpoint = "distilbert-base-cased"
# model_checkpoint = "bert-base-cased" # try it yourself
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)


# find answer token idx
def find_answer_token_idx(
    ctx_start,
    ctx_end,
    ans_start_char,
    ans_end_char,
    offset):
  
  start_idx = 0
  end_idx = 0

  if offset[ctx_start][0] > ans_start_char or offset[ctx_end][1] < ans_end_char:
    pass
    # print("target is (0, 0)")
    # nothing else to do
  else:
    # find the start and end TOKEN positions

    # the 'trick' is knowing what is in units of tokens and what is in
    # units of characters

    # recall: the offset_mapping contains the character positions of each token

    i = ctx_start
    for start_end_char in offset[ctx_start:]:
      start, end = start_end_char
      if start == ans_start_char:
        start_idx = i
        # don't break yet
      
      if end == ans_end_char:
        end_idx = i
        break

      i += 1
  return start_idx, end_idx


# tokenize the training data
# (i.e. expand question+context pairs into question+smaller context windows)

# Google used 384 for SQuAD
max_length = 384
stride = 128

def tokenize_fn_train(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  offset_mapping = inputs.pop("offset_mapping")
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
  answers = batch['answers']
  start_idxs, end_idxs = [], []

  # same loop as above
  for i, offset in enumerate(offset_mapping):
    sample_idx = orig_sample_idxs[i]
    answer = answers[sample_idx]

    ans_start_char = answer['answer_start'][0]
    ans_end_char = ans_start_char + len(answer['text'][0])

    sequence_ids = inputs.sequence_ids(i)

    # find start + end of context (first 1 and last 1)
    ctx_start = sequence_ids.index(1)
    ctx_end = len(sequence_ids) - sequence_ids[::-1].index(1) - 1

    start_idx, end_idx = find_answer_token_idx(
      ctx_start,
      ctx_end,
      ans_start_char,
      ans_end_char,
      offset)

    start_idxs.append(start_idx)
    end_idxs.append(end_idx)
  
  inputs["start_positions"] = start_idxs
  inputs["end_positions"] = end_idxs
  return inputs

train_dataset = raw_datasets["train"].map(
  tokenize_fn_train,
  batched=True,
  remove_columns=raw_datasets["train"].column_names,
)


# tokenize the validation set differently
# we won't need the targets since we will just compare with the original answer
# also: overwrite offset_mapping with Nones in place of question
def tokenize_fn_validation(batch):
  # some questions have leading and/or trailing whitespace
  questions = [q.strip() for q in batch["question"]]

  # tokenize the data (with padding this time)
  # since most contexts are long, we won't bother to pad per-minibatch
  inputs = tokenizer(
    questions,
    batch["context"],
    max_length=max_length,
    truncation="only_second",
    stride=stride,
    return_overflowing_tokens=True,
    return_offsets_mapping=True,
    padding="max_length",
  )

  # we don't need these later so remove them
  orig_sample_idxs = inputs.pop("overflow_to_sample_mapping")
  sample_ids = []

  # rewrite offset mapping by replacing question tuples with None
  # this will be helpful later on when we compute metrics
  for i in range(len(inputs["input_ids"])):
    sample_idx = orig_sample_idxs[i]
    sample_ids.append(batch['id'][sample_idx])

    sequence_ids = inputs.sequence_ids(i)
    offset = inputs["offset_mapping"][i]
    inputs["offset_mapping"][i] = [
      x if sequence_ids[j] == 1 else None for j, x in enumerate(offset)]
    
  inputs['sample_id'] = sample_ids
  return inputs

validation_dataset = raw_datasets["validation"].map(
  tokenize_fn_validation,
  batched=True,
    remove_columns=raw_datasets["validation"].column_names,
)

Map:   0%|          | 0/10570 [00:00<?, ? examples/s]

# Metrics

In [7]:
# now let's define a full compute_metrics function
# note: this will NOT be called from the trainer

from tqdm.autonotebook import tqdm

def compute_metrics(start_logits, end_logits, processed_dataset, orig_dataset):
  # map sample_id ('56be4db0acb8001400a502ec') to row indices of processed data
  sample_id2idxs = {}
  for i, id_ in enumerate(processed_dataset['sample_id']):
    if id_ not in sample_id2idxs:
      sample_id2idxs[id_] = [i]
    else:
      sample_id2idxs[id_].append(i)

  predicted_answers = []
  for sample in tqdm(orig_dataset):

    sample_id = sample['id']
    context = sample['context']

    # update these as we loop through candidate answers
    best_score = float('-inf')
    best_answer = None

    # now loop through the *expanded* input samples (fixed size context windows)
    # from here we will pick the highest probability start/end combination
    for idx in sample_id2idxs[sample_id]:
      start_logit = start_logits[idx] # (T,) vector
      end_logit = end_logits[idx] # (T,) vector

      # note: do NOT do the reverse: ['offset_mapping'][idx]
      offsets = processed_dataset[idx]['offset_mapping']

      start_indices = (-start_logit).argsort()
      end_indices = (-end_logit).argsort()

      for start_idx in start_indices[:n_largest]:
        for end_idx in end_indices[:n_largest]:

          # skip answers not contained in context window
          # recall: we set entries not pertaining to context to None earlier
          if offsets[start_idx] is None or offsets[end_idx] is None:
            continue
          
          # skip answers where end < start
          if end_idx < start_idx:
            continue
          
          # skip answers that are too long
          if end_idx - start_idx + 1 > max_answer_length:
            continue
          
          # see theory lecture for score calculation
          score = start_logit[start_idx] + end_logit[end_idx]
          if score > best_score:
            best_score = score

            # find positions of start and end characters
            # recall: offsets contains tuples for each token:
            # (start_char, end_char)
            first_ch = offsets[start_idx][0]
            last_ch = offsets[end_idx][1]

            best_answer = context[first_ch:last_ch]

    # save best answer
    predicted_answers.append({'id': sample_id, 'prediction_text': best_answer})
  
  # compute the metrics
  true_answers = [
    {'id': x['id'], 'answers': x['answers']} for x in orig_dataset
  ]
  return metric.compute(predictions=predicted_answers, references=true_answers)

# Train and Evaluate

In [8]:
# now load the model we want to fine-tune
model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)

Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-cased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
from transformers import TrainingArguments

args = TrainingArguments(
    "finetuned-squad",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01
)

In [11]:
from transformers import Trainer

# takes ~2.5h with bert on full dataset
# ~1h 15min with distilbert

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    # train_dataset=train_dataset.shuffle(seed=42).select(range(1_000)),
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
)
trainer.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss



KeyboardInterrupt



In [None]:
trainer_output = trainer.predict(validation_dataset)

In [None]:
predictions, _, _ = trainer_output

In [None]:
start_logits, end_logits = predictions

In [None]:
compute_metrics(
    start_logits,
    end_logits,
    validation_dataset, # processed
    raw_datasets["validation"], # orig
)

In [None]:
trainer.save_model('my_saved_model')

In [None]:
from transformers import pipeline

qa = pipeline(
  "question-answering",
  model='my_saved_model',
  device=0,
)

In [None]:
context = "Today I went to the store to purchase a carton of milk."
question = "What did I buy?"

qa(context=context, question=question)