In [None]:
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import time

In [None]:
# with open("data/shakespeare.txt", "r") as data_file:
with open("data/tinyshakespeare.txt", "r") as data_file:
    data = data_file.read()

In [None]:
data_size = len(data)
print(data_size)

In [None]:
alphabet = set(data)
alphabet_size = len(alphabet)
print(alphabet_size)

In [None]:
symbol_to_id = {}
id_to_symbol = {}
for symbol_id, symbol in enumerate(sorted(alphabet)):
    symbol_to_id[symbol] = symbol_id
    id_to_symbol[symbol_id] = symbol

In [None]:
def one_hot_encoding(symbol):
    one_hot_encoded = torch.zeros(alphabet_size)
    symbol_id = symbol_to_id[symbol]
    one_hot_encoded[symbol_id] = 1
    return one_hot_encoded

In [None]:
def label_tensor(symbol):
    return torch.LongTensor([symbol_to_id[symbol]])

In [None]:
hidden_size = 100

class MinCharRNN(nn.Module):
    
    def __init__(self):
        super(MinCharRNN, self).__init__()
        
        self.input_to_hidden = nn.Linear(alphabet_size, hidden_size)
        self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size)
        self.hidden_to_output = nn.Linear(hidden_size, alphabet_size)

    def forward(self, input_symbol, hidden_state):
        hidden_state = F.tanh(self.input_to_hidden(input_symbol) + self.hidden_to_hidden(hidden_state))
        output = self.hidden_to_output(hidden_state)
        return output, hidden_state

In [None]:
def initialize_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.uniform(m.weight, -0.01, 0.01)

In [None]:
learning_rate = 1e-1

model = MinCharRNN()    
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adagrad(model.parameters(), lr=learning_rate)

model.apply(initialize_weights)

In [None]:
# model.load_state_dict(torch.load("data/min-char-rnn.torch"))

In [None]:
epochs = 10
batch_size = 25
batches = data_size // (batch_size + 1)
log_every = 1000
gradient_clipping = 5

initial_state = torch.zeros((1, hidden_size))

for epoch in range(epochs):
    # reset the state before every epoch
    last_hidden_state = initial_state
    
    accumulated_loss = 0.0
    start_time = time.time()
    
    for batch_id in range(batches):
        # reuse the hidden state from last batch
        hidden_state = Variable(last_hidden_state)
        
        # clear the gradient information from the past batch
        optimizer.zero_grad()
        
        # loss for the entire sequence
        v_loss = Variable(torch.zeros(1))
        
        # for every symbol in the batch
        # try predict the next symbol
        # and meassure the loss
        for sequence_id in range(batch_size):
            symbol = one_hot_encoding(data[batch_id * batch_size + sequence_id])
            label = label_tensor(data[sequence_id + 1])
            
            v_symbol = Variable(symbol)
            v_label = Variable(label)
            
            v_prediction, hidden_state = model(v_symbol, hidden_state)
            v_loss += loss_function(v_prediction, v_label)
            
        # backpropagate through time
        v_loss.backward()
        
        # gradient clipping to avoid exploding gradients
        for parameter in model.parameters():
            parameter.grad.data.clamp_(-gradient_clipping, gradient_clipping)
        
        # update parameters
        optimizer.step()
        
        # some logging
        accumulated_loss += v_loss.data[0]
        
        if batch_id % log_every == log_every - 1:
            mean_loss = accumulated_loss / float(log_every)
            end_time = time.time()
            elapsed_time = end_time - start_time
            accumulated_loss = 0.0
            start_time = end_time
            print("Batch: {:6d} Loss: {:.4f} Time: {:.2f} seconds".format(batch_id + 1, mean_loss, elapsed_time))
        
        # take the hidden state out of the variable
        # to avoid backpropagating the next batch to this one
        last_hidden_state = hidden_state.data

In [None]:
# torch.save(model.state_dict(), "data/min-char-rnn.torch")

In [None]:
sample_size = 200

symbol = Variable(one_hot_encoding("\n"), volatile=True)
hidden_state = Variable(torch.zeros((1, hidden_size)), volatile=True)

sample = ""
for sample_id in range(sample_size):
    symbol, hidden_state = model(symbol, hidden_state)
    probabilities = F.softmax(symbol)
    symbol_id = np.random.choice(range(alphabet_size), p=probabilities.data.squeeze(0).numpy())
    sample += id_to_symbol[symbol_id]
    
print(sample)