In [1]:
import os,sys,time,math,textwrap

import numpy as np

import torch
import torch.nn as nn

import dataset, transformer

root = 'data'

In [2]:
lr = .00035
context = 150
batch_size = 32
log_interval = 50

heads = 10
depth = 16

torch.manual_seed(0)
device = torch.device("cuda")

In [3]:
train_data = dataset.WikiText2(root, context, dataset.DatasetSplit.train)
valid_data = dataset.WikiText2(root, context, dataset.DatasetSplit.valid)
test_data = dataset.WikiText2(root, context, dataset.DatasetSplit.test)

In [4]:
def evaluate(data):
    model.eval()
    with torch.no_grad():
        loss = 0.
        loader = torch.utils.data.DataLoader(dataset=data,batch_size=batch_size,shuffle=False)
        for i, (x,y) in enumerate(loader):
            x, y = x.permute(1,0).to(device), y.permute(1,0).to(device)
            yhat = model(x).view(-1, train_data.word_count())
            loss += criterion(yhat, y.contiguous().view(-1))

    print()
    model.train()
    return loss / len(loader)

In [5]:
model = transformer.Transformer(context, train_data.word_count(), 400, 40, 900, heads, depth, tied_weights=True).to(device)
count = sum([np.prod(parm.shape) for parm in model.parameters() if parm.requires_grad])
print('Initialized graph with {} parameters'.format(count))

Initialized graph with 35198479 parameters


In [6]:
criterion = nn.NLLLoss()
curr_lr = .0001
clip = .25
best_val_loss = None
epochs = 10
save = 'model.pt'

train_loader = torch.utils.data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
print('Initiating training, {} iterations/epoch.'.format(len(train_loader)))

try:
    optimizer = torch.optim.Adam(model.parameters(), lr=curr_lr)
    for epoch in range(epochs):
        t0 = time.time()
        val_loss = evaluate(valid_data)
        print('-' * 100)
        print('| checkpoint | epoch {:3d} | time: {:5.2f}s | validation loss {:5.2f} | '
                'validation perplexity {:8.2f}'.format(epoch, (time.time() - t0),
                                                       val_loss, math.exp(val_loss)))
        print('-' * 100)
        print('epoch\t\tms/batch\tlr\tloss\tperplexity')

        if not best_val_loss or val_loss < best_val_loss:
            with open(save, 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss

        model.train()
        total_loss = 0.
        t0 = time.time()
        if epoch == 1: optimizer.param_groups[0]['lr'] = curr_lr = lr # finished warmup
        for i, (x,y) in enumerate(train_loader):
            if i % log_interval == 0 and i > 0:
                cur_loss = total_loss / log_interval
                elapsed = time.time() - t0
                print('{:3d} ({:2.1f}%)\t{:5.2f}\t\t{:1.3}\t{:5.2f}\t{:8.2f}'.format(
                    epoch, 100*i/float(len(train_loader)),
                    elapsed * 1000 / log_interval, curr_lr, cur_loss, math.exp(cur_loss)))
                total_loss = 0
                t0 = time.time()

            x, y = x.permute(1,0).to(device), y.permute(1,0).to(device)
            model.zero_grad()
            yhat = model(x).view(-1, train_data.word_count())
            loss = criterion(yhat, y.contiguous().view(-1))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()

            total_loss += loss.item()

except KeyboardInterrupt:
    print('Graceful Exit')

Initiating training, 436 iterations/epoch.

----------------------------------------------------------------------------------------------------
| checkpoint | epoch   0 | time: 15.42s | validation loss 10.41 | validation perplexity 33279.18
----------------------------------------------------------------------------------------------------
epoch		ms/batch	lr	loss	perplexity
  0 (11.5%)	997.33		0.0001	 9.66	15746.70
  0 (22.9%)	996.18		0.0001	 7.18	 1313.35
  0 (34.4%)	1000.42		0.0001	 7.13	 1253.74
  0 (45.9%)	999.23		0.0001	 7.14	 1260.41
  0 (57.3%)	1000.66		0.0001	 7.12	 1241.91
  0 (68.8%)	1001.52		0.0001	 7.14	 1256.69
  0 (80.3%)	1001.61		0.0001	 7.11	 1230.16
  0 (91.7%)	1000.73		0.0001	 7.12	 1233.04

----------------------------------------------------------------------------------------------------
| checkpoint | epoch   1 | time: 15.33s | validation loss  6.89 | validation perplexity   979.96
----------------------------------------------------------------------------------

In [7]:
print('Restoring best checkpointed model...')
with open(save, 'rb') as f:
    model = torch.load(f)

test_loss = evaluate(test_data)
print('=' * 89)
print('| end of training | test loss {:5.2f} | test perplexity {:8.2f}'.format(test_loss, math.exp(test_loss)))
print('=' * 89)

Restoring best checkpointed model...

| end of training | test loss 10.41 | test perplexity 33279.15


In [8]:
print('\nUncurated samples')
print('-' * 89)

def sample():
    words = []
    model.eval()
    history = torch.randint(train_data.word_count(), (1, 1), dtype=torch.long).cuda()
    for i in range(context):
        output = model(history)
        word_weights = output[-1].squeeze().exp().cpu()
        word_idx = torch.multinomial(word_weights, 1)[0]
        word_tensor = torch.Tensor([[word_idx]]).long().cuda()
        history = torch.cat([history, word_tensor], 0)

        words.append(train_data.idx2word[word_idx])

    return '\n'.join(textwrap.wrap(' '.join(words),80))

for i in range(5):
    print('({})'.format(i), sample())


Uncurated samples
-----------------------------------------------------------------------------------------
(0) pursuit su scrapping monism examples involved Prolog Wax inexpensive Fitz Marcia
authored Kakapo EMI 350 overseas Daddy carriages faithfully 1664 mutualistic
Waddell lukewarm machete Ferraris language need troublesome coasters Lanesboro
penetration Publications copied universal fanbase moratorium beak Regarding
River Ware Hisashi IEDs hoshi buildup Dana cheated conventions Chucky eight
jammed Revival NZ combatant wielding diets Repair contribution culprits Anne
Keys pectoral Enrique realising Aramburu strands Literary Yeah Proponents
Meiklejohn Superman Furthermore preview open Brussels AM hexafluoroplatinate
contain Congolese overs Gastão propulsive lawsuit Angle Reclamation all Heir
penultimate octaves breaks Yankees Yorktown Geffen Zoo legally schism orderings
Tulane Prohaska viewership Nganno Allāh epidemic Sugar Biology ministry Mighty
threat goalkeeper 106 investors Tu