In [1]:
import string
import numpy.random as random
import re

# note: we can build our own char base from reading the file
all_chars = string.printable
n_chars = len(all_chars) # total number of characters

with open('input.txt', 'r') as file:
    text = file.read()

text_len = len(text)
print('text_len =', text_len)

text_len = 1115394


In [2]:
seq_len = 200

def random_seq():
    start = random.randint(0, text_len - seq_len + 1) # numpy random gives int [low, high) hence the +1
    end = start + seq_len
    return text[start:end]

print(random_seq())

re in all my life.

VINCENTIO:
What, you notorious villain, didst thou never see
thy master's father, Vincentio?

BIONDELLO:
What, my old worshipful old master? yes, marry, sir:
see where he looks out


In [3]:
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F

class CharSeqRNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim):
        super(self.__class__, self).__init__()
        
        self.vocab_size = vocab_size # number of chars for this case
        self.hidden_dim = hidden_dim
        self.embed_dim = embed_dim # we could keep this same as hidden dim to reduce one variable
        
        self.encode = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim, num_layers=1, batch_first=True) # we can try dropout
        self.decode = nn.Linear(hidden_dim, vocab_size)

    def forward(self, inp, hidden):
        inp = self.encode(inp) #input must be N x T
        output, hidden = self.rnn(inp, hidden)
        output = self.decode(output)
        #output = F.log_softmax(output, dim=2) # we can  do this at output
        return output, hidden

    def init_hidden(self, batch_size):
        return (Variable(torch.zeros(1, batch_size, self.hidden_dim)).cuda(),
                Variable(torch.zeros(1, batch_size, self.hidden_dim)).cuda())

In [4]:
def char_index(chars):
    return Variable(torch.LongTensor([all_chars.index(c) for c in chars]).view(1,-1)).cuda()
print(char_index("abcDEF"))

Variable containing:
 10  11  12  39  40  41
[torch.cuda.LongTensor of size 1x6 (GPU 0)]



In [5]:
def training_batch(batch_size):
    chars_in = []
    chars_out = []
    for i in range(batch_size):
        char_seq = random_seq()
        chars_in.append(char_index(char_seq[:-1]))
        chars_out.append(char_index(char_seq[1:]))
    chars_in = torch.cat(chars_in, dim=0).cuda()
    chars_out = torch.cat(chars_out, dim=0).cuda()
    return chars_in, chars_out

c_in, c_out = training_batch(1)
#print(c_in)

In [6]:
def run(init_str='A', length=200, temp=0.4):
    hidden = model.init_hidden(1)
    pred = init_str
    if len(init_str) > 1:
        input = char_index(init_str[:-1])
        _, hidden = model(input, hidden)
    
    input = char_index(init_str[-1])
    
    for i in range(length):
        output, hidden = model(input, hidden)
        
        output_dist = F.softmax(output.view(-1)/temp, dim=0).data
        idx = torch.multinomial(output_dist, 1)[0]
        pred_char = all_chars[idx]
        pred += pred_char
        input = char_index(pred_char)
    return pred

In [7]:
import time, math

def time_since(since):
    s = time.time() - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)

In [8]:
def train(batch_size):
    hidden = model.init_hidden(batch_size)
    model.zero_grad()
    loss = 0
    c_in, c_out = training_batch(batch_size)
    
    output, hidden = model(c_in, hidden)
    loss = criterion(output.view(-1, n_chars), c_out.view(-1))
    
    loss.backward()
    optimizer.step()
    
    return loss.data[0]
    

In [9]:
epochs = 2000
print_fq = 100
plot_fq = 10

embed_dim = 128
hidden_dim = 128
batch_size = 64
model = CharSeqRNN(n_chars, embed_dim, hidden_dim).cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)

criterion = nn.CrossEntropyLoss()

start = time.time()
losses = []
loss_avg = 0

for epoch in range(1, epochs+1):
    loss1 = train(batch_size)
    loss_avg += loss1
    if epoch % print_fq == 0:
        print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / epochs * 100, loss1))
        print(run('\n', 150, 0.5), '\n')

    if epoch % plot_fq == 0:
        losses.append(loss_avg / plot_fq)
        loss_avg = 0
    
#print('[%s (%d %d%%) %.4f]' % (time_since(start), epoch, epoch / epochs * 100, loss1))


[0m 6s (100 5%) 2.7967]

Aky?UCvNF ine the that se

 he yl wf facr h leer ter oos net corous the we hone win me are wo ihe wd ues a h

e tat ind linsm in he is mon yed din sig 

[0m 13s (200 10%) 2.4326]

ELb'll the dory this therd he ores me save that bave hare he mand ou the the wand, he ond men math the beare so the wat the as you ine me the mer an t 

[0m 19s (300 15%) 2.2405]

FENE:
The sart,
And wit the the ther mis and seres the werens and os the frit your the the ther rothe my me to in the fere srorst ant the werem the me 

[0m 25s (400 20%) 2.1572]

AS, are som all souly hat the sow sour I sin fore hat and couls the pare my and it uld fard thear the as at the for to is the hinging ond th the the s 

[0m 31s (500 25%) 2.0963]

Goud for the the the the me to me and will be preach the stere lith the and surienind shall more sied and to the more seell the mowe hours not wath th 

[0m 37s (600 30%) 2.0122]

Sis the that lords the the mand, yer in the wing the will the all thou the 

In [10]:
print("After training")
print(run('\n', 500, 0.2))

After training

The seen the such the grace the string the sorrow the stranger for the sorrow the man the stranger the stand the lord,
The stranger for the so make this so man the shall the such the bear the provest the so man the such the such of the so make the pringed the courter the stranger to me and the sone to the such the such the such a son the bear a promise and the sone to the son and the stranger the stand the stranger me to me the proves to me the prince of the will the sorrow the so man the couse 


In [11]:
# saving model parameters
state = {
            'args': (n_chars, embed_dim, hidden_dim), 
            'state_dict': model.state_dict()
        }

torch.save(state, 'model.pth.tar')