In [2]:
import argparse
import os
import time
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchtext import data as d
from torchtext import datasets
from torchtext.vocab import GloVe
from tqdm.notebook import tqdm

In [3]:
is_cuda = torch.cuda.is_available()
is_cuda

True

In [4]:
TEXT = d.Field(lower=True, batch_first=True,)

In [5]:
# make splits for data
train, valid, test = datasets.WikiText2.splits(TEXT,root='data')

downloading wikitext-2-v1.zip


wikitext-2-v1.zip: 100%|██████████| 4.48M/4.48M [00:00<00:00, 9.84MB/s]


extracting


In [6]:
batch_size=20
bptt_len=30
clip = 0.25
lr = 20
log_interval = 200

In [7]:
(len(valid[0].text)//batch_size)*batch_size

217640

In [8]:
len(valid[0].text)

217646

In [9]:
train[0].text = train[0].text[:(len(train[0].text)//batch_size)*batch_size]
valid[0].text = valid[0].text[:(len(valid[0].text)//batch_size)*batch_size]
test[0].text = test[0].text[:(len(valid[0].text)//batch_size)*batch_size]

In [10]:
len(valid[0].text)

217640

In [11]:
# print information about the data
print('train.fields', train.fields)
print('len(train)', len(train))
print('vars(train[0])', vars(train[0])['text'][0:10])

train.fields {'text': <torchtext.data.field.Field object at 0x7fe075072438>}
len(train) 1
vars(train[0]) ['<eos>', '=', 'valkyria', 'chronicles', 'iii', '=', '<eos>', '<eos>', 'senjō', 'no']


In [12]:
TEXT.build_vocab(train)

In [13]:
print('len(TEXT.vocab)', len(TEXT.vocab))

len(TEXT.vocab) 28913


In [15]:
train_iter, valid_iter, test_iter = d.BPTTIterator.splits((train, valid, test), batch_size=batch_size, bptt_len=bptt_len, device="cuda",repeat=False)

In [16]:
class RNNModel(nn.Module):
    def __init__(self,ntoken,ninp,nhid,nlayers,dropout=0.5,tie_weights=False):
        super().__init__()
        self.drop = nn.Dropout()
        self.encoder = nn.Embedding(ntoken,ninp)
        self.rnn = nn.LSTM(ninp,nhid,nlayers,dropout=dropout)
        self.decoder = nn.Linear(nhid,ntoken)
        if tie_weights:
            self.decoder.weight = self.encoder.weight
        
        self.init_weights()
        self.nhid = nhid
        self.nlayers = nlayers
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange,initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange,initrange)
        
    def forward(self,input,hidden): 
        
        emb = self.drop(self.encoder(input))
        output,hidden = self.rnn(emb,hidden)
        output = self.drop(output)
        s = output.size()
        decoded = self.decoder(output.view(s[0]*s[1],s[2]))
        return decoded.view(s[0],s[1],decoded.size(1)),hidden
    
    def init_hidden(self,bsz):
        weight = next(self.parameters()).data
        return(Variable(weight.new(self.nlayers,bsz,self.nhid).zero_()),Variable(weight.new(self.nlayers,bsz,self.nhid).zero_()))

In [17]:
criterion = nn.CrossEntropyLoss()

In [18]:
len(valid_iter.dataset[0].text)

217640

In [19]:
emsize = 200
nhid=200
nlayers=2
dropout = 0.2

ntokens = len(TEXT.vocab)
lstm = RNNModel(ntokens, emsize, nhid,nlayers, dropout, 'store_true')
if is_cuda:
    lstm = lstm.cuda()

In [24]:
def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [34]:
def evaluate(data_source):
    # Turn on evaluation mode which disables dropout.
    lstm.eval()
    total_loss = 0   
    hidden = lstm.init_hidden(batch_size)
    for batch in data_source:        
        data, targets = batch.text,batch.target.view(-1)
        output, hidden = lstm(data, hidden)
        output_flat = output.view(-1, ntokens)
        total_loss += len(data) * criterion(output_flat, targets).data
        hidden = repackage_hidden(hidden)
    return total_loss.item()/(len(data_source.dataset[0].text)//batch_size)

In [44]:
def trainf():
    # Turn on training mode which enables dropout.
    lstm.train()
    total_loss = 0
    start_time = time.time()
    hidden = lstm.init_hidden(batch_size)
    pbar = tqdm(total=len(train_iter))
    for  i,batch in enumerate(train_iter):
        data, targets = batch.text,batch.target.view(-1)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        lstm.zero_grad()
        output, hidden = lstm(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm_` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(lstm.parameters(), clip)
        for p in lstm.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()
        #pbar.set_description("epoch:{}, Loss: {:.4f}".format(
        #    epoch, train_loss/ (batch_idx +1)))
        pbar.update(1)
        #pbar.close()
        '''
        if i % log_interval == 0 and i > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            #(print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}'.format(epoch, i, len(train_iter), lr,elapsed * 1000 / log_interval, cur_loss, math.exp(cur_loss))))
            #pbar.set_description("epoch:{}| Loss: {:.4f} | ppl {:8.2f}".format(
                                  epoch, cur_loss , math.exp(cur_loss)))
            #pbar.update(1)
            total_loss = 0
            start_time = time.time()
        '''
    pbar.close()

In [45]:
# Loop over epochs.
best_val_loss = None
epochs = 40

for epoch in range(1, epochs+1):
    epoch_start_time = time.time()
    trainf()
    val_loss = evaluate(valid_iter)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
        'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                   val_loss, math.exp(val_loss)))
    print('-' * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0

HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 56.39s | valid loss  4.92 | valid ppl   136.51
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 55.83s | valid loss  4.91 | valid ppl   135.19
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 55.99s | valid loss  4.91 | valid ppl   136.14
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   4 | time: 55.97s | valid loss  4.79 | valid ppl   120.59
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   5 | time: 55.84s | valid loss  4.78 | valid ppl   118.99
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   6 | time: 55.91s | valid loss  4.77 | valid ppl   118.15
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   7 | time: 55.82s | valid loss  4.76 | valid ppl   117.10
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   8 | time: 55.93s | valid loss  4.76 | valid ppl   116.50
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch   9 | time: 55.90s | valid loss  4.76 | valid ppl   116.46
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  10 | time: 56.13s | valid loss  4.75 | valid ppl   115.87
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  11 | time: 56.14s | valid loss  4.75 | valid ppl   115.29
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  12 | time: 56.08s | valid loss  4.74 | valid ppl   114.91
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  13 | time: 56.18s | valid loss  4.74 | valid ppl   114.64
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  14 | time: 56.06s | valid loss  4.74 | valid ppl   114.42
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  15 | time: 55.97s | valid loss  4.74 | valid ppl   113.92
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  16 | time: 56.04s | valid loss  4.73 | valid ppl   113.78
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  17 | time: 55.81s | valid loss  4.73 | valid ppl   113.42
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  18 | time: 56.09s | valid loss  4.73 | valid ppl   113.12
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  19 | time: 56.19s | valid loss  4.73 | valid ppl   113.11
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  20 | time: 56.07s | valid loss  4.73 | valid ppl   113.05
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  21 | time: 56.02s | valid loss  4.73 | valid ppl   112.94
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  22 | time: 56.00s | valid loss  4.72 | valid ppl   112.45
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  23 | time: 56.10s | valid loss  4.72 | valid ppl   111.97
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  24 | time: 56.06s | valid loss  4.72 | valid ppl   111.77
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  25 | time: 56.53s | valid loss  4.72 | valid ppl   111.97
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  26 | time: 56.32s | valid loss  4.68 | valid ppl   108.18
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  27 | time: 56.24s | valid loss  4.68 | valid ppl   107.87
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  28 | time: 56.25s | valid loss  4.68 | valid ppl   107.63
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  29 | time: 56.51s | valid loss  4.68 | valid ppl   107.63
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  30 | time: 56.52s | valid loss  4.68 | valid ppl   107.43
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  31 | time: 56.47s | valid loss  4.68 | valid ppl   107.30
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  32 | time: 56.44s | valid loss  4.68 | valid ppl   107.32
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  33 | time: 56.76s | valid loss  4.66 | valid ppl   106.08
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  34 | time: 56.08s | valid loss  4.66 | valid ppl   106.02
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  35 | time: 56.29s | valid loss  4.66 | valid ppl   105.88
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  36 | time: 56.30s | valid loss  4.66 | valid ppl   105.88
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  37 | time: 56.28s | valid loss  4.66 | valid ppl   105.81
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  38 | time: 56.06s | valid loss  4.66 | valid ppl   105.78
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  39 | time: 56.08s | valid loss  4.66 | valid ppl   105.69
-----------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, max=3481.0), HTML(value='')))


-----------------------------------------------------------------------------------------
| end of epoch  40 | time: 56.41s | valid loss  4.66 | valid ppl   105.68
-----------------------------------------------------------------------------------------
