In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import re

import numpy as np
from collections import Counter
import os
from argparse import Namespace

In [2]:
train_file='hp.txt'
seq_size=32
batch_size=16
embedding_size=64
lstm_size=64
gradients_norm=5
initial_words=['Hermoine', 'said']
predict_top_k=2
checkpoint_path='checkpoint'

In [3]:
text = open(train_file, 'r').read()
text = re.sub(r'[^\w\s]', '', text).replace("\n", "").replace("\'", "").split() 

word_counts = Counter(text)
sorted_vocab = sorted(word_counts, key=word_counts.get, reverse=True)
int_to_vocab = {k: w for k, w in enumerate(sorted_vocab)}
vocab_to_int = {w: k for k, w in int_to_vocab.items()}
n_vocab = len(int_to_vocab)

int_text = [vocab_to_int[w] for w in text]
num_batches = int(len(int_text) / (seq_size * batch_size))
in_text = int_text[:num_batches * batch_size * seq_size]
out_text = np.zeros_like(in_text)
out_text[:-1] = in_text[1:]
out_text[-1] = in_text[0]
in_text = np.reshape(in_text, (batch_size, -1))
out_text = np.reshape(out_text, (batch_size, -1))

In [6]:
in_text.shape

(16, 4832)

In [47]:
def get_batches(in_text, out_text, batch_size, seq_size):
    num_batches = np.prod(in_text.shape) // (seq_size * batch_size)
    for i in range(0, num_batches * seq_size, seq_size):
        yield in_text[:, i:i+seq_size], out_text[:, i:i+seq_size]

In [5]:
class RNNModule(nn.Module):
    def __init__(self, n_vocab, seq_size, embedding_size, lstm_size):
        super(RNNModule, self).__init__()
        self.seq_size = seq_size
        self.lstm_size = lstm_size
        self.embedding = nn.Embedding(n_vocab, embedding_size)
        self.lstm = nn.LSTM(embedding_size,
                            lstm_size,
                            batch_first=True)
        self.dense = nn.Linear(lstm_size, n_vocab)

    def forward(self, x, prev_state):
        embed = self.embedding(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.dense(output)
        return logits, state

    def zero_state(self, batch_size):
        return (torch.zeros(1, batch_size, self.lstm_size),
                torch.zeros(1, batch_size, self.lstm_size))

In [17]:
def predict(device, net, words, n_vocab, vocab_to_int, int_to_vocab, top_k=5):
    net.eval()
    #words = ['I', 'am']

    state_h, state_c = net.zero_state(1)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for w in words:
        ix = torch.tensor([[vocab_to_int[w]]]).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))

    _, top_ix = torch.topk(output[0], k=top_k)
    choices = top_ix.tolist()
    choice = np.random.choice(choices[0])

    words.append(int_to_vocab[choice])

    for _ in range(100):
        ix = torch.tensor([[choice]]).to(torch.int64).to(device)
        output, (state_h, state_c) = net(ix, (state_h, state_c))

        _, top_ix = torch.topk(output[0], k=top_k)
        choices = top_ix.tolist()
        choice = np.random.choice(choices[0])
        words.append(int_to_vocab[choice])

    print(' '.join(words))


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [23]:
net = RNNModule(n_vocab, seq_size, embedding_size, lstm_size)
net = net.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.02)

iteration = 0

for e in range(50):
    batches = get_batches(in_text, out_text, batch_size, seq_size)
    state_h, state_c = net.zero_state(batch_size)
    state_h = state_h.to(device)
    state_c = state_c.to(device)
    for x, y in batches:
        iteration += 1
        net.train()

        optimizer.zero_grad()

        x = torch.tensor(x).to(torch.int64).to(device)
        y = torch.tensor(y).to(torch.int64).to(device)

        #x = torch.tensor(train).to(torch.int64)

        logits, (state_h, state_c) = net(x, (state_h, state_c))
        loss = criterion(logits.transpose(1, 2), y)

        loss_value = loss.item()

        loss.backward()

        state_h = state_h.detach()
        state_c = state_c.detach()

        _ = torch.nn.utils.clip_grad_norm_(net.parameters(), gradients_norm)

        optimizer.step()

        if iteration % 1000 == 0:
            print('Epoch: {}/{}'.format(e, 50),
                  'Iteration: {}'.format(iteration),
                  'Loss: {}'.format(loss_value))

Epoch: 0/10 Iteration: 100 Loss: 7.261590480804443
Epoch: 1/10 Iteration: 200 Loss: 6.536898612976074
Epoch: 1/10 Iteration: 300 Loss: 6.098989486694336
Epoch: 2/10 Iteration: 400 Loss: 5.675844192504883
Epoch: 3/10 Iteration: 500 Loss: 5.274921894073486
Epoch: 3/10 Iteration: 600 Loss: 4.948337078094482
Epoch: 4/10 Iteration: 700 Loss: 4.640113353729248
Epoch: 5/10 Iteration: 800 Loss: 4.368916988372803
Epoch: 5/10 Iteration: 900 Loss: 3.9334826469421387
Epoch: 6/10 Iteration: 1000 Loss: 3.9848506450653076
Epoch: 7/10 Iteration: 1100 Loss: 3.5826988220214844
Epoch: 7/10 Iteration: 1200 Loss: 3.673696517944336
Epoch: 8/10 Iteration: 1300 Loss: 3.5829453468322754
Epoch: 9/10 Iteration: 1400 Loss: 3.1579761505126953
Epoch: 9/10 Iteration: 1500 Loss: 3.2351605892181396
Epoch: 10/10 Iteration: 1600 Loss: 3.2852554321289062
Epoch: 11/10 Iteration: 1700 Loss: 3.145927667617798
Epoch: 11/10 Iteration: 1800 Loss: 2.9105069637298584
Epoch: 12/10 Iteration: 1900 Loss: 3.042435884475708
Epoch: 13

In [30]:
 predict(device, net, ["Harry","Potter"], n_vocab, vocab_to_int, int_to_vocab, top_k=5)

Harry Potter and looked right person lived for another goblin. inside Harrys school just like a few in there with one wrong voice by the ceiling here facing an cold Voldemort, that now," that he was nowhere as he left. "Dont try and snatching back for so long as horrible as Seeker next. Then hed be in a letter? of great surprise. Harry showed in the silver doors ten wand, as of them looked as if someone had no stamp. moved full came with his hand into sweets. Ron, by Harry looked behind it. All the Fat Lady there was all the way
