In [1]:
import os

import itertools
import pickle
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import math 

import sys
sys.path.append('../')
import utils
import wiki_utils
%matplotlib inline

In [2]:
import torchtext
from torchtext import data
import spacy
import numpy as np
from tqdm import tqdm 
 
from spacy.symbols import ORTH

from torchtext.datasets import WikiText2


In [3]:
my_tok = spacy.load('en')
 
def spacy_tok(x):
    return [tok.text for tok in my_tok.tokenizer(x)]
 
TEXT = data.Field(lower=True, tokenize=spacy_tok)

torch.device = 'cpu'

In [4]:
train, valid, test = WikiText2.splits(TEXT)

In [5]:
TEXT.build_vocab(train, vectors="glove.6B.200d")

In [6]:
batch_size = 128
sequence_length = 30
grad_clip = 0.1
lr = 4.
best_val_loss = None
log_interval = 100
eval_batch_size = 128

In [7]:
train_loader, val_loader, test_loader = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=batch_size,
    bptt_len=sequence_length,
    device=torch.device,
    repeat=False)

In [8]:
batch, (torchtext_data) = next(enumerate(train_loader))

In [9]:
data = torchtext_data.text

In [10]:
targets = torchtext_data.target

In [11]:
data[:5, :3]

tensor([[   12,   432,   151],
        [   13,  1167,    14],
        [   12,     4,    16],
        [   15,   271, 17524],
        [ 3875,  5426,     5]])

In [12]:
targets.reshape(30, 128)[:5, :3]

tensor([[   13,  1167,    14],
        [   12,     4,    16],
        [   15,   271, 17524],
        [ 3875,  5426,     5],
        [ 3895,  1129,  4341]])

In [13]:
class RNNModel(nn.Module):

    def __init__(self, rnn_type, ntoken, ninp, nhid, nlayers, batch_size, dropout=0.5):
        super(RNNModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        if rnn_type == 'LSTM':
            self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        elif rnn_type == 'GRU':
            self.rnn = nn.GRU(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.rnn_type = rnn_type
        self.nhid = nhid
        self.nlayers = nlayers
 

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, hidden=None):
        emb = self.drop(self.encoder(x))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        if self.rnn_type == 'LSTM':
            return (weight.new(self.nlayers, bsz, self.nhid).zero_(),
                    weight.new(self.nlayers, bsz, self.nhid).zero_())
        else:
            return weight.new(self.nlayers, bsz, self.nhid).zero_()


In [14]:
def evaluate(data_loader):
    model.eval()
    total_loss = 0
    ntokens = weight_matrix.size(0)
    hidden = model.init_hidden(eval_batch_size)
    for i, (torchtext_data) in enumerate(data_loader):
        data, targets = torchtext_data.text, torchtext_data.target.view(-1)
        output, hidden = model(data)
        output_flat = output.view(-1, ntokens)
        total_loss += (len(data) * criterion(output_flat, targets).item())/eval_batch_size
    return total_loss / len(data_loader)


In [15]:
def train():
    model.train()
    total_loss = 0
    ntokens = weight_matrix.size(0)
    for batch, (torchtext_data) in enumerate(train_loader):
        data, targets = torchtext_data.text, torchtext_data.target.view(-1)
        model.zero_grad()
        output, hidden = model(data)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        for p in model.parameters():
            p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()

        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | loss {:5.2f} | ppl {:8.2f}'.format(
                epoch, batch, len(train_loader), lr, cur_loss, math.exp(cur_loss)))
            total_loss = 0


In [16]:
weight_matrix = TEXT.vocab.vectors
ntokens = weight_matrix.size(0)
ninp = weight_matrix.size(1)
model = RNNModel('LSTM', ntokens, ninp, 128, 2, batch_size, 0.3)
model.encoder.weight.data.copy_(weight_matrix)
criterion = nn.CrossEntropyLoss()

In [17]:
def generate(n=50, temp=1.):
    model.eval()
    x = torch.rand(1, 1).mul(ntokens).long()
    hidden = None
    out = []
    for i in range(n):
        output, hidden = model(x, hidden)
        s_weights = output.squeeze().data.div(temp).exp()
        s_idx = torch.multinomial(s_weights, 1)[0]
        x.data.fill_(s_idx)
        # s = corpus.dictionary.idx2symbol[s_idx]
        s = TEXT.vocab.itos[s_idx]
        out.append(s)
    return ' '.join(out)

In [18]:
with torch.no_grad():
    print('sample:\n', generate(50), '\n')

for epoch in range(1, 6):
    train()
    val_loss = evaluate(val_loader)
    print('-' * 89)
    print('| end of epoch {:3d} | valid loss {:5.2f} | valid ppl {:8.2f}'.format(
        epoch, val_loss, math.exp(val_loss)))
    print('-' * 89)
    if not best_val_loss or val_loss < best_val_loss:
        best_val_loss = val_loss
    else:
        # Anneal the learning rate if no improvement has been seen in the validation dataset.
        lr /= 4.0
    with torch.no_grad():
        print('sample:\n', generate(50), '\n')

sample:
 fisheries simpler wren cock barcelona bellcote storm real contracting 1641 lees covenant inserted rigging foreigner mineurs royal dodgy liam highs infect sud verge lowman athlete christman penance stanwix displays feathered compiling jam buddhist weakly served feather released grouped behavioral ryder indonesian ars 311 mini terrific record 940 casting stimulates ark 

| epoch   1 |   100/  583 batches | lr 4.00 | loss  8.22 | ppl  3697.59
| epoch   1 |   200/  583 batches | lr 4.00 | loss  7.36 | ppl  1576.49
| epoch   1 |   300/  583 batches | lr 4.00 | loss  7.13 | ppl  1246.74
| epoch   1 |   400/  583 batches | lr 4.00 | loss  6.94 | ppl  1029.31
| epoch   1 |   500/  583 batches | lr 4.00 | loss  6.82 | ppl   913.28
-----------------------------------------------------------------------------------------
| end of epoch   1 | valid loss  1.44 | valid ppl     4.21
-----------------------------------------------------------------------------------------
sample:
 presenting 

In [19]:
t1 = generate(10000, 1.)
t15 = generate(10000, 1.5)
t075 = generate(10000, 0.75)
with open('./generated075_words.txt', 'w') as outf:
    outf.write(t075)
with open('./generated1_words.txt', 'w') as outf:
    outf.write(t1)
with open('./generated15_words.txt', 'w') as outf:
    outf.write(t15)
