In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from tqdm import tqdm_notebook
from torch.utils.data import IterableDataset
from torch.optim.lr_scheduler import ExponentialLR
from transformers import (BertTokenizer,
                          BertModel, 
                          BertForMaskedLM, 
                          BertForNextSentencePrediction, 
                          AdamW)

from dataset import *
from utils import *

In [3]:
device = 'cuda'
parallel = False

In [4]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [5]:
collator = NSPBatchCollator(tokenizer, device=device)

In [6]:
dataset = WikihowNSP('wikihow.csv', shuffle=True, tqdm=tqdm_notebook)

In [7]:
loader = dataset.loader(collator, batch_size=32)

In [8]:
# raw_model = BertForNextSentencePrediction.from_pretrained('bert-base-uncased')
raw_model = torch.load('bert.pt')
model = nn.DataParallel(raw_model.to(device)) if parallel else raw_model

In [9]:
optimizer = AdamW(model.parameters(), lr=1e-6, weight_decay=0.01)

In [10]:
scheduler = ExponentialLR(optimizer, gamma=0.999)

In [11]:
def reduce_metrics(seq):
    loss, lr = zip(*seq)
    return f'loss: {np.mean(loss)}, lr: {lr[-1][0]:.2e}'

In [None]:
with Logger(file='train.log', 
            reduce_fn=reduce_metrics, 
            header=('=' * 50),
            overwrite=False, 
            period=50) as logger:
    for epoch in range(8):
        for it, (input_ids, 
                 token_type_ids, 
                 attention_mask, 
                 next_sentence_label) in enumerate(loader):

            model.train()
            model.zero_grad()
            loss, score = model(input_ids=input_ids, 
                                token_type_ids=token_type_ids, 
                                attention_mask=attention_mask, 
                                next_sentence_label=next_sentence_label)
            loss.mean().backward()
            logger.step([loss.mean().detach().cpu().numpy(), 
                         scheduler.get_lr()])
            optimizer.step()
            scheduler.step()

            if it % 100 == 0:
                torch.save(raw_model, 'bert.pt')

HBox(children=(FloatProgress(value=0.0, max=1585695.0), HTML(value='')))