In [1]:
from transformers import DistilBertTokenizerFast, DistilBertForQuestionAnswering
from datasets import load_dataset
from datasets import disable_caching
disable_caching()

import torch
import numpy as np
import random

# we set up some seeds so that we can reproduce results
seed = 123
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

In [2]:
from torch.utils.data import DataLoader
from transformers import DistilBertTokenizer, DistilBertForQuestionAnswering
import transformers
transformers.logging.set_verbosity_error()

def load_model(device):
  tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased-distilled-squad')
  model = DistilBertForQuestionAnswering.from_pretrained("distilbert-base-uncased-distilled-squad")
  model.to(device)
  return model, tokenizer

def get_start_end(inputs, answer):
  m, n = len(inputs), len(answer)
  for i in range(m - n + 1):
    if all(inputs[i: i+n] == answer):
      return i, i + n - 1
  return 0, 0

def preprocess(dataset, tokenizer):
  input_ids, attention_masks, labels = [], [], []
  for i in range(len(dataset)):
    answer = tokenizer(dataset[i]['answers'][0]['span_text'], return_tensors = 'pt')['input_ids'].squeeze()[1:-1]
    question = dataset[i]['questions'][0]['input_text']
    context = dataset[i]['contexts']
    encoding = tokenizer.encode_plus(
      question,
      context,
      add_special_tokens      = True,
      max_length              = tokenizer.model_max_length,
      return_token_type_ids   = False,
      return_attention_mask   = True,
      return_tensors          = "pt",
      padding                 = "max_length",
      truncation              = True
    )
    input_id, attention_mask = encoding['input_ids'], encoding['attention_mask']
    start, end = get_start_end(input_id.squeeze(), answer)
    labels.append((start, end))
    input_ids.append(input_id)
    attention_masks.append(attention_mask)
  return input_ids, attention_masks, labels

def preprocess_and_tokenize(dataset, tokenizer, batch_size):
  input_ids, attention_masks, labels = preprocess(dataset, tokenizer)
  dataset = QADataset(input_ids, attention_masks, labels)
  return DataLoader(dataset, batch_size = batch_size)

In [3]:
class QADataset(torch.utils.data.Dataset):

  def __init__(self, input_ids, attention_masks, labels):

    self.input_ids = input_ids
    self.attention_masks = attention_masks
    self.labels = labels

  def __len__(self):

    return len(self.input_ids)

  def __getitem__(self, index):
    return {
      'input_ids': self.input_ids[index],
      'attention_mask': self.attention_masks[index],
      'labels': torch.tensor(self.labels[index], dtype=torch.long)
    }

In [4]:
from tqdm.auto import tqdm

def train_loop(model, num_epochs, optimizer, lr_scheduler, train_data_loader, validation_data_loader, device):
  train_losses, val_losses = [], []
  for epoch in range(num_epochs):
    # put the model in training mode (important that this is done each epoch,
    # since we put the model into eval mode during validation)
    model.train()

    print(f"Epoch {epoch + 1} training:")
    progress_bar = tqdm(range(len(train_data_loader)))
    train_loss, val_loss, num_train, num_val = 0, 0, 0, 0

    for i, batch in enumerate(train_data_loader):
      start, end = batch['labels'][:, 0], batch['labels'][:, 1]
      res = model(batch['input_ids'].to(device).squeeze(), batch['attention_mask'].to(device).squeeze(), start_positions = start.to(device), end_positions = end.to(device))
      loss = res.loss
      loss.backward()
      batch_size = len(batch['input_ids'])
      train_loss += loss.item() * batch_size
      num_train += batch_size
      optimizer.step()
      lr_scheduler.step()
      optimizer.zero_grad()

      progress_bar.update(1)
    train_loss /= num_train

    # print the epoch's average metrics
    print(f"Epoch {epoch+1} average training loss={train_loss}")

    # normally, validation would be more useful when training for many epochs
    print("Running validation:")
    for i, batch in enumerate(validation_data_loader):
      start, end = batch['labels'][:, 0], batch['labels'][:, 1]
      res = model(batch['input_ids'].to(device).squeeze(), batch['attention_mask'].to(device).squeeze(), start_positions = start.to(device), end_positions = end.to(device))
      loss = res.loss
      batch_size = len(batch['input_ids'])
      val_loss += loss.item() * batch_size
      num_val += batch_size
      progress_bar.update(1)
    val_loss /= num_val
    
    print(f"Epoch {epoch+1} average validation loss={val_loss}")
    train_losses.append(train_loss)
    val_losses.append(val_loss)
  return train_losses, val_losses

In [5]:
from collections import Counter

def evaluate_metrics(input_ids, s, e, start, end, precision, recall, f1):
  for input_id, s1, e1, start1, end1 in zip(input_ids, s, e, start, end):
    pred = input_id[s1: e1 + 1]
    ref = input_id[start1: end1 + 1]
    counter_pred = Counter(pred.detach().cpu().numpy())
    counter_ref = Counter(ref.detach().cpu().numpy())
    matches = 1e-8
    for word in counter_pred:
      matches += min(counter_pred[word], counter_ref[word])
    pred_length = len(pred) + 1e-8
    ref_length = len(ref) + 1e-8
    precision.append(matches/pred_length)
    recall.append(matches/ref_length)
    f1.append(2/(pred_length/matches+ref_length/matches))

def eval_loop(model, validation_data_loader, device):
  progress_bar = tqdm(range(len(validation_data_loader)))
  precision, recall, f1 = [], [], []
  for i, batch in enumerate(validation_data_loader):
    start, end = batch['labels'][:, 0].to(device), batch['labels'][:, 1].to(device)
    res = model(batch['input_ids'].to(device).squeeze(), batch['attention_mask'].to(device).squeeze())
    s, e = res.start_logits.argmax(axis = 1), res.end_logits.argmax(axis = 1)
    evaluate_metrics(batch['input_ids'].to(device).squeeze(), s, e, start, end, precision, recall, f1)
    progress_bar.update(1)
  precision = np.mean(np.array(precision))
  recall = np.mean(np.array(recall))
  f1 = np.mean(np.array(f1))
  return precision, recall, f1

In [6]:
from transformers import get_scheduler

def main():
  device = "cuda" if torch.cuda.is_available() else "cpu"
  batch_size = 24

  model, tokenizer = load_model(device)
  train = load_dataset('json', data_files = 'train.json')['train']
  validation = load_dataset('json', data_files = 'dev.json')['train']

  print('Start Preprocessing.')
  train_data_loader = preprocess_and_tokenize(train, tokenizer, batch_size)
  validation_data_loader = preprocess_and_tokenize(validation, tokenizer, batch_size)

  num_epochs = 3
  optimizer = torch.optim.AdamW(model.parameters(), lr = 1e-4)
  lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=50,
    num_training_steps=len(train_data_loader) * num_epochs
  )

  train_losses, val_losses = train_loop(model, num_epochs, optimizer, lr_scheduler, train_data_loader, validation_data_loader, device)
  precision, recall, f1_score  = eval_loop(model, validation_data_loader, device)
  
  print("PRECISION: ", precision)
  print("RECALL: ", recall)
  print("F1-SCORE: ", f1_score)

if __name__ == "__main__":
    main()

Using custom data configuration default-40f1f9282355cd35
Found cached dataset json (/users/zzhang99/.cache/huggingface/datasets/json/default-40f1f9282355cd35/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Using custom data configuration default-825013e737ca82fb
Found cached dataset json (/users/zzhang99/.cache/huggingface/datasets/json/default-825013e737ca82fb/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Start Preprocessing.
Epoch 1 training:


  0%|          | 0/581 [00:00<?, ?it/s]

Epoch 1 average training loss=1.5888607696544703
Running validation:
Epoch 1 average validation loss=1.4448462908084576
Epoch 2 training:


  0%|          | 0/581 [00:00<?, ?it/s]

Epoch 2 average training loss=0.8521531998987115
Running validation:
Epoch 2 average validation loss=1.533890254598808
Epoch 3 training:


  0%|          | 0/581 [00:00<?, ?it/s]

Epoch 3 average training loss=0.4539193768883901
Running validation:
Epoch 3 average validation loss=1.6293585708302827


  0%|          | 0/37 [00:00<?, ?it/s]

PRECISION:  0.7673034473562863
RECALL:  0.7529600543503637
F1-SCORE:  0.6935632424949435
