<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Load-text-data" data-toc-modified-id="Load-text-data-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Load text data</a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href="#Train" data-toc-modified-id="Train-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Train</a></span></li></ul></div>

In [54]:
reset -fs

In [55]:
import os
from io import open
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [56]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [57]:
torch.manual_seed(1111)

<torch._C.Generator at 0x7f5e680228b0>

# Load text data

In [58]:
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []
        
    def add_word(self, word):
        """
        Add word to 'self.idx2word' and 'self.word2idx'.
        """
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1 # starts from 0
        return self.word2idx[word]
    
    def __len__(self):
        return len(self.idx2word)

In [59]:
class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))
        
    def tokenize(self, path):
        """
        Tokenize a text file and add tokens to the dictionary.
        """
        assert os.path.exists(path)
        
        with open(path, 'r', encoding="utf8") as f:
            idx_all = []
            for line in f:
                words = line.split() + ['<eos>']
                idx_line = []
                for word in words:
                    self.dictionary.add_word(word)
                    idx_line.append(self.dictionary.word2idx[word])
                idx_all.append(torch.tensor(idx_line).type(torch.int64))
            ids = torch.cat(idx_all)               
        return ids

In [60]:
model_data_filepath = 'data/'

corpus = Corpus(model_data_filepath + 'wikitext-2')

In [61]:
corpus.dictionary.word2idx['<unk>']

9

In [62]:
def make_batch(data, n_seq):
    """
    Trim data and cleanly divide data into n_seq chunks.
    """
    nbatch = data.size(0) // n_seq
    
    # Trim off remainders.
    data = data.narrow(0, 0, nbatch * n_seq)
    
    # Evenly divide the data across the n_seq batches.
    # Shape : ([bptt, n_seq])
    return data.view(n_seq, -1).t().contiguous().to(device)

In [63]:
# def get_batch(source, i):
#     """
#     Subdivides the source into chunks of length bptt.
#     The chunks are along dimension 0 (length of each row is n_seq).
#     """
#     seq_len = min(bptt, len(source)-1-i)
#     data = source[i:i+seq_len]
#     target = source[i+1:i+1+seq_len].view(-1)
#     return data, target

In [64]:
# class wikiDataset(Dataset):
#     def __init__(self, data, n_seq):
#         self.data = data
#         self.n_seq = n_seq
    
#     def __len__(self):
#         return len(self.data)
    
#     def __getitem__(self, idx):
#         data = self.data[idx]
#         return make_batch(data, self.n_seq)

In [65]:
# wiki_train = wikiDataset(corpus.train, n_seq=80)

In [66]:
# train_loader = DataLoader(wiki_train, batch_size=35, shuffle=False)
# next(iter(train_loader))

In [67]:
n_seq = 20
eval_n_seq = 10

In [68]:
train_data = make_batch(corpus.train, n_seq)
val_data = make_batch(corpus.valid, eval_n_seq)
test_data = make_batch(corpus.test, eval_n_seq)

In [69]:
train_data

tensor([[    0,   284, 15178,  ...,  1352,  1335,    16],
        [    1,   357,    43,  ...,    46,    43,  2015],
        [    2,  1496,  7369,  ...,   380,    27, 33001],
        ...,
        [  357,   415,   173,  ...,   212,    78,  1575],
        [ 2520,     9,  3890,  ...,   208,    27,   808],
        [   33,    35,    19,  ...,  8832,  6091,   209]], device='cuda:0')

# Model

In [76]:
class NWPMModel(nn.Module):
    def __init__(self, ntoken, nemb, nhid, nlayers, dropout=0.5, tie_weights=False):
        super().__init__()
        self.ntoken = ntoken
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, nemb)
        self.rnn = nn.GRU(nemb, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        
        if tie_weights:
            if nhid != nemb:
                raise ValueError('When using the tied flag, nhid must be equal to nemb')
            self.decoder.weight = self.encoder.weight
        
        self.init_weights()
        
        self.nhid = nhid
        self.nlayers = nlayers
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        decoded = decoded.view(-1, self.ntoken)
        return F.log_softmax(decoded, dim=1), hidden
    
    def init_hidden(self, n_seq):
        weight = next(self.parameters())
        return weight.new_zeros(self.nlayers, n_seq, self.nhid)

# Train

In [79]:
def repackage_hidden(h):
    """
    Wraps hidden states in new Tensors, to detach them from their history.
    """
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [80]:
def get_batch(source, i):
    """
    Subdivides the source into chunks of length bptt.
    The chunks are along dimension 0 (length of each row is n_seq).
    """
    seq_len = min(bptt, len(source)-1-i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

In [82]:
def evaluate(dataset, ntokens=len(corpus.dictionary), bptt=bptt):
    model.eval() # disables dropout
    sum_loss = 0
    hidden = model.init_hidden(eval_n_seq)
    
    with torch.no_grad():
        for i in range(0, dataset.size(0)-1, bptt):
            X, y = get_batch(dataset, i)
            y_hat, hidden = model(X, hidden)
            hidden = repackage_hidden(hidden)
            sum_loss += len(X) * criterion(y_hat, y).item()
    return sum_loss / (len(dataset)-1)

In [98]:
def train_epochs(optimizer, scheduler, n_seq, ntokens=len(corpus.dictionary), epochs=20, log_interval=200):
    """
    Train for epochs.
    """
    best_val_loss = None
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        model.train()
        sum_loss = 0.
        hidden = model.init_hidden(n_seq)

        for batch, i in enumerate(range(0, train_data.size(0)-1, bptt)):
            X, y = get_batch(train_data, i)
            
            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            optimizer.zero_grad()
            hidden = repackage_hidden(hidden)
            y_hat, hidden = model(X, hidden)
            loss = criterion(y_hat, y)
            loss.backward()
            
            # 'clip_grad_norm' helps prevent the exploding gradient problem in RNNs/LSTMs.
            lr = scheduler.get_lr()[0]
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
            for p in model.parameters():
                p.data.add_(-lr, p.grad)
            
            optimizer.step()
            scheduler.step()
            
            sum_loss += loss.item()

            if batch % log_interval == 0 and batch > 0:
                cur_loss = sum_loss / log_interval
                print(f'| {epoch=:3d} | {batch:5d}/{len(train_data)//bptt} batches |  {lr=:2.2f}  |  '
                f'loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}')
                sum_loss = 0
                start_time = time.time()
        
        val_loss = evaluate(val_data)
        print('-'*89)
        print(f'| end of {epoch=:2d} | time: {time.time()-epoch_start_time:5.2f}s | {val_loss=:5.2f} | '
              f'valid ppl {math.exp(val_loss):8.2f}')
        print('-' * 89)
        
        # Save the model if the validation loss is the best we've seen so far.
        if not best_val_loss or val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), 'model.pth')

    os.rename('model.pth', f'model_val_loss_{val_loss:.2f}_val_ppl_{math.exp(val_loss):.2f}.pt')

In [99]:
criterion = nn.NLLLoss()
ntokens = len(corpus.dictionary)
bptt = 35

In [108]:
model = NWPMModel(
    ntoken = ntokens,
    nemb = 650,
    nhid = 650,
    nlayers=2,
    dropout=0.2,
    tie_weights=True
).to(device)

In [109]:
optimizer = torch.optim.SGD(model.parameters(), lr=5, momentum=0.9)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=5,
                                                steps_per_epoch=len(train_data)//bptt, epochs=20)

In [110]:
train_epochs(optimizer=optimizer,
             scheduler=scheduler,
             n_seq=n_seq,
             epochs=20,
             log_interval=200)

| epoch=  1 |   200/2983 batches |  lr=0.20  |  loss  7.55 | ppl  1907.99
| epoch=  1 |   400/2983 batches |  lr=0.21  |  loss  6.71 | ppl   819.47
| epoch=  1 |   600/2983 batches |  lr=0.21  |  loss  6.46 | ppl   637.97
| epoch=  1 |   800/2983 batches |  lr=0.22  |  loss  6.34 | ppl   568.69
| epoch=  1 |  1000/2983 batches |  lr=0.24  |  loss  6.22 | ppl   504.36
| epoch=  1 |  1200/2983 batches |  lr=0.25  |  loss  6.16 | ppl   474.87
| epoch=  1 |  1400/2983 batches |  lr=0.27  |  loss  6.08 | ppl   438.23
| epoch=  1 |  1600/2983 batches |  lr=0.29  |  loss  6.08 | ppl   435.68
| epoch=  1 |  1800/2983 batches |  lr=0.32  |  loss  5.93 | ppl   375.07
| epoch=  1 |  2000/2983 batches |  lr=0.35  |  loss  5.91 | ppl   369.91
| epoch=  1 |  2200/2983 batches |  lr=0.38  |  loss  5.80 | ppl   328.99
| epoch=  1 |  2400/2983 batches |  lr=0.41  |  loss  5.80 | ppl   329.10
| epoch=  1 |  2600/2983 batches |  lr=0.45  |  loss  5.77 | ppl   320.38
| epoch=  1 |  2800/2983 batches |  lr

ValueError: Tried to step 59662 times. The specified number of total steps is 59660

In [111]:
# Load the best saved model.
with open('model_val_loss_5.01_val_ppl_149.39.pt', 'rb') as f:
    model = torch.load(f)
    # after load the rnn params are not a continuous chunk of memory
    # this makes them a continuous chunk, and will speed up forward pass
    # Currently, only rnn model supports flatten_parameters function.
    model.rnn.flatten_parameters()

In [112]:
# Run on validation data.
val_loss = evaluate(val_data)
print('=' * 89)
print(f'| End of training | {val_loss=:5.2f} | val ppl {math.exp(val_loss):8.2f}')

| End of training | val_loss= 5.01 | val ppl   149.39


In [113]:
# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print(f'| End of training | {test_loss=:5.2f} | test ppl {math.exp(test_loss):8.2f}')

| End of training | test_loss= 4.94 | test ppl   140.26


In [44]:
# model.load_state_dict(
#     torch.load(
#     model_data_filepath + 'word_language_model_quantize.pth',
#         map_location = torch.device('cpu')
#     )
# )
# model.eval()
# print(model)