<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Load-text-data" data-toc-modified-id="Load-text-data-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Load text data</a></span></li><li><span><a href="#Model" data-toc-modified-id="Model-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>Model</a></span></li><li><span><a href="#Train" data-toc-modified-id="Train-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Train</a></span></li></ul></div>

In [1]:
reset -fs

In [2]:
import os
from io import open
import time
import math
import torch
import torch.nn as nn
import torch.nn.functional as F


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cpu')

In [4]:
torch.manual_seed(1)

<torch._C.Generator at 0x7f472996f3b0>

# Load text data

In [5]:
class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []
        
    def add_word(self, word):
        """
        Add word to 'self.idx2word' and 'self.word2idx'.
        """
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1 # starts from 0
        return self.word2idx[word]
    
    def __len__(self):
        return len(self.idx2word)

In [6]:
class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'train.txt'))
        self.valid = self.tokenize(os.path.join(path, 'valid.txt'))
        self.test = self.tokenize(os.path.join(path, 'test.txt'))
        
    def tokenize(self, path):
        """
        Tokenize a text file and add tokens to the dictionary.
        """
        assert os.path.exists(path)
        
        with open(path, 'r', encoding="utf8") as f:
            idx_all = []
            for line in f:
                words = line.split() + ['<eos>']
                idx_line = []
                for word in words:
                    self.dictionary.add_word(word)
                    idx_line.append(self.dictionary.word2idx[word])
                idx_all.append(torch.tensor(idx_line).type(torch.int64))
            ids = torch.cat(idx_all)               
        return ids

In [7]:
model_data_filepath = 'data/'

corpus = Corpus(model_data_filepath + 'wikitext-2')

In [8]:
corpus.dictionary.word2idx['<unk>']

9

In [18]:
def make_batch(data, n_seq):
    """
    Trim data and cleanly divide data into n_seq chunks.
    """
    nbatch = data.size(0) // n_seq
    
    # Trim off remainders.
    data = data.narrow(0, 0, nbatch * n_seq)
    
    # Evenly divide the data across the n_seq batches.
    # Shape : ([bptt, n_seq])
    return data.view(-1, n_seq).to(device)

In [10]:
len(corpus.train)

2088628

In [64]:
n_seq = 80
eval_n_seq = 10

In [65]:
train_data = make_batch(corpus.train, n_seq)[:1000]
val_data = make_batch(corpus.valid, eval_n_seq)[:1000]
test_data = make_batch(corpus.test, eval_n_seq)[:1000]

In [66]:
train_data.shape, val_data.shape

(torch.Size([1000, 80]), torch.Size([1000, 10]))

# Model

In [45]:
class NWPMModel(nn.Module):
    def __init__(self, ntoken, nemb, nhid, nlayers, dropout=0.5, tie_weights=False):
        super().__init__()
        self.ntoken = ntoken
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, nemb)
        self.rnn = nn.GRU(nemb, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)
        
        if tie_weights:
            if nhid != nemb:
                raise ValueError('When using the tied flag, nhid must be equal to nemb')
            self.decoder.weight = self.encoder.weight
        
        self.init_weights()
        
        self.nhid = nhid
        self.nlayers = nlayers
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output)
        decoded = decoded.view(-1, self.ntoken)
        return F.log_softmax(decoded, dim=1), hidden
    
    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return weight.new_zeros(self.nlayers, bsz, self.nhid)

In [69]:
criterion = nn.NLLLoss()
ntokens = len(corpus.dictionary)
bptt = 35

In [70]:
model = NWPMModel(
    ntoken = ntokens,
    nemb = 200,
    nhid = 200,
    nlayers=5,
    tie_weights=True
)

# Train

In [71]:
def repackage_hidden(h):
    """
    Wraps hidden states in new Tensors, to detach them from their history.
    """
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [72]:
def get_batch(source, i):
    """
    Subdivides the source into chunks of length bptt.
    The chunks are along dimension 0 (length of each row is bptt).
    """
    seq_len = min(bptt, len(source)-1-i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

In [83]:
def evaluate(dataset, ntokens=ntokens, bptt=bptt):
    model.eval() # disables dropout
    sum_loss = 0
    hidden = model.init_hidden(eval_n_seq)
    
    with torch.no_grad():
        for i in range(0, dataset.size(0)-1, bptt):
            X, y = get_batch(dataset, i)
            y_hat, hidden = model(X, hidden)
            hidden = repackage_hidden(hidden)
            sum_loss += len(X) * criterion(y_hat, y).item()
    return sum_loss / (len(dataset)-1)

In [84]:
def train_epochs(ntokens=ntokens, epochs=20, log_interval=200, lr=1):
    """
    Train for epochs.
    """
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        model.train()
        sum_loss = 0.
        start_time = time.time()
        hidden = model.init_hidden(batch_size)

        for batch, i in enumerate(range(0, train_data.size(0)-1, bptt)):
            X, y = get_batch(train_data, i)
            # Starting each batch, we detach the hidden state from how it was previously produced.
            # If we didn't, the model would try backpropagating all the way to start of the dataset.
            model.zero_grad()
            hidden = repackage_hidden(hidden)
            y_hat, hidden = model(X, hidden)
            loss = criterion(y_hat, y)
            loss.backward()

            # 'clip_grad_norm' helps prevent the exploding gradient problem in RNNs/LSTMs.
            torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
            for p in model.parameters():
                p.data.add_(-lr, p.grad)

            sum_loss += loss.item()

            if batch % log_interval == 0 and batch > 0:
                cur_loss = sum_loss / log_interval
                elapsed = time.time() - start_time
                print(f'| {epoch=:3d} | {batch:5d}/{len(train_data)//bptt} batches |  {lr=:2.2f}  |  '
                f'{elapsed/log_interval :5.2f} s/batch | loss {cur_loss:5.2f} | ppl {math.exp(cur_loss):8.2f}')
                sum_loss = 0
                start_time = time.time()
        
        val_loss = evaluate(val_data)
        print('-'*89)
        print(f'| end of {epoch=:3d} | time: {time.time()-epoch_start_time:5.2f}s | {val_loss=:5.2f} | '
              f'valid ppl {math.exp(val_loss):8.2f}')
        print('-' * 89)
        
        # Save the model if the validation loss is the best we've seen so far.
        best_val_loss = None
        if not best_val_loss or val_loss < best_val_loss:
            with open('model.pth', 'wb') as f:
                torch.save(model, f)
            best_val_loss = val_loss
        else:
            # Anneal the learning rate if no improvement has been seen in the validation dataset.
            lr /= 4.0

In [85]:
lr = 1
epochs = 2

In [86]:
train_epochs(epochs=epochs, lr=lr, log_interval=7)

| epoch=  1 |     7/28 batches |  lr=1.00  |   2.01 s/batch | loss  8.59 | ppl  5384.07
| epoch=  1 |    14/28 batches |  lr=1.00  |   1.84 s/batch | loss  7.36 | ppl  1572.27
| epoch=  1 |    21/28 batches |  lr=1.00  |   2.16 s/batch | loss  7.42 | ppl  1664.02
| epoch=  1 |    28/28 batches |  lr=1.00  |   1.86 s/batch | loss  7.41 | ppl  1656.82
-----------------------------------------------------------------------------------------
| end of epoch=  1 | time: 59.31s | val_loss= 7.27 | valid ppl  1439.79
-----------------------------------------------------------------------------------------
| epoch=  2 |     7/28 batches |  lr=1.00  |   2.66 s/batch | loss  8.54 | ppl  5113.96
| epoch=  2 |    14/28 batches |  lr=1.00  |   2.09 s/batch | loss  7.30 | ppl  1484.96
| epoch=  2 |    21/28 batches |  lr=1.00  |   2.16 s/batch | loss  7.37 | ppl  1588.92
| epoch=  2 |    28/28 batches |  lr=1.00  |   1.90 s/batch | loss  7.36 | ppl  1572.45
--------------------------------------------

In [87]:
# Load the best saved model.
with open('model.pth', 'rb') as f:
    model = torch.load(f)
    # after load the rnn params are not a continuous chunk of memory
    # this makes them a continuous chunk, and will speed up forward pass
    # Currently, only rnn model supports flatten_parameters function.
    model.rnn.flatten_parameters()

In [88]:
# Run on test data.
test_loss = evaluate(test_data)
print('=' * 89)
print(f'| End of training | {test_loss=:5.2f} | test ppl {math.exp(test_loss):8.2f}')

| End of training | test_loss= 7.21 | test ppl  1359.52


In [None]:
# model.load_state_dict(
#     torch.load(
#     model_data_filepath + 'word_language_model_quantize.pth',
#         map_location = torch.device('cpu')
#     )
# )
# model.eval()
# print(model)