### https://towardsdatascience.com/how-to-fine-tune-a-q-a-transformer-86f91ec92997

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


### Import Libraries

In [2]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ed/d5/f4157a376b8a79489a76ce6cfe147f4f3be1e029b7144fa7b8432e8acb26/transformers-4.4.2-py3-none-any.whl (2.0MB)
[K     |████████████████████████████████| 2.0MB 6.9MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 21.3MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/71/23/2ddc317b2121117bf34dd00f5b0de194158f2a44ee2bf5e47c7166878a97/tokenizers-0.10.1-cp37-cp37m-manylinux2010_x86_64.whl (3.2MB)
[K     |████████████████████████████████| 3.2MB 38.2MB/s 
Building wheels for collected packages: sacremoses
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp37-none-any.whl size=893262 sha256=a04dd62f2eb

In [3]:
import json
import os
import re
import sys
import pandas as pd

import requests
import string
import numpy as np
!pip install colorama
from colorama import Fore
from tokenizers import BertWordPieceTokenizer
from tqdm import tqdm
from transformers import BertTokenizer, BertForQuestionAnswering
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split

gpu = torch.device('cuda')

Collecting colorama
  Downloading https://files.pythonhosted.org/packages/44/98/5b86278fbbf250d239ae0ecb724f8572af1c91f4a11edf4d36a206189440/colorama-0.4.4-py2.py3-none-any.whl
Installing collected packages: colorama
Successfully installed colorama-0.4.4


In [4]:
df = pd.read_csv('/content/gdrive/MyDrive/tmproj/legal_squad_data.csv')
df.head()

Unnamed: 0.1,Unnamed: 0,context_id,qna_id,context,question,answer_start,answer
0,0,3196,56cee398aab44d1400b88bfb,"In 1785, the assembly of the Congress of the C...",In what year did New York become the United St...,3,1785
1,1,3196,56cee398aab44d1400b88bfc,"In 1785, the assembly of the Congress of the C...",Who was the United States' first President?,313,George Washington
2,2,3196,56cee398aab44d1400b88bfd,"In 1785, the assembly of the Congress of the C...",In what building did the Supreme Court of the ...,517,Federal Hall
3,3,3196,56cee398aab44d1400b88bfe,"In 1785, the assembly of the Congress of the C...",On what street did the writing of the Bill of ...,533,Wall Street
4,4,3196,56cee398aab44d1400b88bff,"In 1785, the assembly of the Congress of the C...",What was the second largest city in the United...,578,Philadelphia


In [5]:
y = [{'text': row['answer'], 'answer_start': row['answer_start']} for idx, row in df.iterrows()]
y[:4]

[{'answer_start': 3, 'text': '1785'},
 {'answer_start': 313, 'text': 'George Washington'},
 {'answer_start': 517, 'text': 'Federal Hall'},
 {'answer_start': 533, 'text': 'Wall Street'}]

In [6]:
df.loc[:,['context_id', 'qna_id', 'context', 'question']]

Unnamed: 0,context_id,qna_id,context,question
0,3196,56cee398aab44d1400b88bfb,"In 1785, the assembly of the Congress of the C...",In what year did New York become the United St...
1,3196,56cee398aab44d1400b88bfc,"In 1785, the assembly of the Congress of the C...",Who was the United States' first President?
2,3196,56cee398aab44d1400b88bfd,"In 1785, the assembly of the Congress of the C...",In what building did the Supreme Court of the ...
3,3196,56cee398aab44d1400b88bfe,"In 1785, the assembly of the Congress of the C...",On what street did the writing of the Bill of ...
4,3196,56cee398aab44d1400b88bff,"In 1785, the assembly of the Congress of the C...",What was the second largest city in the United...
...,...,...,...,...
4852,86289,57344c34acc1501500babdc3,"Nevertheless, although a distinction between l...",Where did synods prohibit all hunting at?
4853,86289,57344c34acc1501500babdc4,"Nevertheless, although a distinction between l...",What did Benedict XIV declare about decrees pr...
4854,86289,5735ffae012e2f140011a115,"Nevertheless, although a distinction between l...",Who can prohibit hunting to the clerics?
4855,86289,5735ffae012e2f140011a116,"Nevertheless, although a distinction between l...",Declaration that decrees are not severe was do...


In [7]:
X_train, X_test, y_train, y_test = train_test_split(df.loc[:,['context_id', 'qna_id', 'context', 'question']], y, test_size=0.33, random_state=42)
train_contexts, train_questions, train_answers = list(X_train['context']), list(X_train['question']), y_train
val_contexts, val_questions, val_answers = list(X_test['context']), list(X_test['question']), y_test

In [8]:
train_contexts[0]

'Switzerland was the last Western republic to grant women the right to vote. Some Swiss cantons approved this in 1959, while at the federal level it was achieved in 1971 and, after resistance, in the last canton Appenzell Innerrhoden (one of only two remaining Landsgemeinde) in 1990. After obtaining suffrage at the federal level, women quickly rose in political significance, with the first woman on the seven member Federal Council executive being Elisabeth Kopp, who served from 1984–1989, and the first female president being Ruth Dreifuss in 1999.'

In [9]:
train_questions[0]

'Who did Switzerland finally grant the right to vote to following the rest of the Western republic?'

In [10]:
train_answers[0]

{'answer_start': 51, 'text': 'women'}

In [11]:
def add_end_idx(answers, contexts):
    # loop through each answer-context pair
    for answer, context in zip(answers, contexts):
        # gold_text refers to the answer we are expecting to find in context
        gold_text = answer['text']
        # we already know the start index
        start_idx = answer['answer_start']
        # and ideally this would be the end index...
        end_idx = start_idx + len(gold_text)

        # ...however, sometimes squad answers are off by a character or two
        if context[start_idx:end_idx] == gold_text:
            # if the answer is not off :)
            answer['answer_end'] = end_idx
        else:
            # this means the answer is off by 1-2 tokens
            for n in [1, 2]:
                if context[start_idx-n:end_idx-n] == gold_text:
                    answer['answer_start'] = start_idx - n
                    answer['answer_end'] = end_idx - n
            
# and apply the function to our two answer lists
add_end_idx(train_answers, train_contexts)
add_end_idx(val_answers, val_contexts)

In [12]:
from transformers import DistilBertTokenizerFast
# initialize the tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
# tokenize
train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




In [13]:
def add_token_positions(encodings, answers):
    # initialize lists to contain the token indices of answer start/end
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        # append start/end token position using char_to_token method
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end']))

        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        # end position cannot be found, char_to_token found space, so shift position until found
        shift = 1
        while end_positions[-1] is None:
            end_positions[-1] = encodings.char_to_token(i, answers[i]['answer_end'] - shift)
            shift += 1
    # update our encodings object with the new token-based start/end positions
    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

# apply function to our data
add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)

In [14]:
import torch

class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

# build datasets for both our training and validation sets
train_dataset = SquadDataset(train_encodings)
val_dataset = SquadDataset(val_encodings)

In [22]:
from transformers import DistilBertForQuestionAnswering
model = DistilBertForQuestionAnswering.from_pretrained('distilbert-base-uncased')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForQuestionAnswering: ['vocab_transform.weight', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_projector.weight', 'vocab_projector.bias']
- This IS expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForQuestionAnswering from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForQuestionAnswering were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['qa_outputs.weight', 'qa_outputs.bias']
You should probably TRAIN this mode

In [None]:
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm

# setup GPU/CPU
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
# move model over to detected device
model.to(device)
# activate training mode of model
model.train()
# initialize adam optimizer with weight decay (reduces chance of overfitting)
optim = AdamW(model.parameters(), lr=5e-5)

# initialize data loader for training data
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

for epoch in range(3):
    # set model to train mode
    model.train()
    # setup loop (we use tqdm for the progress bar)
    loop = tqdm(train_loader, leave=True)
    for batch in loop:
        # initialize calculated gradients (from prev step)
        optim.zero_grad()
        # pull all the tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        start_positions = batch['start_positions'].to(device)
        end_positions = batch['end_positions'].to(device)
        # train model on batch and return outputs (incl. loss)
        outputs = model(input_ids, attention_mask=attention_mask,
                        start_positions=start_positions,
                        end_positions=end_positions)
        # extract loss
        loss = outputs[0]
        # calculate loss for every parameter that needs grad update
        loss.backward()
        # update parameters
        optim.step()
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())



  0%|          | 0/204 [00:00<?, ?it/s][A[A

Epoch 0:   0%|          | 0/204 [00:04<?, ?it/s][A[A

Epoch 0:   0%|          | 0/204 [00:04<?, ?it/s, loss=6.41][A[A

Epoch 0:   0%|          | 1/204 [00:04<14:04,  4.16s/it, loss=6.41][A[A

Epoch 0:   0%|          | 1/204 [00:07<14:04,  4.16s/it, loss=6.41][A[A

Epoch 0:   0%|          | 1/204 [00:07<14:04,  4.16s/it, loss=6.15][A[A

Epoch 0:   1%|          | 2/204 [00:07<13:29,  4.01s/it, loss=6.15][A[A

Epoch 0:   1%|          | 2/204 [00:11<13:29,  4.01s/it, loss=6.15][A[A

Epoch 0:   1%|          | 2/204 [00:11<13:29,  4.01s/it, loss=6.02][A[A

Epoch 0:   1%|▏         | 3/204 [00:11<12:53,  3.85s/it, loss=6.02][A[A

Epoch 0:   1%|▏         | 3/204 [00:14<12:53,  3.85s/it, loss=6.02][A[A

Epoch 0:   1%|▏         | 3/204 [00:14<12:53,  3.85s/it, loss=5.88][A[A

Epoch 0:   2%|▏         | 4/204 [00:14<12:33,  3.77s/it, loss=5.88][A[A

Epoch 0:   2%|▏         | 4/204 [00:18<12:33,  3.77s/it, loss=5.88][A[A

Epo