In [1]:
# login to huggingface and push model to hub
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
from datasets import load_dataset, DatasetDict
import torch
torch.set_printoptions(linewidth=1000000)
from transformers import BertTokenizerFast, BertForQuestionAnswering
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

dataset = load_dataset("squad_v2")
dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})

In [3]:
def add_token_positions(example):
    answer_start_char = example['answers']['answer_start']
    if len(answer_start_char) > 0:
        answer_start_char = answer_start_char[0]
        answer_end_char = answer_start_char + len(example['answers']['text'][0])
        # print([example['context'][answer_start_char:answer_end_char]])
        context_encoding = tokenizer(example['context'])
        len_question = len(tokenizer(example['question'])['input_ids']) - 1
        example['start_positions'] = context_encoding.char_to_token(answer_start_char) + len_question
        example['end_positions'] = context_encoding.char_to_token(answer_end_char - 1) + len_question + 1
    else:
        example['start_positions'] = 0
        example['end_positions'] = 0

    return example

def fix_answer_indexes(example):
    real_answer = example['answers']['text']
    start_idx = example['answers']['answer_start'][0] if len(example['answers']['answer_start']) > 0 else 0
    # Get the real end index
    end_idx = start_idx + len(real_answer)

    # Deal with the problem of 1 or 2 more characters
    if example['context'][start_idx:end_idx] == real_answer:
        example['answers']['answer_end'] = end_idx
    # When the real answer is more by one character
    elif example['context'][start_idx-1:end_idx-1] == real_answer:
        example['answers']['answer_start'] = start_idx - 1
        example['answers']['answer_end'] = end_idx - 1
    # When the real answer is more by two characters
    elif example['context'][start_idx-2:end_idx-2] == real_answer:
        example['answers']['answer_start'] = start_idx - 2
        example['answers']['answer_end'] = end_idx - 2
    return example

# Apply the function to each example in the dataset
updated_dataset = DatasetDict()
updated_dataset['train'] = dataset['train'].map(add_token_positions)
updated_dataset['validation'] = dataset['validation'].map(add_token_positions)
# updated_dataset['train'] = updated_dataset['train'].map(fix_answer_indexes)
# updated_dataset['validation'] = updated_dataset['validation'].map(fix_answer_indexes)

In [4]:
for i in range(10):
    len_question = len(tokenizer(updated_dataset['train'][i]['question'])['input_ids']) -1
    print(updated_dataset['train'][i]['answers'])
    print(tokenizer.decode(tokenizer(updated_dataset['train'][i]['question'], updated_dataset['train'][i]['context'])['input_ids'][updated_dataset['train'][i]['start_positions']:updated_dataset['train'][i]['end_positions']]))

{'text': ['in the late 1990s'], 'answer_start': [269]}
in the late 1990s
{'text': ['singing and dancing'], 'answer_start': [207]}
singing and dancing
{'text': ['2003'], 'answer_start': [526]}
2003
{'text': ['Houston, Texas'], 'answer_start': [166]}
houston, texas
{'text': ['late 1990s'], 'answer_start': [276]}
late 1990s
{'text': ["Destiny's Child"], 'answer_start': [320]}
destiny's child
{'text': ['Dangerously in Love'], 'answer_start': [505]}
dangerously in love
{'text': ['Mathew Knowles'], 'answer_start': [360]}
mathew knowles
{'text': ['late 1990s'], 'answer_start': [276]}
late 1990s
{'text': ['lead singer'], 'answer_start': [290]}
lead singer


In [5]:
def tokenize_function(example):
    return tokenizer(example['question'], example['context'], padding='max_length', truncation=True)

tokenized_dataset = updated_dataset.map(tokenize_function, batched=True)

tokenized_dataset.set_format('torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'start_positions', 'end_positions'])
print(tokenized_dataset)

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers', 'start_positions', 'end_positions', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers', 'start_positions', 'end_positions', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 11873
    })
})


In [6]:
# for i in range(5):
#     print(tokenized_dataset['train'][i])
#     print(tokenized_dataset['train'][i]['input_ids'])
#     print(tokenizer.decode(tokenized_dataset['train'][i]['input_ids']))
#     print(tokenizer.decode(tokenized_dataset['train'][i]['input_ids'] [tokenized_dataset['train'][i]['start_positions']:tokenized_dataset['train'][i]['end_positions']]))
#     print(f"Start position: {tokenized_dataset['train'][i]['start_positions']}")
#     print(f"End position: {tokenized_dataset['train'][i]['end_positions']}")

train_loader = torch.utils.data.DataLoader(tokenized_dataset['train'], batch_size = 8)
val_loader = torch.utils.data.DataLoader(tokenized_dataset['validation'], batch_size = 8)

# for batch in train_loader:
#     print(batch)
#     print(batch['input_ids'])
#     print(batch['token_type_ids'])
#     print(batch['attention_mask'])
#     print(batch['start_positions'])
#     print(batch['end_positions'])
#     print(batch.keys())
#     break

In [7]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = BertForQuestionAnswering.from_pretrained("bert-base-uncased").to(device)
# parallel_model = torch.nn.DataParallel(model)
optimizer = torch.optim.AdamW(model.parameters(),lr = 5e-5)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
from tqdm.auto import tqdm
import time
epochs = 3
whole_train_eval_time = time.time()

train_losses = []
val_losses = []

print_every = 1000

for epoch in range(epochs):
  epoch_time = time.time()

  # Set model in train mode
  model.train()
    
  loss_of_epoch = 0

  print("############Train############")

  for batch_idx,batch in tqdm(enumerate(train_loader), total=len(train_loader)): 
    
    optimizer.zero_grad()
    inputs = {
      "input_ids": batch['input_ids'].to(device),
      'token_type_ids': batch['token_type_ids'].to(device),
      "attention_mask": batch['attention_mask'].to(device),
      "start_positions": batch['start_positions'].to(device),
      "end_positions": batch['end_positions'].to(device),
    }
    # print(inputs)
    outputs = model(**inputs)
    # print(outputs.keys())

    start_logits, end_logits = outputs[1], outputs[2]

    for j in range(len(start_logits)):
        start_prediction = torch.argmax(start_logits[j])
        end_prediction = torch.argmax(end_logits[j]) + 1  # Add 1 to include the end token
        start_gold = batch['start_positions'][j]
        end_gold = batch['end_positions'][j]
        answer = tokenizer.decode(batch['input_ids'][j][start_prediction:end_prediction])
        gold = tokenizer.decode(batch['input_ids'][j][start_gold:end_gold])
        # print(f"Input: {tokenizer.decode(batch['input_ids'][j])}")
        # print(f"Prediction: {answer}")
        # print(f"Start prediction: {start_prediction}")
        # print(f"End prediction: {end_prediction}")
        # print(f"Gold: {gold}")
        # print('-------------------------------------------')

    loss = outputs[0]
    print(f"Loss: {loss.item()}")
    loss_of_epoch += loss.item()
    # do a backwards pass 
    loss.backward()
    # update the weights
    optimizer.step()
    # Find the total loss
    

    if (batch_idx+1) % print_every == 0:
      print("Batch {:} / {:}".format(batch_idx+1,len(train_loader)),"\nLoss:", round(loss.item(),1),"\n")

  loss_of_epoch /= len(train_loader)
  train_losses.append(loss_of_epoch)

  ##########Evaluation##################

  # Set model in evaluation mode
  model.eval()

  print("############Evaluate############")

  loss_of_epoch = 0

  for batch_idx,batch in enumerate(val_loader):
    
    with torch.no_grad():

      inputs = {
        "input_ids": batch['input_ids'].to(device),
        'token_type_ids': batch['token_type_ids'].to(device),
        "attention_mask": batch['attention_mask'].to(device),
        "start_positions": batch['start_positions'].to(device),
        "end_positions": batch['end_positions'].to(device),
      }
      
      outputs = model(**inputs)
      loss = outputs[0]
      # Find the total loss
      loss_of_epoch += loss.item()

    if (batch_idx+1) % print_every == 0:
       print("Batch {:} / {:}".format(batch_idx+1,len(val_loader)),"\nLoss:", round(loss.item(),1),"\n")

  loss_of_epoch /= len(val_loader)
  val_losses.append(loss_of_epoch)

  # Print each epoch's time and train/val loss 
  print("\n-------Epoch ", epoch+1,
        "-------"
        "\nTraining Loss:", train_losses[-1],
        "\nValidation Loss:", val_losses[-1],
        "\nTime: ",(time.time() - epoch_time),
        "\n-----------------------",
        "\n\n")

print("Total training and evaluation time: ", (time.time() - whole_train_eval_time))

############Train############


  0%|          | 0/16290 [00:00<?, ?it/s]

Loss: 6.321177959442139
Loss: 5.959258556365967
Loss: 5.826147556304932
Loss: 5.607283592224121
Loss: 5.487421989440918
Loss: 5.299964427947998
Loss: 5.237456321716309
Loss: 5.078381538391113
Loss: 5.270635604858398
Loss: 5.078268051147461
Loss: 5.073866844177246
Loss: 5.12841272354126
Loss: 4.927915573120117
Loss: 4.724134922027588
Loss: 4.7871198654174805
Loss: 4.657887935638428
Loss: 4.6137871742248535
Loss: 4.598516464233398
Loss: 4.546854019165039
Loss: 3.905991792678833
Loss: 4.441003799438477
Loss: 4.584109783172607
Loss: 4.861584663391113
Loss: 4.467474937438965
Loss: 4.334637641906738
Loss: 3.849069118499756
Loss: 3.4166126251220703
Loss: 4.511617660522461
Loss: 3.987583637237549
Loss: 4.066779613494873
Loss: 3.7989211082458496
Loss: 3.2367262840270996
Loss: 3.1062989234924316
Loss: 4.295470237731934
Loss: 4.276442527770996
Loss: 3.7389750480651855
Loss: 4.397189617156982
Loss: 4.500052452087402
Loss: 3.565493583679199
Loss: 2.815634250640869
Loss: 3.120116949081421
Loss: 3.95

  0%|          | 0/16290 [00:00<?, ?it/s]

Loss: 7.289910316467285
Loss: 5.225017547607422
Loss: 5.814704895019531
Loss: 3.4934003353118896
Loss: 2.427252769470215
Loss: 1.7821866273880005
Loss: 1.8480767011642456
Loss: 1.0009338855743408
Loss: 1.4855232238769531
Loss: 1.1069490909576416
Loss: 1.968748688697815
Loss: 2.25260591506958
Loss: 1.4905749559402466
Loss: 0.8792949914932251
Loss: 2.205695867538452
Loss: 1.642472267150879
Loss: 1.1683790683746338
Loss: 2.0708413124084473
Loss: 1.8335556983947754
Loss: 1.384194254875183
Loss: 1.248022198677063
Loss: 0.8477322459220886
Loss: 1.1495296955108643
Loss: 2.146371841430664
Loss: 1.1479296684265137
Loss: 1.3729274272918701
Loss: 0.5744746923446655
Loss: 1.2123541831970215
Loss: 1.3824931383132935
Loss: 1.552369236946106
Loss: 1.1979566812515259
Loss: 0.6423500180244446
Loss: 0.7128428220748901
Loss: 1.187432885169983
Loss: 0.9609726667404175
Loss: 0.859878420829773
Loss: 1.021324634552002
Loss: 1.9590880870819092
Loss: 0.9828138947486877
Loss: 1.094046950340271
Loss: 0.877194762

  0%|          | 0/16290 [00:00<?, ?it/s]

Loss: 2.610358715057373
Loss: 0.7119404673576355
Loss: 1.9751924276351929
Loss: 1.142841100692749
Loss: 0.5403069853782654
Loss: 0.904380202293396
Loss: 2.1394381523132324
Loss: 0.5724571347236633
Loss: 1.0905506610870361
Loss: 1.1017348766326904
Loss: 1.6800869703292847
Loss: 1.7735508680343628
Loss: 1.6240124702453613
Loss: 0.8852486610412598
Loss: 1.3365750312805176
Loss: 1.7747706174850464
Loss: 1.2622809410095215
Loss: 2.894695997238159
Loss: 1.3970385789871216
Loss: 0.6768884658813477
Loss: 0.851983904838562
Loss: 0.5932559967041016
Loss: 0.9520767331123352
Loss: 1.6454360485076904
Loss: 1.1273269653320312
Loss: 1.4847466945648193
Loss: 0.2297317385673523
Loss: 0.7627726197242737
Loss: 0.7853822708129883
Loss: 1.3747012615203857
Loss: 1.0977115631103516
Loss: 0.1961214542388916
Loss: 0.47544822096824646
Loss: 0.524802565574646
Loss: 0.5807883739471436
Loss: 0.7267522811889648
Loss: 0.9920229911804199
Loss: 0.9427201747894287
Loss: 0.6114982962608337
Loss: 0.5151708722114563
Loss:

In [None]:
model.push_to_hub("pgajo/bert-base-uncased-squad2")