In [1]:
from collections import defaultdict
from dataclasses import dataclass
import os
import random
import time
from typing import Callable, Dict, List, Generator, Tuple
from data_pre_process import *
from model import *
from data_loader import *
import gc

import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from tqdm._tqdm_notebook import tqdm_notebook as tqdm

import torch
from torch import nn, optim
import torch.cuda.amp 
from pathlib import Path
from torch.cuda.amp import GradScaler as scaler

from torch.utils.data import Dataset, Subset, DataLoader

from transformers import BertTokenizer, AdamW, BertModel, get_linear_schedule_with_warmup, BertPreTrainedModel

Please use `tqdm.notebook.*` instead of `tqdm._tqdm_notebook.*`
  from ipykernel import kernelapp as app


In [2]:
init_start_time = time.time()

bert_model = 'bert-base-uncased'
do_lower_case = 'uncased' in bert_model
device = torch.device('cuda')

# data_dir_t = Path('data_2/v1.0/train')
# data_path_t = data_dir_t/'nq-train-00.jsonl.gz'

# data_dir_v = Path('data_2/v1.0/dev')
# data_path_v = data_dir_v/'nq-dev-00.jsonl.gz'

data_dir_t = Path('data')
data_path_t = data_dir_t/'v1.0_train.jsonl.gz'

data_dir_v = Path('data')
data_path_v = data_dir_v/'v1.0_dev.jsonl.gz'

In [3]:
chunksize = 1000
max_seq_len = 384
max_question_len = 64
doc_stride = 128
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case='uncased' in 'bert-base-uncased')

convert_func = functools.partial(convert_data,
                                 tokenizer=tokenizer,
                                 max_seq_len=max_seq_len,
                                 max_question_len=max_question_len,
                                 doc_stride=doc_stride,
                                 val=False)

In [4]:
start = time.time()
with gzip.open(data_path_t, "rb") as f:
    data = f.read()
x = data.splitlines()
data_reader = JsonlReader(x, convert_func, chunksize=chunksize)
end = time.time()
print("Loading Data:", end - start, "seconds")

train_size = len(x)

Loading Data: 153.53386163711548 seconds


In [None]:
num_labels = 5
n_epochs = 1
lr = 2e-5
warmup = 0.05
batch_size = 16
accumulation_steps = 4

In [None]:
model = BertForQuestionAnswering.from_pretrained(bert_model, num_labels=5)
model = model.to(device)

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
train_optimization_steps = int(n_epochs * train_size / batch_size / accumulation_steps)
warmup_steps = int(train_optimization_steps * warmup)

optimizer = AdamW(optimizer_grouped_parameters, lr=lr, correct_bias=False)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=train_optimization_steps)

s = torch.cuda.amp.GradScaler()
model.zero_grad()
model = model.train()

In [None]:
global_step = 0
start_time = time.time()
for examples in tqdm(data_reader, total=int(np.floor(train_size/chunksize))):
    train_dataset = TextDataset(examples)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
    for x_batch, y_batch in train_loader:
        x_batch, attention_mask, token_type_ids = x_batch
        y_batch = (y.to(device) for y in y_batch)
        
        y_pred = model(x_batch.to(device),
                       attention_mask=attention_mask.to(device),
                       token_type_ids=token_type_ids.to(device))
        
        loss = loss_fn(y_pred, y_batch)
        loss.backward()
        
        if (global_step + 1) % accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            model.zero_grad()
        global_step += 1
        
    if int((time.time() - start_time) / 3600) % 4 == 0 and int((time.time() - start_time) / 3600) > 1:
        torch.save(model.state_dict(), 'bert_pytorch.bin')
        torch.save(optimizer.state_dict(), 'bert_pytorch_optimizer.bin')
        break
        
del examples, train_dataset, train_loader
x = gc.collect()

In [None]:
def eval_collate_fn(examples: List[Example]) -> Tuple[List[torch.Tensor], List[Example]]:
    # input tokens
    max_len = max([len(example.input_ids) for example in examples])
    tokens = np.zeros((len(examples), max_len), dtype=np.int64)
    token_type_ids = np.ones((len(examples), max_len), dtype=np.int64)
    for i, example in enumerate(examples):
        row = example.input_ids
        tokens[i, :len(row)] = row
        token_type_id = [0 if i <= row.index(102) else 1
                         for i in range(len(row))]  # 102 corresponds to [SEP]
        token_type_ids[i, :len(row)] = token_type_id
    attention_mask = tokens > 0
    inputs = [torch.from_numpy(tokens),
              torch.from_numpy(attention_mask),
              torch.from_numpy(token_type_ids)]

    return inputs, examples


def eval_model(
    model: nn.Module,
    valid_loader: DataLoader,
    device: torch.device = torch.device('cuda')
) -> Dict[str, float]:
    """Compute validation score.
    
    Parameters
    ----------
    model : nn.Module
        Model for prediction.
    valid_loader : DataLoader
        Data loader of validation data.
    device : torch.device, optional
        Device for computation.
    
    Returns
    -------
    dict
        Scores of validation data.
        `long_score`: score of long answers
        `short_score`: score of short answers
        `overall_score`: score of the competition metric
    """
    model.to(device)
    model.eval()
    with torch.no_grad():
        result = Result()
        for inputs, examples in tqdm(valid_loader):
            input_ids, attention_mask, token_type_ids = inputs
            y_preds = model(input_ids.to(device),
                            attention_mask.to(device),
                            token_type_ids.to(device))
            
            start_preds, end_preds, class_preds = (p.detach().cpu() for p in y_preds)
            start_logits, start_index = torch.max(start_preds, dim=1)
            end_logits, end_index = torch.max(end_preds, dim=1)

            # span logits minus the cls logits seems to be close to the best
            cls_logits = start_preds[:, 0] + end_preds[:, 0]  # '[CLS]' logits
            logits = start_logits + end_logits - cls_logits  # (batch_size,)
            indices = torch.stack((start_index, end_index)).transpose(0, 1)  # (batch_size, 2)
            result.update(examples, logits.numpy(), indices.numpy(), class_preds.numpy())

    return result.score()


class Result(object):
    """Stores results of all test data.
    """
    
    def __init__(self):
        self.examples = {}
        self.results = {}
        self.best_scores = defaultdict(float)
        self.class_labels = ['LONG', 'NO', 'SHORT', 'UNKNOWN', 'YES']
        
    @staticmethod
    def is_valid_index(example: Example, index: List[int]) -> bool:
        """Return whether valid index or not.
        """
        start_index, end_index = index
        if start_index > end_index:
            return False
        if start_index <= example.question_len + 2:
            return False
        return True
        
    def update(
        self,
        examples: List[Example],
        logits: torch.Tensor,
        indices: torch.Tensor,
        class_preds: torch.Tensor
    ):
        """Update batch objects.
        
        Parameters
        ----------
        examples : list of Example
        logits : np.ndarray with shape (batch_size,)
            Scores of each examples..
        indices : np.ndarray with shape (batch_size, 2)
            `start_index` and `end_index` pairs of each examples.
        class_preds : np.ndarray with shape (batch_size, num_classes)
            Class predicition scores of each examples.
        """
        for i, example in enumerate(examples):
            if self.is_valid_index(example, indices[i]) and \
               self.best_scores[example.example_id] < logits[i]:
                self.best_scores[example.example_id] = logits[i]
                self.examples[example.example_id] = example
                self.results[example.example_id] = [
                    example.doc_start, indices[i], class_preds[i]]

    def _generate_predictions(self) -> Generator[Dict, None, None]:
        """Generate predictions of each examples.
        """
        for example_id in self.results.keys():
            doc_start, index, class_pred = self.results[example_id]
            example = self.examples[example_id]
            tokenized_to_original_index = example.tokenized_to_original_index
            short_start_index = tokenized_to_original_index[doc_start + index[0]]
            short_end_index = tokenized_to_original_index[doc_start + index[1]]
            long_start_index = -1
            long_end_index = -1
            for candidate in example.candidates:
                if candidate['start_token'] <= short_start_index and \
                   short_end_index <= candidate['end_token']:
                    long_start_index = candidate['start_token']
                    long_end_index = candidate['end_token']
                    break
            yield {
                'example': example,
                'long_answer': [long_start_index, long_end_index],
                'short_answer': [short_start_index, short_end_index],
                'yes_no_answer': class_pred
            }

    def end(self) -> Dict[str, Dict]:
        """Get predictions in submission format.
        """
        preds = {}
        for pred in self._generate_predictions():
            example = pred['example']
            long_start_index, long_end_index = pred['long_answer']
            short_start_index, short_end_index = pred['short_answer']
            class_pred = pred['yes_no_answer']

            long_answer = f'{long_start_index}:{long_end_index}' if long_start_index != -1 else np.nan
            short_answer = f'{short_start_index}:{short_end_index}'
            class_pred = self.class_labels[class_pred.argmax()]
            short_answer += ' ' + class_pred if class_pred in ['YES', 'NO'] else ''
            preds[f'{example.example_id}_long'] = long_answer
            preds[f'{example.example_id}_short'] = short_answer
        return preds

    def score(self) -> Dict[str, float]:
        """Calculate score of all examples.
        """

        def _safe_divide(x: int, y: int) -> float:
            """Compute x / y, but return 0 if y is zero.
            """
            if y == 0:
                return 0.
            else:
                return x / y

        def _compute_f1(answer_stats: List[List[bool]]) -> float:
            """Computes F1, precision, recall for a list of answer scores.
            """
            has_answer, has_pred, is_correct = list(zip(*answer_stats))
            precision = _safe_divide(sum(is_correct), sum(has_pred))
            recall = _safe_divide(sum(is_correct), sum(has_answer))
            f1 = _safe_divide(2 * precision * recall, precision + recall)
            return f1

        long_scores = []
        short_scores = []
        for pred in self._generate_predictions():
            example = pred['example']
            long_pred = pred['long_answer']
            short_pred = pred['short_answer']
            class_pred = pred['yes_no_answer']
            yes_no_label = self.class_labels[class_pred.argmax()]

            # long score
            long_label = example.annotations['long_answer']
            has_answer = long_label['candidate_index'] != -1
            has_pred = long_pred[0] != -1 and long_pred[1] != -1
            is_correct = False
            if long_label['start_token'] == long_pred[0] and \
               long_label['end_token'] == long_pred[1]:
                is_correct = True
            long_scores.append([has_answer, has_pred, is_correct])

            # short score
            short_labels = example.annotations['short_answers']
            class_pred = example.annotations['yes_no_answer']
            has_answer = yes_no_label != 'NONE' or len(short_labels) != 0
            has_pred = class_pred != 'NONE' or (short_pred[0] != -1 and short_pred[1] != -1)
            is_correct = False
            if class_pred in ['YES', 'NO']:
                is_correct = yes_no_label == class_pred
            else:
                for short_label in short_labels:
                    if short_label['start_token'] == short_pred[0] and \
                       short_label['end_token'] == short_pred[1]:
                        is_correct = True
                        break
            short_scores.append([has_answer, has_pred, is_correct])

        long_score = _compute_f1(long_scores)
        short_score = _compute_f1(short_scores)
        return {
            'long_score': long_score,
            'short_score': short_score,
            'overall_score': (long_score + short_score) / 2
        }

In [5]:
start = time.time()
with gzip.open(data_path_v, "rb") as f:
    data = f.read()
y = data.splitlines()
end = time.time()
print("Loading Data:", end - start, "seconds")

val_size = len(y)

Loading Data: 50.362544775009155 seconds


In [7]:
eval_start_time = time.time()

convert_func = functools.partial(convert_data,
                                 tokenizer=tokenizer,
                                 max_seq_len=max_seq_len,
                                 max_question_len=max_question_len,
                                 doc_stride=doc_stride,
                                 val=True)

data_reader_v = JsonlReader(y, convert_func, chunksize=chunksize)
valid_data = next(data_reader_v)
valid_data = list(itertools.chain.from_iterable(valid_data))
valid_dataset = Subset(valid_data, range(len(valid_data)))
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, collate_fn=eval_collate_fn)
valid_scores = eval_model(model, valid_loader, device=device)

print(f'calculate validation score done in {(time.time() - eval_start_time) / 60:.1f} minutes.')

KeyError: 'document_text'

In [None]:
long_score = valid_scores['long_score']
short_score = valid_scores['short_score']
overall_score = valid_scores['overall_score']
print('validation scores:')
print(f'\tlong score    : {long_score:.4f}')
print(f'\tshort score   : {short_score:.4f}')
print(f'\toverall score : {overall_score:.4f}')
print(f'all process done in {(time.time() - init_start_time) / 3600:.1f} hours.')