In [1]:
!pip install -U transformers

Collecting transformers
  Downloading transformers-3.0.2-py3-none-any.whl (769 kB)
[K     |████████████████████████████████| 769 kB 574 kB/s 
Collecting tokenizers==0.8.1.rc1
  Downloading tokenizers-0.8.1rc1-cp37-cp37m-manylinux1_x86_64.whl (3.0 MB)
[K     |████████████████████████████████| 3.0 MB 4.0 MB/s 
[31mERROR: allennlp 1.0.0 has requirement transformers<2.12,>=2.9, but you'll have transformers 3.0.2 which is incompatible.[0m
Installing collected packages: tokenizers, transformers
  Attempting uninstall: tokenizers
    Found existing installation: tokenizers 0.7.0
    Uninstalling tokenizers-0.7.0:
      Successfully uninstalled tokenizers-0.7.0
  Attempting uninstall: transformers
    Found existing installation: transformers 2.11.0
    Uninstalling transformers-2.11.0:
      Successfully uninstalled transformers-2.11.0
Successfully installed tokenizers-0.8.1rc1 transformers-3.0.2
You should consider upgrading via the '/opt/conda/bin/python3.7 -m pip insta

In [2]:
import random
import numpy as np
import pandas as pd
from tqdm import tqdm
tqdm.pandas()

from transformers import BertTokenizer, BertPreTrainedModel, BertConfig, BertModel
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader

  from pandas import Panel


In [3]:
random.seed(2020)
np.random.seed(2020)
torch.manual_seed(2020)
torch.cuda.manual_seed_all(2020)

In [4]:
def load_json_as_pandas_df(filename):
    df = pd.read_json(f'/kaggle/input/squad-20/{filename}', orient='records')
    df = pd.DataFrame.from_records(df['data'])
    df = df.explode('paragraphs').reset_index(drop=True)
    df['context'] = df['paragraphs'].apply(lambda x: x['context'])
    df['qas'] = df['paragraphs'].apply(lambda x: x['qas'])
    df = df[['title', 'context', 'qas']]
    df = df.explode('qas').reset_index(drop=True)
    df = pd.concat([df, pd.DataFrame.from_records(df['qas'])], axis=1).drop(['qas'], axis=1)
    df.loc[df['is_impossible'], 'answers'] = df.loc[df['is_impossible'], 'plausible_answers']
    df = df.drop(['plausible_answers'], axis=1)
    df['answers'] = df['answers'].apply(lambda x: x[0] if len(x) else {'text': '', 'answer_start': 0})
    return df

In [5]:
train_df = load_json_as_pandas_df('train-v2.0.json')
train_df

Unnamed: 0,title,context,question,id,answers,is_impossible
0,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce start becoming popular?,56be85543aeaaa14008c9063,"{'text': 'in the late 1990s', 'answer_start': ...",False
1,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,What areas did Beyonce compete in when she was...,56be85543aeaaa14008c9065,"{'text': 'singing and dancing', 'answer_start'...",False
2,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,When did Beyonce leave Destiny's Child and bec...,56be85543aeaaa14008c9066,"{'text': '2003', 'answer_start': 526}",False
3,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In what city and state did Beyonce grow up?,56bf6b0f3aeaaa14008c9601,"{'text': 'Houston, Texas', 'answer_start': 166}",False
4,Beyoncé,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,In which decade did Beyonce become famous?,56bf6b0f3aeaaa14008c9602,"{'text': 'late 1990s', 'answer_start': 276}",False
...,...,...,...,...,...,...
130314,Matter,"The term ""matter"" is used throughout physics i...",Physics has broadly agreed on the definition o...,5a7e070b70df9f001a875439,"{'text': 'matter', 'answer_start': 485}",True
130315,Matter,"The term ""matter"" is used throughout physics i...",Who coined the term partonic matter?,5a7e070b70df9f001a87543a,"{'text': 'Alfvén', 'answer_start': 327}",True
130316,Matter,"The term ""matter"" is used throughout physics i...",What is another name for anti-matter?,5a7e070b70df9f001a87543b,"{'text': 'Gk. common matter', 'answer_start': ...",True
130317,Matter,"The term ""matter"" is used throughout physics i...",Matter usually does not need to be used in con...,5a7e070b70df9f001a87543c,"{'text': 'a specifying modifier', 'answer_star...",True


In [6]:
valid_df = load_json_as_pandas_df('dev-v2.0.json')
valid_df

Unnamed: 0,title,context,question,id,answers,is_impossible
0,Normans,The Normans (Norman: Nourmands; French: Norman...,In what country is Normandy located?,56ddde6b9a695914005b9628,"{'text': 'France', 'answer_start': 159}",False
1,Normans,The Normans (Norman: Nourmands; French: Norman...,When were the Normans in Normandy?,56ddde6b9a695914005b9629,"{'text': '10th and 11th centuries', 'answer_st...",False
2,Normans,The Normans (Norman: Nourmands; French: Norman...,From which countries did the Norse originate?,56ddde6b9a695914005b962a,"{'text': 'Denmark, Iceland and Norway', 'answe...",False
3,Normans,The Normans (Norman: Nourmands; French: Norman...,Who was the Norse leader?,56ddde6b9a695914005b962b,"{'text': 'Rollo', 'answer_start': 308}",False
4,Normans,The Normans (Norman: Nourmands; French: Norman...,What century did the Normans first gain their ...,56ddde6b9a695914005b962c,"{'text': '10th century', 'answer_start': 671}",False
...,...,...,...,...,...,...
11868,Force,"The pound-force has a metric counterpart, less...",What is the seldom used force unit equal to on...,5737aafd1c456719005744ff,"{'text': 'sthène', 'answer_start': 665}",False
11869,Force,"The pound-force has a metric counterpart, less...",What does not have a metric counterpart?,5ad28ad0d7d075001a4299cc,"{'text': 'pound-force', 'answer_start': 4}",True
11870,Force,"The pound-force has a metric counterpart, less...",What is the force exerted by standard gravity ...,5ad28ad0d7d075001a4299cd,"{'text': 'kilogram-force', 'answer_start': 82}",True
11871,Force,"The pound-force has a metric counterpart, less...",What force leads to a commonly used unit of mass?,5ad28ad0d7d075001a4299ce,"{'text': 'kilogram-force', 'answer_start': 195}",True


In [7]:
max_len = 512
epochs = 2
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




In [8]:
# question + context tokens length

# count    130319.000000
# mean        170.726632
# std          65.407215
# min          35.000000
# 25%         129.000000
# 50%         158.000000
# 75%         200.000000
# max         870.000000
# dtype: float64

# Around 200 examples have token length of more than 500

In [9]:
class SQuAD2Dataset(Dataset):
    def __init__(self, df, tokenizer, max_len):
        self.df = df
        self.tokenizer = tokenizer
        self.max_len = max_len
        
    def __len__(self):
        return len(self.df)
        
    def __getitem__(self, idx):
        context = self.df['context'][idx]
        question = self.df['question'][idx]
        answer = self.df['answers'][idx]['text']
        answer_start = self.df['answers'][idx]['answer_start']
        is_impossible = self.df['is_impossible'][idx]
        qas_id = self.df['id'][idx]
        
        input_text = '[CLS] ' + question + ' [SEP] ' + context + ' [SEP]'
        input_tokens = self.tokenizer.tokenize(input_text)
        
        start_position = 0
        end_position = 0
        
        if not is_impossible:
            answer_tokens = self.tokenizer.tokenize(answer)
            
            for i, tok in enumerate(input_tokens):
                if tok == answer_tokens[0] and input_tokens[min(len(input_tokens)-1, i+len(answer_tokens)-1)] == answer_tokens[-1]:
                    start_position = i
                    end_position = min(len(input_tokens)-1, i+len(answer_tokens)-1)
                    if start_position >= max_len or end_position >= max_len:
                        start_position = 0
                        end_position = 0

        token_type_ids = [0] * (len(self.tokenizer.tokenize(question)) + 2)
        token_type_ids = token_type_ids + [1] * (len(input_tokens) - len(token_type_ids))

        input_ids = self.tokenizer.convert_tokens_to_ids(input_tokens)
        attention_mask = [1] * len(input_tokens)

        # Truncate
        if len(input_ids) > self.max_len:
            input_ids = input_ids[:self.max_len]
            attention_mask = attention_mask[:self.max_len]
            token_type_ids = token_type_ids[:self.max_len]
            
        # Padding
        if len(input_ids) < self.max_len:
            input_ids = input_ids + [0] * (max_len - len(input_tokens))
            attention_mask = attention_mask + [0] * (max_len - len(input_tokens))
            token_type_ids = token_type_ids + [0] * (max_len - len(input_tokens))
        
        # convert lists to tensor
        input_ids = torch.tensor(input_ids, dtype=torch.long)
        attention_mask = torch.tensor(attention_mask, dtype=torch.float)
        token_type_ids = torch.tensor(token_type_ids, dtype=torch.long)
        start_position = torch.tensor(start_position, dtype=torch.long).unsqueeze(0)
        end_position = torch.tensor(end_position, dtype=torch.long).unsqueeze(0)
        
        return input_ids, attention_mask, token_type_ids, start_position, end_position

In [10]:
train_dataset = SQuAD2Dataset(train_df, tokenizer, max_len)
valid_dataset = SQuAD2Dataset(valid_df, tokenizer, max_len)

In [11]:
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
valid_dataloader = DataLoader(valid_dataset, batch_size=256, shuffle=False)

In [12]:
class SQuAD2Model(BertPreTrainedModel):
    def __init__(self, conf):
        super(SQuAD2Model, self).__init__(conf)
        self.bert = BertModel(conf)
        self.drop = nn.Dropout(0.1)
        self.fc = nn.Linear(768, 2)
    
    def forward(self, ids, mask, token_type_ids):
        _, _, out = self.bert(ids, attention_mask=mask, token_type_ids=token_type_ids)

        out = self.drop(out[0])
        logits = self.fc(out)

        start_logits, end_logits = logits.split(1, dim=-1)

        start_logits = start_logits.squeeze(-1)
        end_logits = end_logits.squeeze(-1)

        return start_logits, end_logits

In [13]:
model_config = BertConfig.from_pretrained('bert-base-uncased')
model_config.output_hidden_states = True
model = SQuAD2Model(conf=model_config)
model.to(device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




SQuAD2Model(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)


In [14]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)
criterion = nn.CrossEntropyLoss()

In [15]:
for epoch in range(epochs):

    model.train()

    for step, batch in enumerate(tqdm(train_dataloader)):
        optimizer.zero_grad()

        input_ids, attention_masks, token_type_ids, start_positions, end_positions = batch
        input_ids = input_ids.to(device)
        attention_masks = attention_masks.to(device)
        token_type_ids = token_type_ids.to(device)
        start_positions = start_positions.squeeze().to(device)
        end_positions = end_positions.squeeze().to(device)

        outputs_start, outputs_end = model(input_ids, attention_masks, token_type_ids)

        loss_start = criterion(outputs_start, start_positions)
        loss_end = criterion(outputs_end, end_positions)
        total_loss = loss_start + loss_end
        total_loss.backward()

        optimizer.step()
        if (step+1) % 1000 == 0:
            print(f'Epoch: {epoch+1} || Step: {step+1} || Training Loss: {total_loss.item()}')

    model.eval()

    true_start = []
    true_end = []
    pred_start = []
    pred_end = []

    with torch.no_grad():

        for batch in tqdm(valid_dataloader):

            input_ids, attention_masks, token_type_ids, start_positions, end_positions = batch
            input_ids = input_ids.to(device)
            attention_masks = attention_masks.to(device)
            token_type_ids = token_type_ids.to(device)
            start_positions = start_positions.squeeze().to(device)
            end_positions = end_positions.squeeze().to(device)

            outputs_start, outputs_end = model(input_ids, attention_masks, token_type_ids)

            true_start.append(start_positions)
            true_end.append(end_positions)
            pred_start.append(outputs_start)
            pred_end.append(outputs_end)

    true_start = torch.cat(true_start)
    true_end = torch.cat(true_end)
    pred_start = torch.cat(pred_start)
    pred_end = torch.cat(pred_end)

    loss = criterion(pred_start, true_start) + criterion(pred_end, true_end)

    print(f'Validation Loss: {loss.item()}')
    
    torch.save({'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': loss.item()},
               f'model_{epoch+1}.pth')

 12%|█▏        | 1000/8145 [06:15<45:19,  2.63it/s]

Epoch: 1 || Step: 1000 || Training Loss: 7.7994561195373535


 25%|██▍       | 2000/8145 [12:32<42:34,  2.41it/s]

Epoch: 1 || Step: 2000 || Training Loss: 6.749929428100586


 37%|███▋      | 3000/8145 [18:50<31:50,  2.69it/s]

Epoch: 1 || Step: 3000 || Training Loss: 7.677448272705078


 49%|████▉     | 4000/8145 [25:11<30:38,  2.26it/s]

Epoch: 1 || Step: 4000 || Training Loss: 6.591978073120117


 61%|██████▏   | 5000/8145 [31:32<19:45,  2.65it/s]

Epoch: 1 || Step: 5000 || Training Loss: 7.486364841461182


 74%|███████▎  | 6000/8145 [37:52<14:44,  2.42it/s]

Epoch: 1 || Step: 6000 || Training Loss: 6.120998382568359


 86%|████████▌ | 7000/8145 [44:10<07:05,  2.69it/s]

Epoch: 1 || Step: 7000 || Training Loss: 7.568524360656738


 98%|█████████▊| 8000/8145 [50:25<00:54,  2.68it/s]

Epoch: 1 || Step: 8000 || Training Loss: 8.043097496032715


100%|██████████| 8145/8145 [51:19<00:00,  2.64it/s]
100%|██████████| 47/47 [03:08<00:00,  4.01s/it]


Validation Loss: 5.95525598526001


 12%|█▏        | 1000/8145 [06:17<44:27,  2.68it/s]

Epoch: 2 || Step: 1000 || Training Loss: 8.087492942810059


 25%|██▍       | 2000/8145 [12:35<38:22,  2.67it/s]

Epoch: 2 || Step: 2000 || Training Loss: 7.828147888183594


 37%|███▋      | 3000/8145 [18:52<31:36,  2.71it/s]

Epoch: 2 || Step: 3000 || Training Loss: 5.915700912475586


 49%|████▉     | 4000/8145 [25:11<28:01,  2.46it/s]

Epoch: 2 || Step: 4000 || Training Loss: 5.252338409423828


 61%|██████▏   | 5000/8145 [31:26<20:34,  2.55it/s]

Epoch: 2 || Step: 5000 || Training Loss: 6.153585433959961


 74%|███████▎  | 6000/8145 [37:41<13:07,  2.72it/s]

Epoch: 2 || Step: 6000 || Training Loss: 9.335138320922852


 86%|████████▌ | 7000/8145 [43:56<08:28,  2.25it/s]

Epoch: 2 || Step: 7000 || Training Loss: 8.138043403625488


 98%|█████████▊| 8000/8145 [50:12<00:53,  2.73it/s]

Epoch: 2 || Step: 8000 || Training Loss: 7.154268741607666


100%|██████████| 8145/8145 [51:06<00:00,  2.66it/s]
100%|██████████| 47/47 [03:08<00:00,  4.01s/it]


Validation Loss: 5.883515357971191
