In [1]:
import json
import pandas as pd
import re
import numpy as np
import torch
import random
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForTokenClassification
from sklearn.metrics import accuracy_score, f1_score
import matplotlib.pyplot as plt
from IPython.display import clear_output
import time
from torch.nn import DataParallel

from sklearn.model_selection import train_test_split

In [2]:
from prepare_short_ans_dataset import ShortAnswerDataset, simplify_nq_example
from short_ans_model import ShortAnswerModel

In [3]:
def read_json(filename):
    with open(filename, 'r') as json_file:
        json_list = list(json_file)
        
    data = []
    for json_str in json_list:
        data.append(json.loads(json_str))
    return data

In [4]:
train_data = read_json("v1.0-simplified_simplified-nq-train.jsonl")
dev_data = read_json("v1.0-simplified_nq-dev-all.jsonl")

In [5]:
val_data, test_data = train_test_split(dev_data, test_size=0.4, random_state=42)

In [6]:
print(f"Train size: {len(train_data)}")
print(f"Validation size: {len(val_data)}")
print(f"Test size: {len(test_data)}")

Train size: 307373
Validation size: 4698
Test size: 3132


-------

In [8]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

In [9]:
col_spans = []
for i in range(1, 10):
    col_spans.append(f'<Td_colspan="{i}">')
    col_spans.append(f'<Th_colspan="{i}">')
tokenizer.add_tokens(['</Td>', '<Td>', '</Tr>', '<Tr>', '<Th>', '</Th>', '<Li>', '</Li>', '<Ul>', '</Ul>', '<Table>', '</Table>'])
tokenizer.add_tokens(col_spans)

18

In [10]:
model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=2)
model.resize_token_embeddings(len(tokenizer))

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

Embedding(29026, 768)

In [11]:
max_len = 500

train_dataset = ShortAnswerDataset(train_data, simplify_nq_example, tokenizer, max_len, should_simplify=False)
val_dataset = ShortAnswerDataset(val_data, simplify_nq_example, tokenizer, max_len, should_simplify=True)
test_dataset = ShortAnswerDataset(test_data, simplify_nq_example, tokenizer, max_len, should_simplify=True)

100%|█████████████████████████████████████████████████████████████████████████| 307373/307373 [03:11<00:00, 1606.46it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 4698/4698 [00:44<00:00, 105.45it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 3132/3132 [00:30<00:00, 104.32it/s]


In [13]:
bs = 32

train_dataloader = DataLoader(train_dataset, batch_size=bs, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=bs, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=bs, shuffle=True)

In [14]:
print(f"Training Steps: {len(train_dataloader)}")
print(f"Validation Steps: {len(val_dataloader)}")

Training Steps: 5089
Validation Steps: 90


In [16]:
lr = 1e-05
optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)

In [17]:
device = 'cuda'
answer_model = ShortAnswerModel(model, device=device)

In [7]:
answer_model.train(train_dataloader, val_dataloader, 5, optimizer, checkpoint_step=1000)

In [47]:
torch.save(model.state_dict(), 'short_model.pt')