In [12]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import string
import random
from random import randint
import torch.backends.cudnn as cudnn

In [13]:
# Read data into string and seperate train and valid data
with open('text8') as file:
    data = file.read()
valid_size = 1000
valid_text = data[:valid_size]
train_text = data[valid_size:]
train_size = len(train_text)
valid_size = len(valid_text)
print(train_size, train_text[:64])
print(valid_size, valid_text[:64])

99999000 ons anarchists advocate social relations based upon voluntary as
1000  anarchism originated as a term of abuse first used against earl


In [14]:
# Utility functions
vocab_size = len(string.ascii_lowercase) + 1 # 0 index for ' '
first_letter = ord(string.ascii_lowercase[0])

def char2id(char):
    if char in string.ascii_lowercase:
        return ord(char) - first_letter + 1
    elif char == ' ':
        return 0
    else:
        print('Unexpected character %s' % char)
        return 0

def id2char(dictid):
    if dictid > 0:
        return chr(dictid + first_letter - 1)
    else: return ' '

print(char2id('a'), char2id('z'), char2id(' '), char2id('ï'))
print(id2char(1), id2char(26), id2char(0))

Unexpected character ï
1 26 0 0
a z  


In [15]:
batch_size = 64
num_unrollings = 10

# Generate batches parallely across the text at equal intervals
# Each batch contains one character from each of the positions
# Positions are updated after generating every batch
# The next batch would therefore contain the next characters from all the chosen positions
# num_unrollings number of batches are processed at once
# Each character is represented as a one hot vector
class BatchGenerator(object):
    def __init__(self, text, batch_size, num_unrollings):
        self._text = text
        self._text_size = len(text)
        self._batch_size = batch_size
        self._num_unrollings = num_unrollings
        segment = self._text_size // batch_size
        self._cursor = [offset*segment for offset in range(batch_size)]
        self._last_batch = self._next_batch()
        
    def _next_batch(self):
        batch = np.zeros(shape=(self._batch_size, vocab_size), dtype=np.float)
        for b in range(self._batch_size):
            batch[b, char2id(self._text[self._cursor[b]])] = 1.0
            self._cursor[b] = (self._cursor[b] + 1) % self._text_size
        return batch
    
    def _next(self):
        batches = [self._last_batch]
        for step in range(self._num_unrollings):
            batches.append(self._next_batch())
        self._last_batch = batches[-1]
        return batches

In [16]:
dtype = torch.FloatTensor
hidden_size = 1024 # 1024 for gpu 

class LSTM(nn.Module):
    
    def __init__(self, vocab_size, hidden_size):
        super(LSTM, self).__init__()
        # LSTM architecture
        
        # Layer1
        self.inputW1 = nn.Linear(vocab_size, hidden_size)
        self.inputU1 = nn.Linear(hidden_size, hidden_size) 
        self.forgetW1 = nn.Linear(vocab_size, hidden_size)
        self.forgetU1 = nn.Linear(hidden_size, hidden_size)
        self.outputW1 = nn.Linear(vocab_size, hidden_size)
        self.outputU1 = nn.Linear(hidden_size, hidden_size)
        self.cellW1 = nn.Linear(vocab_size, hidden_size)
        self.cellU1 = nn.Linear(hidden_size, hidden_size)
        
        # Layer2
        self.inputW2 = nn.Linear(hidden_size, hidden_size)
        self.inputU2 = nn.Linear(hidden_size, hidden_size) 
        self.forgetW2 = nn.Linear(hidden_size, hidden_size)
        self.forgetU2 = nn.Linear(hidden_size, hidden_size)
        self.outputW2 = nn.Linear(hidden_size, hidden_size)
        self.outputU2 = nn.Linear(hidden_size, hidden_size)
        self.cellW2 = nn.Linear(hidden_size, hidden_size)
        self.cellU2 = nn.Linear(hidden_size, hidden_size)
        
        # Softmax weight
        self.softW = nn.Linear(hidden_size, vocab_size)
        # Dropout
        self.dropout = nn.Dropout(0)
        
    def forward(self, one_hot_input, cell_prev, hidden_prev, batch_size):
        sig = nn.Sigmoid()
        tnh = nn.Tanh()
        
        # Extract cell and hidden data for each layer
        cell_prev = Variable(cell_prev)
        hidden_prev = Variable(hidden_prev)
        cell_prev1 = cell_prev[:batch_size, :]
        hidden_prev1 = hidden_prev[:batch_size, :]
        cell_prev2 = cell_prev[batch_size:, :]
        hidden_prev2 = hidden_prev[batch_size:, :]
        
        # Layer 1 computation
        input_gate = sig(self.inputW1(one_hot_input) + self.inputU1(hidden_prev1))
        forget_gate = sig(self.forgetW1(one_hot_input) + self.forgetU1(hidden_prev1))
        output_gate = sig(self.outputW1(one_hot_input) + self.outputU1(hidden_prev1))
        update = tnh(self.cellW1(one_hot_input) + self.cellU1(hidden_prev1))
        cell1 = (forget_gate * cell_prev1) + (input_gate * update)
        hidden1 = output_gate * tnh(cell1)
        
        # Layer 2 computation
        input_gate = sig(self.inputW2(hidden1) + self.inputU2(hidden_prev2))
        forget_gate = sig(self.forgetW2(hidden1) + self.forgetU2(hidden_prev2))
        output_gate = sig(self.outputW2(hidden1) + self.outputU2(hidden_prev2))
        update = tnh(self.cellW2(hidden1) + self.cellU2(hidden_prev2))
        cell2 = (forget_gate * cell_prev2) + (input_gate * update)
        hidden2 = output_gate * tnh(cell2)
        
        logits = self.softW(hidden2)
        logits = self.dropout(logits)
        cell = torch.cat((cell1, cell2), dim=0)
        hidden = torch.cat((hidden2, hidden2), dim=0)
        
        return cell.data, hidden.data, logits

lstm = LSTM(vocab_size, hidden_size)
lstm.cuda()

LSTM(
  (inputW1): Linear(in_features=27, out_features=1024)
  (inputU1): Linear(in_features=1024, out_features=1024)
  (forgetW1): Linear(in_features=27, out_features=1024)
  (forgetU1): Linear(in_features=1024, out_features=1024)
  (outputW1): Linear(in_features=27, out_features=1024)
  (outputU1): Linear(in_features=1024, out_features=1024)
  (cellW1): Linear(in_features=27, out_features=1024)
  (cellU1): Linear(in_features=1024, out_features=1024)
  (inputW2): Linear(in_features=1024, out_features=1024)
  (inputU2): Linear(in_features=1024, out_features=1024)
  (forgetW2): Linear(in_features=1024, out_features=1024)
  (forgetU2): Linear(in_features=1024, out_features=1024)
  (outputW2): Linear(in_features=1024, out_features=1024)
  (outputU2): Linear(in_features=1024, out_features=1024)
  (cellW2): Linear(in_features=1024, out_features=1024)
  (cellU2): Linear(in_features=1024, out_features=1024)
  (softW): Linear(in_features=1024, out_features=27)
  (dropout): Dropout(p=0)
)

In [17]:
class StateManager(object):
    def __init__(self):
        self.hidden_state = torch.zeros(2*batch_size, hidden_size).type(dtype)
        self.cell_state = torch.zeros(2*batch_size, hidden_size).type(dtype)
        
    def save_state(self, cell_state, hidden_state):
        self.cell_state = cell_state
        self.hidden_state = hidden_state
        
    def load_state(self):
        return self.cell_state, self.hidden_state
    
    def reset_state(self):
        self.hidden_state = torch.zeros(2*batch_size, hidden_size).type(dtype)
        self.cell_state = torch.zeros(2*batch_size, hidden_size).type(dtype)
        
sm = StateManager()

In [18]:
learning_rate = 0.1
train_batches = BatchGenerator(train_text, batch_size, num_unrollings)
optimizer = torch.optim.SGD(lstm.parameters(), lr=learning_rate)

def train():
    cell, hidden = sm.load_state()
    batches = train_batches._next()
    optimizer.zero_grad()
    lstm.zero_grad()
    loss_function = nn.CrossEntropyLoss()
    loss = 0
    for u in range(num_unrollings):
        one_hot_input = Variable(torch.from_numpy(batches[u]).type(dtype), requires_grad=False)
        # cell, hidden, logits = lstm(one_hot_input, cell, hidden)
        cell, hidden, logits = lstm(one_hot_input.cuda(), cell.cuda(), hidden.cuda(), batch_size)
        labels = Variable(torch.from_numpy(np.argmax(batches[u+1], axis=1)))
        # loss += loss_function(logits, labels)
        loss += loss_function(logits.cuda(), labels.cuda())
    loss.backward()
    torch.nn.utils.clip_grad_norm(lstm.parameters(), 5)
    optimizer.step()
    sm.save_state(cell, hidden)
    return loss / num_unrollings

In [19]:
# Sampling implementation

def sample_distribution(distribution):
# Sample one element from a distribution assumed to be an array of normalized probabilities
    r = random.uniform(0, 1)
    s = 0
    for i in range(len(distribution)):
        s += distribution[i]
        if s >= r:
            return i
    return len(distribution) - 1

def sample():
    charid = randint(0, vocab_size-1)
    print(id2char(charid), end='')
    cell = torch.zeros(2*1, hidden_size).type(dtype)
    hidden = torch.zeros(2*1, hidden_size).type(dtype) 
    soft = nn.Softmax(dim=1)
    for i in range(100):
        one_hot = torch.zeros(1, vocab_size).type(dtype)
        one_hot[0, charid] = 1.0
        one_hot = Variable(one_hot, requires_grad=False)
        # cell, hidden, logits = lstm(one_hot, cell, hidden)
        cell, hidden, logits = lstm(one_hot.cuda(), cell.cuda(), hidden.cuda(), 1)
        output = soft(logits)
        charid = sample_distribution(output.data[0])
        print(id2char(charid), end='')
    print(' ')

In [20]:
# Beam search implementatioon

def beam_search(distributions, beam_size, alog_probs, sequences):
    
    distributions = torch.abs(torch.log(distributions))
    matrix = alog_probs + distributions
    indices = np.argsort(matrix.numpy(), axis=None)[:beam_size]
    indices = [(i//vocab_size, (i-(i//vocab_size)*vocab_size)) for i in indices]
    for i in range(beam_size):
        alog_probs[i] = matrix[indices[i][0], indices[i][1]]
    seq_ids = [i[0] for i in indices]
    char_ids = [i[1] for i in indices]
    j = 0
    temp = [None] * beam_size
    for i in seq_ids:
        temp[j] = sequences[i] + ',' + str(char_ids[j])
        j += 1
    sequences = temp
    return alog_probs, sequences, char_ids

def sample_beam():
    beam_size = 5
    alog_probs = torch.ones(beam_size, 1)
    charid = randint(0, vocab_size-1)
    sequences = [str(charid)] * beam_size
    last_indices = [charid] * beam_size
    distributions = torch.zeros(beam_size, vocab_size) 
    cell = torch.zeros(2*1, hidden_size).type(dtype)
    hidden = torch.zeros(2*1, hidden_size).type(dtype) 
    soft = nn.Softmax(dim=1)
    for i in range(100):
        for b in range(beam_size):
            one_hot = torch.zeros(1, vocab_size).type(dtype)
            one_hot[0, last_indices[b]] = 1.0
            one_hot = Variable(one_hot, requires_grad=False)
            # cell, hidden, logits = lstm(one_hot, cell, hidden)
            cell, hidden, logits = lstm(one_hot.cuda(), cell.cuda(), hidden.cuda(), 1)
            output = soft(logits)
            distributions[b] = output.data
        alog_probs, sequences, last_indices = beam_search(distributions, beam_size, alog_probs, sequences)     
            
    ar = sequences[0].split(',')
    ids = [int(i) for i in ar]
    for i in ids:
        print(id2char(i), end='')
    print(' ')

In [21]:
valid_batches = BatchGenerator(valid_text, 1, 1)

def valid_perplexity():
    loss_function = nn.CrossEntropyLoss()
    loss = 0
    cell = torch.zeros(2*1, hidden_size).type(dtype)
    hidden = torch.zeros(2*1, hidden_size).type(dtype)
    for i in range(valid_size):
        batches = valid_batches._next()
        one_hot_input = Variable(torch.from_numpy(batches[0]).type(dtype), requires_grad=False)
        # cell, hidden, logits = lstm(one_hot_input, cell, hidden)
        cell, hidden, logits = lstm(one_hot_input.cuda(), cell.cuda(), hidden.cuda(), 1)
        labels = Variable(torch.from_numpy(np.argmax(batches[1], axis=1)))
        # loss += loss_function(logits, labels)
        loss += loss_function(logits.cuda(), labels.cuda())
    return torch.exp(loss / valid_size)

In [22]:
num_epochs = 2
num_iters = 160000 

cudnn.benchmark = True
cudnn.fasttest = True

for e in range(num_epochs):
    print('Epoch %d' % e)
    for i in range(num_iters):
        l = train()
        if i%1000 == 0: 
            print('Average loss at step %d: %.3f ' % (i,l))
            print('Minibatch perplexity: %.3f' % torch.exp(l))
            print('Validation perplexity: %.3f' % valid_perplexity())
            lstm.eval()
            sample()
            #sample_beam()
            print('')
            lstm.train()
    learning_rate = learning_rate / 10
    optimizer = torch.optim.SGD(lstm.parameters(), lr=learning_rate)

Epoch 0
Average loss at step 0: 3.297 
Minibatch perplexity: 27.018
Validation perplexity: 25.538
fxqfbojgpaqehrskalgykyeslybybladlpbtteftyyozwhaqdvjqhwfsxp parhvwootjlmuyonjpofxrubc mlxkystjss owzba 

Average loss at step 1000: 2.305 
Minibatch perplexity: 10.028
Validation perplexity: 10.602
go fha devt n desins mone era whe jad on ote wora thines es nt wuur the thith ende ta turinss cing al 

Average loss at step 2000: 2.254 
Minibatch perplexity: 9.523
Validation perplexity: 9.275
chrelagb hadt andy the on caped a grower gerijpto buseals samdeis an ane lraprens cand one chepmare p 

Average loss at step 3000: 2.089 
Minibatch perplexity: 8.076
Validation perplexity: 7.938
wer sive tro the ewens tom noze bistity opither pedewrice in onlich toum vord pooply ledante anderico 

Average loss at step 4000: 1.909 
Minibatch perplexity: 6.748
Validation perplexity: 7.152
ines bat outherva zreecop offare erus ix deser of nine arl desied and fot foubajs proted hecided hich 

Average loss at 

Average loss at step 42000: 1.392 
Minibatch perplexity: 4.024
Validation perplexity: 3.435
ates on the introduced science and played s sorre of hannward would ring i better matter farely in on 

Average loss at step 43000: 1.213 
Minibatch perplexity: 3.365
Validation perplexity: 3.436
e history of the her to gave dudies space of homitore was the latter champelless the himaley xames ov 

Average loss at step 44000: 1.281 
Minibatch perplexity: 3.599
Validation perplexity: 3.395
has influential artistic broaders named by tearly onder corresponds in largely general italy champion 

Average loss at step 45000: 1.190 
Minibatch perplexity: 3.286
Validation perplexity: 3.313
ded sees of premier blbc q qthmger medites and containing for subjeques destroyed for out of a set of 

Average loss at step 46000: 1.361 
Minibatch perplexity: 3.902
Validation perplexity: 3.355
vagened in addiated by their element and this criminal covean and mountained to secure primacilla fam 

Average loss at step

Average loss at step 84000: 1.351 
Minibatch perplexity: 3.863
Validation perplexity: 3.251
king her clouds hammozions called johnny city in the one of each state of gods to sq that is hockey b 

Average loss at step 85000: 1.268 
Minibatch perplexity: 3.555
Validation perplexity: 3.182
xing landland or much toronto a buildings on resolution of the header script or what he called the fi 

Average loss at step 86000: 1.259 
Minibatch perplexity: 3.524
Validation perplexity: 3.237
it following days to sherton a total to say believed on space in the feminian attempted the tradition 

Average loss at step 87000: 1.274 
Minibatch perplexity: 3.575
Validation perplexity: 3.251
n and ve s leep are frac two years to dominated he heard or tooch other country s itsorpherms began a 

Average loss at step 88000: 1.190 
Minibatch perplexity: 3.287
Validation perplexity: 3.174
ex praised historic to perfect pil in the age recovered on visaved to see adams with luther south slo 

Average loss at step

Average loss at step 126000: 1.171 
Minibatch perplexity: 3.224
Validation perplexity: 3.190
fin edgar and his first point his third staffed a more born the first settlematologist environments l 

Average loss at step 127000: 1.101 
Minibatch perplexity: 3.006
Validation perplexity: 3.159
ken morning comeo lines andrew ecceding water stress be publisher had a directors of thus in famous s 

Average loss at step 128000: 1.203 
Minibatch perplexity: 3.329
Validation perplexity: 3.198
x the crahma or condition non foes for the southern a means thiller s that crime eritreached dispute  

Average loss at step 129000: 1.271 
Minibatch perplexity: 3.565
Validation perplexity: 3.201
k in the new centaville film green available on the main says interesting self waves the variable def 

Average loss at step 130000: 1.161 
Minibatch perplexity: 3.193
Validation perplexity: 3.084
ject magnation the gonzo ensuke the chance of the product hears they had been constens while there is 

Average loss at

Average loss at step 8000: 1.257 
Minibatch perplexity: 3.514
Validation perplexity: 2.936
n is about the matte e eine jan tv one zero three design guerrilland properlant the remainder of the  

Average loss at step 9000: 1.312 
Minibatch perplexity: 3.714
Validation perplexity: 2.919
nomes people since it is a countries who would continue to rome second germany alone american point w 

Average loss at step 10000: 1.228 
Minibatch perplexity: 3.413
Validation perplexity: 2.951
zenazing as an emperors on worldwide the batman predeced early c earlier than two other subgrouper an 

Average loss at step 11000: 1.203 
Minibatch perplexity: 3.328
Validation perplexity: 2.949
bert is part of phones locally are part of the friends comet in turkey or confusion skin up ministry  

Average loss at step 12000: 1.265 
Minibatch perplexity: 3.542
Validation perplexity: 2.949
me itself asian scene comprehensed by top of there are no swords accepting alogna bdong itselffe one  

Average loss at step 1

Average loss at step 50000: 1.133 
Minibatch perplexity: 3.106
Validation perplexity: 2.910
cs where he was three in one eight zero six education on how cyclon hits was a sense of the legendary 

Average loss at step 51000: 1.307 
Minibatch perplexity: 3.693
Validation perplexity: 2.914
jection of the lorian officials therapyish and is among the cycle rio bang william italian murderer n 

Average loss at step 52000: 1.318 
Minibatch perplexity: 3.736
Validation perplexity: 2.925
quest can be in low friend and is pronounced only any overshoem builby vis on jews enjoyled boundarie 

Average loss at step 53000: 1.189 
Minibatch perplexity: 3.284
Validation perplexity: 2.880
ph free play of the user apside backs geeli where the level two zero zero zero inche is geographicall 

Average loss at step 54000: 1.194 
Minibatch perplexity: 3.300
Validation perplexity: 2.891
 marmiturgy thick is the founder of the market was carried a program the railier called one populatio 

Average loss at step

Average loss at step 92000: 1.313 
Minibatch perplexity: 3.717
Validation perplexity: 2.892
ba is zoted by beegheg first oppression markupse nicklary population would responsible for fm merchan 

Average loss at step 93000: 1.218 
Minibatch perplexity: 3.379
Validation perplexity: 2.897
get warrior range from brainnacism theologs gcd nstraize the sheep one four one and monoxymous eagles 

Average loss at step 94000: 1.154 
Minibatch perplexity: 3.171
Validation perplexity: 2.906
it in by squeezes it is oran under the post cases of the armenia in one eight eight nine the month be 

Average loss at step 95000: 1.338 
Minibatch perplexity: 3.812
Validation perplexity: 2.916
a the novels that the national slies around form and light address and september two zero zero zero f 

Average loss at step 96000: 1.183 
Minibatch perplexity: 3.264
Validation perplexity: 2.928
phaleps through organization progressive rotational greisgrassion who have no conscripted with some p 

Average loss at step

Average loss at step 134000: 1.350 
Minibatch perplexity: 3.856
Validation perplexity: 2.918
ging there he lett harry richard some one five six press on and clearing congo glames each youth comp 

Average loss at step 135000: 1.169 
Minibatch perplexity: 3.219
Validation perplexity: 2.912
y television neo tourists and former games each real primary reference to uplentitts magazines for th 

Average loss at step 136000: 1.341 
Minibatch perplexity: 3.823
Validation perplexity: 2.932
quences being demographic department used the whigs at left for high addiction commandary work in dec 

Average loss at step 137000: 1.141 
Minibatch perplexity: 3.129
Validation perplexity: 2.931
staking movement housing of measurement shagist once which sexiom only by them three spufative illust 

Average loss at step 138000: 1.209 
Minibatch perplexity: 3.351
Validation perplexity: 2.914
z may old civil war e state in apple in sixty goodoelite and justines one eight seven one three nine  

Average loss at

In [26]:
# model.eval() after reloading
torch.save(lstm.state_dict(), 'model')

In [28]:
lstm.load_state_dict(torch.load('model'))
lstm.eval()
sample()
print(valid_perplexity())

ca rna doek dan russifyians conclude the european civil satar was followed for direct sounds is gabon 
Variable containing:
 2.8524
[torch.cuda.FloatTensor of size 1 (GPU 0)]

