In [None]:
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

In [None]:
# with open("data/shakespeare.txt", "r") as data_file:
with open("data/tinyshakespeare.txt", "r") as data_file:
    data = data_file.read()

In [None]:
data_size = len(data)
print(data_size)

In [None]:
alphabet = set(data)
alphabet_size = len(alphabet)
print(alphabet_size)

In [None]:
symbol_to_id = {}
id_to_symbol = {}
for symbol_id, symbol in enumerate(sorted(alphabet)):
    symbol_to_id[symbol] = symbol_id
    id_to_symbol[symbol_id] = symbol

In [None]:
def batch_one_hot_encoding(symbols):
    one_hot_encoded = torch.zeros((len(symbols), alphabet_size))
    for i, symbol in enumerate(symbols):
        symbol_id = symbol_to_id[symbol]
        one_hot_encoded[i][symbol_id] = 1
    return one_hot_encoded

In [None]:
def extract_batch_labels(one_hot_encoded):
    return torch.max(one_hot_encoded, 1)[1].unsqueeze(1)

In [None]:
hidden_size = 100

class MinCharRNN(nn.Module):
    
    def __init__(self):
        super(MinCharRNN, self).__init__()
        
        self.input_to_hidden = nn.Linear(alphabet_size, hidden_size)
        self.hidden_to_hidden = nn.Linear(hidden_size, hidden_size)
        self.hidden_to_output = nn.Linear(hidden_size, alphabet_size)

    def forward(self, input_symbol, hidden_state):
        hidden_state = F.tanh(self.input_to_hidden(input_symbol) + self.hidden_to_hidden(hidden_state))
        output = self.hidden_to_output(hidden_state)
        return output, hidden_state

In [None]:
def initialize_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.uniform(m.weight, -0.01, 0.01)

In [None]:
use_cuda = torch.cuda.is_available()

In [None]:
learning_rate = 1e-1

model = MinCharRNN()

if use_cuda:
    model = model.cuda()
    
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adagrad(model.parameters(), lr=learning_rate)

model.apply(initialize_weights)

In [None]:
epochs = 2
batch_size = 16
batches = data_size // (batch_size + 1)
log_every = 1000
gradient_clipping = 5

initial_state = torch.zeros((1, hidden_size))
if use_cuda:
    initial_state = initial_state.cuda()

for epoch in range(epochs):
    # reset the state before every epoch
    last_hidden_state = initial_state
    
    accumulated_loss = 0.0
    
    for batch_id in range(batches):
        # reusing state from last batch
        hidden_state = Variable(last_hidden_state)
        
        optimizer.zero_grad()
        
        v_loss = 0.0
        
        batch_start = batch_id * batch_size
        batch_end = batch_start + batch_size + 1
        batch_symbols = batch_one_hot_encoding(data[batch_start:batch_end])
        if use_cuda:
            batch_symbols = batch_symbols.cuda()
        batch_labels = extract_batch_labels(batch_symbols)
        
        for sequence_id in range(batch_size):
            symbol = batch_symbols[sequence_id]
            label = batch_labels[sequence_id + 1]
            
            v_symbol = Variable(symbol)
            v_label = Variable(label)
            
            v_prediction, hidden_state = model(v_symbol, hidden_state)
            v_loss += loss_function(v_prediction, v_label)
            
        v_loss.backward()
        
        for parameter in model.parameters():
            parameter.grad.data.clamp_(-gradient_clipping, gradient_clipping)
        
        optimizer.step()
        
        if use_cuda:
            v_loss = v_loss.cpu()
            
        accumulated_loss += v_loss.data[0]
        
        if batch_id % log_every == log_every - 1:
            mean_loss = accumulated_loss / float(log_every)
            accumulated_loss = 0.0
            print("Batch: {:6d} Loss: {:.4f}".format(batch_id + 1, mean_loss))
        
        # cannot preserve variable between batches
        last_hidden_state = hidden_state.data