In [1]:
import string
import numpy.random as random
import re

# note: we can build our own char base from reading the file
all_chars = string.printable
n_chars = len(all_chars) # total number of characters

with open('input.txt', 'r') as file:
    text = file.read()

text_len = len(text)
print('text_len =', text_len)

text_len = 1115394


In [2]:
seq_len = 200

def random_seq():
    start = random.randint(0, text_len - seq_len + 1) # numpy random gives int [low, high) hence the +1
    end = start + seq_len
    return text[start:end]

print(random_seq())

ws of your husband.

VIRGILIA:
O, good madam, there can be none yet.

VALERIA:
Verily, I do not jest with you; there came news from
him last night.

VIRGILIA:
Indeed, madam?

VALERIA:
In earnest, it's


In [3]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

class CharSeqRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(self.__class__, self).__init__()
        
        self.vocab_size = vocab_size # number of chars for this case
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim # we could keep this same as hidden dim to reduce one variable
        
        self.encode = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers=1, batch_first=True) # we can try dropout
        self.decode = nn.Linear(hidden_dim, vocab_size)

    def forward(self, inp, hidden):
        inp = self.encode(inp) #input must be N x T
        output, hidden = self.rnn(inp, hidden)
        output = self.decode(output)
        #output = F.log_softmax(output, dim=2) # we can  do this at output
        return output, hidden

    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(1, batch_size, self.hidden_dim)),
                Variable(torch.zeros(1, batch_size, self.hidden_dim)))

In [4]:
def char_index(chars):
    return Variable(torch.LongTensor([all_chars.index(c) for c in chars]).view(1,-1))

print(char_index("abcDEF"))

Variable containing:
 10  11  12  39  40  41
[torch.LongTensor of size (1,6)]



In [5]:
def training_batch(batch_size):
    chars_in = []
    chars_out = []
    for i in range(batch_size):
        char_seq = random_seq()
        chars_in.append(char_index(char_seq[:-1]))
        chars_out.append(char_index(char_seq[1:]))
    chars_in = torch.cat(chars_in, dim=0)
    chars_out = torch.cat(chars_out, dim=0)
    return chars_in, chars_out

c_in, c_out = training_batch(1)
#print(c_in)

In [6]:
def run(init_str='A', length=200, temp=0.4):
    hidden = model.init_hidden(1)
    pred = init_str
    if len(init_str) > 1:
        input = char_index(init_str[:-1])
        _, hidden = model(input, hidden)
    
    input = char_index(init_str[-1])
    
    for i in range(length):
        output, hidden = model(input, hidden)
        
        output_dist = F.softmax(output.view(-1)/temp, dim=0).data
        idx = torch.multinomial(output_dist, 1)[0]
        pred_char = all_chars[idx]
        pred += pred_char
        input = char_index(pred_char)
    return pred

In [7]:
import time, math

def time_since(since):
    s = time.time() - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [8]:
def train(batch_size):
    hidden = model.init_hidden(batch_size)
    model.zero_grad()
    loss = 0
    c_in, c_out = training_batch(batch_size)
    
    output, hidden = model(c_in, hidden)
    loss = criterion(output.view(-1, n_chars), c_out.view(-1))
    
    loss.backward()
    optimizer.step()
    
    return loss.data[0]
    

In [9]:
epochs = 2000
print_fq = 100
plot_fq = 10

embed_dim = 128
hidden_dim = 128
batch_size = 64
model = CharSeqRNN(n_chars, embed_dim, hidden_dim)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

criterion = nn.CrossEntropyLoss()

start = time.time()
losses = []
loss_avg = 0

for epoch in range(1, epochs+1):
    loss1 = train(batch_size)
    loss_avg += loss1
    if epoch % print_fq == 0:
        print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / epochs * 100, loss1))
        print(run('\n', 150, 0.5), '\n')

    if epoch % plot_fq == 0:
        losses.append(loss_avg / plot_fq)
        loss_avg = 0
    
#print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / epochs * 100, loss1))


[4m 16s (100 5%) 2.4376]

Whine dee hos thee to wine pan inbet the thou thas that the the wis ang sothen ind dout the bors all that de path he lot weran the rot hie son nous co 

[8m 41s (200 10%) 2.1566]

REHKE INGTE:
And this the cance mand the with rise mot har this herat is the sher the the my thou the of and lo prate shath beat on the will so madis  

[13m 22s (300 15%) 2.0317]

I I freresence hat the thim of hat the seed of well her sofe it the lord and with the the the here my morterd,
And shat lith the gordon and frow here  

[17m 57s (400 20%) 1.9411]

Aw my be that not.

MENCETIO:
For the the shat my the stir.

CLAUSTES:
Sern the maded you poors, lone for fall the so searter Cander and dongue the hi 

[22m 0s (500 25%) 1.8832]

Hor to say be not your his to my to hersenty,

First our stand that our so the some my be him and and cone is the reath the parding but the shall.

GL 

[26m 2s (600 30%) 1.8275]

Then a moor should to me the she will in this purvest,
Whild he sinder t

In [10]:
print("After training")
print(run('\n', 500, 0.2))

After training

The stands and the son the bears to the cousin,
That a stand with the country to the partion of the country.

PETRUCHIO:
I will not stand and the straition to the partion to the such all the words,
And the sen the such the stand of the words,
And the complest to the stand and such made and the father
That the father with the world with the prince,
The sen the world and the common made and the stand and the partion.

BRUTUS:
Why, the send the consul of the house to the seem the prince
That with t


In [14]:
# saving model parameters
state = {
            'args': (n_chars, embed_dim, hidden_dim), 
            'state_dict': model.state_dict()
        }

torch.save(state, 'pretrained/model.pth.tar')