In [51]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import string
import random
from random import randint
import torch.backends.cudnn as cudnn

In [52]:
# Read data into string and seperate train and valid data
with open('text8') as file:
    data = file.read()
valid_size = 1000
valid_text = data[:valid_size]
train_text = data[valid_size:]
train_size = len(train_text)
valid_size = len(valid_text)
print(train_size, train_text[:64])
print(valid_size, valid_text[:64])

99999000 ons anarchists advocate social relations based upon voluntary as
1000  anarchism originated as a term of abuse first used against earl


In [53]:
# Utility functions
vocab_size = len(string.ascii_lowercase) + 1 # 0 index for ' '
first_letter = ord(string.ascii_lowercase[0])

def char2id(char):
    if char in string.ascii_lowercase:
        return ord(char) - first_letter + 1
    elif char == ' ':
        return 0
    else:
        print('Unexpected character %s' % char)
        return 0

def id2char(dictid):
    if dictid > 0:
        return chr(dictid + first_letter - 1)
    else: return ' '

print(char2id('a'), char2id('z'), char2id(' '), char2id('ï'))
print(id2char(1), id2char(26), id2char(0))

Unexpected character ï
1 26 0 0
a z  


In [54]:
batch_size = 64
num_unrollings = 10

# Generate batches parallely across the text at equal intervals
# Each batch contains one character from each of the positions
# Positions are updated after generating every batch
# The next batch would therefore contain the next characters from all the chosen positions
# num_unrollings number of batches are processed at once
# Each character is represented as a one hot vector
class BatchGenerator(object):
    def __init__(self, text, batch_size, num_unrollings):
        self._text = text
        self._text_size = len(text)
        self._batch_size = batch_size
        self._num_unrollings = num_unrollings
        segment = self._text_size // batch_size
        self._cursor = [offset*segment for offset in range(batch_size)]
        self._last_batch = self._next_batch()
        
    def _next_batch(self):
        batch = np.zeros(shape=(self._batch_size, vocab_size), dtype=np.float)
        for b in range(self._batch_size):
            batch[b, char2id(self._text[self._cursor[b]])] = 1.0
            self._cursor[b] = (self._cursor[b] + 1) % self._text_size
        return batch
    
    def _next(self):
        batches = [self._last_batch]
        for step in range(self._num_unrollings):
            batches.append(self._next_batch())
        self._last_batch = batches[-1]
        return batches

In [55]:
dtype = torch.FloatTensor
hidden_size = 1024 # 1024 for gpu 

class LSTM(nn.Module):
    
    def __init__(self, vocab_size, hidden_size):
        super(LSTM, self).__init__()
        # LSTM architecture
        self.inputW = nn.Linear(vocab_size, hidden_size)
        self.inputU = nn.Linear(hidden_size, hidden_size)
        
        self.forgetW = nn.Linear(vocab_size, hidden_size)
        self.forgetU = nn.Linear(hidden_size, hidden_size)
        
        self.outputW = nn.Linear(vocab_size, hidden_size)
        self.outputU = nn.Linear(hidden_size, hidden_size)
        
        self.cellW = nn.Linear(vocab_size, hidden_size)
        self.cellU = nn.Linear(hidden_size, hidden_size)
        
        self.softW = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0)
        
    def forward(self, one_hot_input, cell_prev, hidden_prev):
        sig = nn.Sigmoid()
        tnh = nn.Tanh()
        cell_prev = Variable(cell_prev)
        hidden_prev = Variable(hidden_prev)
        input_gate = sig(self.inputW(one_hot_input) + self.inputU(hidden_prev))
        forget_gate = sig(self.forgetW(one_hot_input) + self.forgetU(hidden_prev))
        output_gate = sig(self.outputW(one_hot_input) + self.outputU(hidden_prev))
        update = tnh(self.cellW(one_hot_input) + self.cellU(hidden_prev))
        cell = (forget_gate * cell_prev) + (input_gate * update)
        hidden = output_gate * tnh(cell)
        logits = self.softW(hidden)
        logits = self.dropout(logits)
        return cell.data, hidden.data, logits

lstm = LSTM(vocab_size, hidden_size)
lstm.cuda()

LSTM(
  (inputW): Linear(in_features=27, out_features=1024)
  (inputU): Linear(in_features=1024, out_features=1024)
  (forgetW): Linear(in_features=27, out_features=1024)
  (forgetU): Linear(in_features=1024, out_features=1024)
  (outputW): Linear(in_features=27, out_features=1024)
  (outputU): Linear(in_features=1024, out_features=1024)
  (cellW): Linear(in_features=27, out_features=1024)
  (cellU): Linear(in_features=1024, out_features=1024)
  (softW): Linear(in_features=1024, out_features=27)
  (dropout): Dropout(p=0)
)

In [56]:
class StateManager(object):
    def __init__(self):
        self.hidden_state = torch.zeros(batch_size, hidden_size).type(dtype)
        self.cell_state = torch.zeros(batch_size, hidden_size).type(dtype)
        
    def save_state(self, cell_state, hidden_state):
        self.cell_state = cell_state
        self.hidden_state = hidden_state
        
    def load_state(self):
        return self.cell_state, self.hidden_state
    
    def reset_state(self):
        self.hidden_state = torch.zeros(batch_size, hidden_size).type(dtype)
        self.cell_state = torch.zeros(batch_size, hidden_size).type(dtype)
        
sm = StateManager()

In [57]:
learning_rate = 0.1
train_batches = BatchGenerator(train_text, batch_size, num_unrollings)
optimizer = torch.optim.SGD(lstm.parameters(), lr=learning_rate)

def train():
    cell, hidden = sm.load_state()
    batches = train_batches._next()
    optimizer.zero_grad()
    lstm.zero_grad()
    loss_function = nn.CrossEntropyLoss()
    loss = 0
    for u in range(num_unrollings):
        one_hot_input = Variable(torch.from_numpy(batches[u]).type(dtype), requires_grad=False)
        # cell, hidden, logits = lstm(one_hot_input, cell, hidden)
        cell, hidden, logits = lstm(one_hot_input.cuda(), cell.cuda(), hidden.cuda())
        labels = Variable(torch.from_numpy(np.argmax(batches[u+1], axis=1)))
        # loss += loss_function(logits, labels)
        loss += loss_function(logits.cuda(), labels.cuda())
    loss.backward()
    torch.nn.utils.clip_grad_norm(lstm.parameters(), 5)
    optimizer.step()
    sm.save_state(cell, hidden)
    return loss / num_unrollings

#Sampling implementation

def sample_distribution(distribution):
#Sample one element from a distribution assumed to be an array of normalized probabilities
    r = random.uniform(0, 1)
    s = 0
    for i in range(len(distribution)):
        s += distribution[i]
        if s >= r:
            return i
    return len(distribution) - 1

def sample():
    charid = randint(0, vocab_size-1)
    print(id2char(charid), end='')
    cell = torch.zeros(1, hidden_size).type(dtype)
    hidden = torch.zeros(1, hidden_size).type(dtype) 
    soft = nn.Softmax(dim=1)
    for i in range(100):
        one_hot = torch.zeros(1, vocab_size).type(dtype)
        one_hot[0, charid] = 1.0
        one_hot = Variable(one_hot, requires_grad=False)
        # cell, hidden, logits = lstm(one_hot, cell, hidden)
        cell, hidden, logits = lstm(one_hot.cuda(), cell.cuda(), hidden.cuda())
        output = soft(logits)
        charid = sample_distribution(output.data[0])
        print(id2char(charid), end='')
    print(' ')

In [58]:
# Generate a random character and feed to model
# Take predicted character and feed it back again to generate subsequent characters

def beam_search(distributions, beam_size, alog_probs, sequences):
    
    distributions = torch.abs(torch.log(distributions))
    matrix = alog_probs + distributions
    indices = np.argsort(matrix.numpy(), axis=None)[:beam_size]
    indices = [(i//vocab_size, (i-(i//vocab_size)*vocab_size)) for i in indices]
    for i in range(beam_size):
        alog_probs[i] = matrix[indices[i][0], indices[i][1]]
    seq_ids = [i[0] for i in indices]
    char_ids = [i[1] for i in indices]
    j = 0
    temp = [None] * beam_size
    for i in seq_ids:
        temp[j] = sequences[i] + ',' + str(char_ids[j])
        j += 1
    sequences = temp
    return alog_probs, sequences, char_ids

def sample():
    beam_size = 5
    sequences = list()
    alog_probs = torch.ones(beam_size, 1)
    charid = randint(0, vocab_size-1)
    for i in range(beam_size):
        sequences.append(str(charid))
    last_indices = [charid] * beam_size
    distributions = torch.zeros(beam_size, vocab_size) 
    cell = torch.zeros(1, hidden_size).type(dtype)
    hidden = torch.zeros(1, hidden_size).type(dtype) 
    soft = nn.Softmax(dim=1)
    for i in range(100):
        for b in range(beam_size):
            one_hot = torch.zeros(1, vocab_size).type(dtype)
            one_hot[0, last_indices[b]] = 1.0
            one_hot = Variable(one_hot, requires_grad=False)
            # cell, hidden, logits = lstm(one_hot, cell, hidden)
            cell, hidden, logits = lstm(one_hot.cuda(), cell.cuda(), hidden.cuda())
            output = soft(logits)
            distributions[b] = output.data
        alog_probs, sequences, last_indices = beam_search(distributions, beam_size, alog_probs, sequences)     
            
    ar = sequences[0].split(',')
    ids = [int(i) for i in ar]
    for i in ids:
        print(id2char(i), end='')
    print(' ')

In [59]:
valid_batches = BatchGenerator(valid_text, 1, 1)

def valid_perplexity():
    loss_function = nn.CrossEntropyLoss()
    loss = 0
    cell = torch.zeros(1, hidden_size).type(dtype)
    hidden = torch.zeros(1, hidden_size).type(dtype)
    for i in range(valid_size):
        batches = valid_batches._next()
        one_hot_input = Variable(torch.from_numpy(batches[0]).type(dtype), requires_grad=False)
        # cell, hidden, logits = lstm(one_hot_input, cell, hidden)
        cell, hidden, logits = lstm(one_hot_input.cuda(), cell.cuda(), hidden.cuda())
        labels = Variable(torch.from_numpy(np.argmax(batches[1], axis=1)))
        # loss += loss_function(logits, labels)
        loss += loss_function(logits.cuda(), labels.cuda())
    return torch.exp(loss / valid_size)

In [60]:
num_iters = 50001 #50001, no dropout, lr=0.1

cudnn.benchmark = True
cudnn.fasttest = True

for i in range(num_iters):
    lstm.train()
    l = train()
    if i%1000 == 0: 
        print('Average loss at step %d: %.3f ' % (i,l))
        print('Minibatch perplexity: %.3f' % torch.exp(l))
        print('Validation perplexity: %.3f' % valid_perplexity())
        lstm.eval()
        sample()
        print('')

Average loss at step 0: 3.287 
Minibatch perplexity: 26.768
Validation perplexity: 23.265
b                                                                                                     

Average loss at step 1000: 2.218 
Minibatch perplexity: 9.192
Validation perplexity: 9.839
n he he th anere or fo fe here t then fo fo he anenen athererinin h f the herenene here the ther f th 

Average loss at step 2000: 2.222 
Minibatch perplexity: 9.224
Validation perplexity: 8.936
by aronononon aner t for th hex f the fofof zere t ze f throne ze f zeronexix f f foner frenin ore fo 

Average loss at step 3000: 2.073 
Minibatch perplexity: 7.948
Validation perplexity: 7.703
kenererengreronin onix w anis tine on ere hthenereerereinin t fonere t h an one tex inengenonis o ani 

Average loss at step 4000: 1.926 
Minibatch perplexity: 6.862
Validation perplexity: 7.147
h w wiger inde ane ze ane se s sthe on onina ine th on th h onin h oreredingalee stindighin thininin  

Average loss at step 5000:

Average loss at step 42000: 1.489 
Minibatch perplexity: 4.435
Validation perplexity: 3.748
kenedes orin therralolat oneredec s aregerareys trin s anereridingenenedd tighanonere on anore ones o 

Average loss at step 43000: 1.341 
Minibatch perplexity: 3.821
Validation perplexity: 3.661
ky anovinenanthanilix renonthochilintarin sty tenon ve milinanes al teddayes d tusuelyelyonareyoferat 

Average loss at step 44000: 1.339 
Minibatch perplexity: 3.815
Validation perplexity: 3.653
n werabenanelagrer e ano anothintherongesigellinarelilereryo ar elarenelirerinonanother s an rinereth 

Average loss at step 45000: 1.276 
Minibatch perplexity: 3.581
Validation perplexity: 3.635
jan oreredi a o tinger one o areanous ouristai oly o s anocoglan ayoneronar s cone anonin w atharindi 

Average loss at step 46000: 1.426 
Minibatch perplexity: 4.164
Validation perplexity: 3.658
housenedanofon alyered conanenecof olanec a wandalenas wanewan ononindithis senedenofonof thinenedano 

Average loss at step