In [7]:
import json
import numpy as np
import torch
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizerFast, BertForTokenClassification
from sklearn.metrics import accuracy_score, f1_score
import time
from sklearn.model_selection import train_test_split

In [18]:
from prepare_short_ans_dataset import ShortAnswerDataset, TestShortAnswerDataset, simplify_nq_example
from short_ans_model import ShortAnswerModel

In [9]:
def read_json(filename):
    with open(filename, 'r') as json_file:
        json_list = list(json_file)
        
    data = []
    for json_str in json_list:
        data.append(json.loads(json_str))
    return data

In [10]:
dev_data = read_json("v1.0-simplified_nq-dev-all.jsonl")

In [11]:
val_data, test_data = train_test_split(dev_data, test_size=0.4, random_state=42)

---------

In [12]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-cased')

col_spans = []
for i in range(1, 10):
    col_spans.append(f'<Td_colspan="{i}">')
    col_spans.append(f'<Th_colspan="{i}">')
tokenizer.add_tokens(['</Td>', '<Td>', '</Tr>', '<Tr>', '<Th>', '</Th>', '<Li>', '</Li>', '<Ul>', '</Ul>', '<Table>', '</Table>'])
tokenizer.add_tokens(col_spans)

18

In [13]:
model = BertForTokenClassification.from_pretrained('bert-base-cased', num_labels=2)
model.resize_token_embeddings(len(tokenizer))

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForTokenClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cas

Embedding(29026, 768)

In [14]:
max_len = 500
test_dataset = TestShortAnswerDataset(test_data, simplify_nq_example, tokenizer, max_len, should_simplify=True)

100%|██████████████████████████████████████████████████████████████████████████████| 3132/3132 [00:27<00:00, 112.35it/s]


In [15]:
bs = 1
test_dataloader = DataLoader(test_dataset, batch_size=bs, shuffle=True)

In [16]:
def join_test_example_preds(pred_starts, pred_ends):
    if len(pred_starts) == pred_starts.count(-1):
        return f"{-1}:{-1}"
        
    for i in range(len(pred_starts)):
        if pred_starts[i] != -1:
            return f"{pred_starts[i]}:{pred_ends[i]}"
    return f"{-1}:{-1}"

In [17]:
def format_test_gold(start, end, long_start):
    if start == -1:
        return f"{-1}:{-1}"
    
    start = start - long_start
    end = end - long_start
    return f"{start}:{end}"

In [20]:
model.load_state_dict(torch.load('short_model_2.pt'))

<All keys matched successfully>

In [21]:
device = 'cuda'
answer_model = ShortAnswerModel(model, device=device)

In [22]:
pred, gold = [], []

for data in tqdm(test_dataloader):
    pred_starts, pred_ends = [], []
    encodings, question_len, short_info, long_info = data
    for i, enc in enumerate(encodings):
        s, e = answer_model(enc, [question_len] * bs)[0]
        s, e = int(s), int(e)
        if s != -1:
            s += (max_len-150)*i
            e += (max_len-150)*i
        pred_starts.append(s)
        pred_ends.append(e)
    pred_ans = join_test_example_preds(pred_starts, pred_ends)
    gold_ans = format_test_gold(int(short_info[0]['start_token']), 
                                int(short_info[0]['end_token']), int(long_info[0]['start_token']))
    pred.append(pred_ans)
    gold.append(gold_ans)

100%|███████████████████████████████████████████████████████████████████████████████| 1028/1028 [01:04<00:00, 15.82it/s]


In [23]:
f = f1_score(gold, pred, average='micro')
print(f'Test F score: {f:.3f}')

Test F score: 0.445
