In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
import string
import random
from random import randint
from fastText import load_model
import torch.backends.cudnn as cudnn

In [2]:
# Read data into string, split into words and finally separate into train and valid sets
valid_size = 1000

with open('text8') as file:
    data = file.read()
words = data.split()
valid_words = words[:valid_size]
train_words = words[valid_size:]
train_size = len(train_words)

print(valid_size, valid_words[:10])
print(train_size, train_words[:10])

1000 ['anarchism', 'originated', 'as', 'a', 'term', 'of', 'abuse', 'first', 'used', 'against']
17004207 ['american', 'individualist', 'anarchism', 'benjamin', 'tucker', 'in', 'one', 'eight', 'two', 'five']


In [3]:
# Use fasttext excecutable in fastText-0.1.0 folder to build bin (and vec) file containing word embeddings
# Load bin file into fasttext model
ft_model = load_model('text8_ft.bin')

In [4]:
# dictionary -> word : embedding
# index_dictionary -> index : word
# reverse_index_dictionary -> word : index
dictionary = {'UNK' : ft_model.get_word_vector('UNK')}
index_dictionary = {0 : 'UNK'}
all_words = ft_model.get_words()
for word in all_words:
    dictionary[word] = ft_model.get_word_vector(word)
    index_dictionary[len(index_dictionary)] = word
reverse_index_dictionary = dict(zip(index_dictionary.values(), index_dictionary.keys()))

print('Vocabulary size : %d' % len(dictionary))

Vocabulary size : 71291


In [5]:
vocab_size = len(dictionary)
num_dimensions = 300
batch_size = 64
num_unrollings = 10

# Generate batches parallely across the text at equal intervals
# Each batch contains one word from each of the positions
# Positions are updated after generating every batch
# The next batch would therefore contain the next words from all the chosen positions
# num_unrollings number of batches are processed at once
# Each word is represented as an embedding
class BatchGenerator(object):
    def __init__(self, corpus, batch_size, num_unrollings):
        self._corpus = corpus
        self._corpus_size = len(corpus)
        self._batch_size = batch_size
        self._num_unrollings = num_unrollings
        segment = self._corpus_size // batch_size
        self._cursor = [offset*segment for offset in range(batch_size)]
        self._last_batch, self._last_indices = self._next_batch()
        
    def _next_batch(self):
        #print('')
        batch = np.zeros(shape=(self._batch_size, num_dimensions), dtype=np.float)
        word_indices = np.zeros(shape=(self._batch_size,), dtype=np.int64)
        for b in range(self._batch_size):
            word = self._corpus[self._cursor[b]]
            if word not in dictionary:
                batch[b] = dictionary['UNK']
                word_indices[b] = reverse_index_dictionary['UNK']
            else:
                batch[b] = dictionary[word]
                word_indices[b] = reverse_index_dictionary[word]
            #print(word, end='')
            #print(' ', end='')
            self._cursor[b] = (self._cursor[b] + 1) % self._corpus_size
        #print('')
        return batch, word_indices
    
    def _next(self):
        batches = [self._last_batch]
        indices = [self._last_indices]
        for step in range(self._num_unrollings):
            batch, word_indices = self._next_batch()
            batches.append(batch)
            indices.append(word_indices)
        self._last_batch = batches[-1]
        self._last_indices = indices[-1]
        return batches, indices
    
train_batches = BatchGenerator(train_words, batch_size, num_unrollings)
valid_batches = BatchGenerator(valid_words, 1, 1)

In [6]:
dtype = torch.FloatTensor
hidden_size = 1024

class LSTM(nn.Module):
    
    def __init__(self, num_dimensions, hidden_size):
        super(LSTM, self).__init__()
        # LSTM architecture
        
        # Layer 1
        self.inputW1 = nn.Linear(num_dimensions, hidden_size)
        self.inputU1 = nn.Linear(hidden_size, hidden_size)
        self.forgetW1 = nn.Linear(num_dimensions, hidden_size)
        self.forgetU1 = nn.Linear(hidden_size, hidden_size)
        self.outputW1 = nn.Linear(num_dimensions, hidden_size)
        self.outputU1 = nn.Linear(hidden_size, hidden_size)
        self.cellW1 = nn.Linear(num_dimensions, hidden_size)
        self.cellU1 = nn.Linear(hidden_size, hidden_size)
        
        # Layer 2
        self.inputW2 = nn.Linear(hidden_size, hidden_size)
        self.inputU2 = nn.Linear(hidden_size, hidden_size)
        self.forgetW2 = nn.Linear(hidden_size, hidden_size)
        self.forgetU2 = nn.Linear(hidden_size, hidden_size)  
        self.outputW2 = nn.Linear(hidden_size, hidden_size)
        self.outputU2 = nn.Linear(hidden_size, hidden_size)     
        self.cellW2 = nn.Linear(hidden_size, hidden_size)
        self.cellU2 = nn.Linear(hidden_size, hidden_size)
        
        # Softmax weight
        self.softW = nn.Linear(hidden_size, vocab_size)
        # Dropout
        self.dropout = nn.Dropout(0)
        
    def forward(self, embeddings, cell_prev, hidden_prev, batch_size):
        sig = nn.Sigmoid()
        tnh = nn.Tanh()
        
        # Extract cell and hidden data for each layer
        cell_prev = Variable(cell_prev)
        hidden_prev = Variable(hidden_prev)
        cell_prev1 = cell_prev[:batch_size, :]
        hidden_prev1 = hidden_prev[:batch_size, :]
        cell_prev2 = cell_prev[batch_size:, :]
        hidden_prev2 = hidden_prev[batch_size:, :]
        
        # Layer 1 computation
        input_gate = sig(self.inputW1(embeddings) + self.inputU1(hidden_prev1))
        forget_gate = sig(self.forgetW1(embeddings) + self.forgetU1(hidden_prev1))
        output_gate = sig(self.outputW1(embeddings) + self.outputU1(hidden_prev1))
        update = tnh(self.cellW1(embeddings) + self.cellU1(hidden_prev1))
        cell1 = (forget_gate * cell_prev1) + (input_gate * update)
        hidden1 = output_gate * tnh(cell1)
        
        # Layer 2 computation
        input_gate = sig(self.inputW2(hidden1) + self.inputU2(hidden_prev2))
        forget_gate = sig(self.forgetW2(hidden1) + self.forgetU2(hidden_prev2))
        output_gate = sig(self.outputW2(hidden1) + self.outputU2(hidden_prev2))
        update = tnh(self.cellW2(hidden1) + self.cellU2(hidden_prev2))
        cell2 = (forget_gate * cell_prev2) + (input_gate * update)
        hidden2 = output_gate * tnh(cell2)
        
        # Output
        logits = self.softW(hidden2)
        logits = self.dropout(logits)
        cell = torch.cat((cell1, cell2), dim=0)
        hidden = torch.cat((hidden2, hidden2), dim=0)
        
        return cell.data, hidden.data, logits

lstm = LSTM(num_dimensions, hidden_size)
lstm.cuda()

LSTM(
  (inputW1): Linear(in_features=300, out_features=1024)
  (inputU1): Linear(in_features=1024, out_features=1024)
  (forgetW1): Linear(in_features=300, out_features=1024)
  (forgetU1): Linear(in_features=1024, out_features=1024)
  (outputW1): Linear(in_features=300, out_features=1024)
  (outputU1): Linear(in_features=1024, out_features=1024)
  (cellW1): Linear(in_features=300, out_features=1024)
  (cellU1): Linear(in_features=1024, out_features=1024)
  (inputW2): Linear(in_features=1024, out_features=1024)
  (inputU2): Linear(in_features=1024, out_features=1024)
  (forgetW2): Linear(in_features=1024, out_features=1024)
  (forgetU2): Linear(in_features=1024, out_features=1024)
  (outputW2): Linear(in_features=1024, out_features=1024)
  (outputU2): Linear(in_features=1024, out_features=1024)
  (cellW2): Linear(in_features=1024, out_features=1024)
  (cellU2): Linear(in_features=1024, out_features=1024)
  (softW): Linear(in_features=1024, out_features=71291)
  (dropout): Dropout(p=0)


In [7]:
class StateManager(object):
    def __init__(self):
        self.hidden_state = torch.zeros(2*batch_size, hidden_size).type(dtype)
        self.cell_state = torch.zeros(2*batch_size, hidden_size).type(dtype)
        
    def save_state(self, cell_state, hidden_state):
        self.cell_state = cell_state
        self.hidden_state = hidden_state
        
    def load_state(self):
        return self.cell_state, self.hidden_state
    
    def reset_state(self):
        self.hidden_state = torch.zeros(2*batch_size, hidden_size).type(dtype)
        self.cell_state = torch.zeros(2*batch_size, hidden_size).type(dtype)
        
sm = StateManager()

In [8]:
learning_rate = 0.1
optimizer = torch.optim.SGD(lstm.parameters(), lr=learning_rate)

def train():
    cell, hidden = sm.load_state()
    batches, indices = train_batches._next()
    optimizer.zero_grad()
    lstm.zero_grad()
    loss_function = nn.CrossEntropyLoss()
    loss = 0
    for u in range(num_unrollings):
        embeddings = Variable(torch.from_numpy(batches[u]).type(dtype), requires_grad=False)
        cell, hidden, logits = lstm(embeddings.cuda(), cell.cuda(), hidden.cuda(), batch_size)
        labels = Variable(torch.from_numpy(indices[u+1]))
        loss += loss_function(logits.cuda(), labels.cuda())
    loss.backward()
    torch.nn.utils.clip_grad_norm(lstm.parameters(), 5) #1.25
    optimizer.step()
    sm.save_state(cell, hidden)
    return loss / num_unrollings

In [9]:
# Generate a random character and feed to model
# Take predicted character and feed it back again to generate subsequent characters

def sample_distribution(distribution):
# Sample one element from a distribution assumed to be an array of normalized probabilities
    r = random.uniform(0, 1)
    s = 0
    for i in range(len(distribution)):
        s += distribution[i]
        if s >= r:
            return i
    return len(distribution) - 1

def sample():
    word_id = randint(0, vocab_size-1)
    print(index_dictionary[word_id], end='')
    print(' ', end='')
    cell = torch.zeros(2*1, hidden_size).type(dtype)
    hidden = torch.zeros(2*1, hidden_size).type(dtype) 
    soft = nn.Softmax(dim=1)
    for i in range(12):
        numpy_vector = dictionary[index_dictionary[word_id]].reshape(1, num_dimensions)
        word_vector = torch.from_numpy(numpy_vector).type(dtype)
        word_vector = Variable(word_vector, requires_grad=False)
        cell, hidden, logits = lstm(word_vector.cuda(), cell.cuda(), hidden.cuda(), 1)
        output = soft(logits)
        word_id = sample_distribution(output.data[0])
        print(index_dictionary[word_id], end='')
        print(' ', end='')
    print(' ')

In [10]:
def valid_perplexity():
    loss_function = nn.CrossEntropyLoss()
    loss = 0
    cell = torch.zeros(2*1, hidden_size).type(dtype)
    hidden = torch.zeros(2*1, hidden_size).type(dtype)
    for i in range(valid_size):
        batches, indices = valid_batches._next()
        word_vectors = Variable(torch.from_numpy(batches[0]).type(dtype), requires_grad=False)
        cell, hidden, logits = lstm(word_vectors.cuda(), cell.cuda(), hidden.cuda(), 1)
        labels = Variable(torch.from_numpy(indices[1]))
        loss += loss_function(logits.cuda(), labels.cuda())
    return torch.exp(loss / valid_size)

In [29]:
num_iters = 300001 

cudnn.benchmark = True
cudnn.fasttest = True

for i in range(num_iters):
    lstm.train()
    l = train()
    if i%1000 == 0: 
        print('Average loss at step %d: %.3f ' % (i,l))
        print('Minibatch perplexity: %.3f' % torch.exp(l))
        print('Validation perplexity: %.3f' % valid_perplexity())
        lstm.eval()
        sample()
        print('')

Average loss at step 0: 11.177 
Minibatch perplexity: 71481.344
Validation perplexity: 70671.266
coloane chairperson include artcyclopedia osteoporosis tsetse gait begun guru recent dracula pluck tmc  

Average loss at step 1000: 7.463 
Minibatch perplexity: 1742.593
Validation perplexity: 2060.778
consultum symbols idiomatic relaunch perpendicular kroll significance dorothy restrictions for proposed fast intrinsic  

Average loss at step 2000: 7.020 
Minibatch perplexity: 1118.893
Validation perplexity: 1484.531
entrees proves affan encephalopathy voted olympian primordial in frequenting target wales question space  

Average loss at step 3000: 6.964 
Minibatch perplexity: 1057.770
Validation perplexity: 1520.071
politican suzerain combe alignments ank navy plagiarized jigoro leader pashtun authorized the construct  

Average loss at step 4000: 6.856 
Minibatch perplexity: 949.657
Validation perplexity: 1000.979
abate piran svante errol distinction nanyang overpasses tabular hussain f

Validation perplexity: 280.340
covers logged over weeks the same year in minnesota in one nine eight  

Average loss at step 44000: 5.339 
Minibatch perplexity: 208.406
Validation perplexity: 298.770
docherty winners with unprotected trust rights such as captain of a series of  

Average loss at step 45000: 5.778 
Minibatch perplexity: 323.087
Validation perplexity: 281.803
lymec expansionist reaction british fantasy basic diagram standards such as intoxicated substances or  

Average loss at step 46000: 5.295 
Minibatch perplexity: 199.413
Validation perplexity: 286.552
infer heathers entice UNK the latest star probe to print the screen is  

Average loss at step 47000: 5.057 
Minibatch perplexity: 157.063
Validation perplexity: 299.067
kirov deconstructionists classical norma ideals at glasnevin memorial university in philadelphia pennsylvania a  

Average loss at step 48000: 5.371 
Minibatch perplexity: 215.185
Validation perplexity: 300.758
villard kodagu vorticity master composed 

Average loss at step 88000: 4.927 
Minibatch perplexity: 137.913
Validation perplexity: 229.436
velupillai wha amputated after world war i is to say i do fall  

Average loss at step 89000: 4.710 
Minibatch perplexity: 111.015
Validation perplexity: 230.306
spur rainfall the study of spring europa is the character one nine nine  

Average loss at step 90000: 4.930 
Minibatch perplexity: 138.420
Validation perplexity: 217.610
reuniting zuider universes produced while producing safe seed before they had been added  

Average loss at step 91000: 4.742 
Minibatch perplexity: 114.687
Validation perplexity: 247.108
examination gulls swinburne lincoln has listed below the idea UNK that designs of  

Average loss at step 92000: 5.161 
Minibatch perplexity: 174.419
Validation perplexity: 226.900
albus annihilate yw b is still it the first modern guide to work  

Average loss at step 93000: 5.079 
Minibatch perplexity: 160.684
Validation perplexity: 210.935
hso saskatchewan fptp delivers the UNK

peep blind goalie lineman udo max c one nine five seven one nine  

Average loss at step 134000: 4.742 
Minibatch perplexity: 114.613
Validation perplexity: 200.300
tif greyscale supported this technology was imported by mumps such as mozilla engelbart  

Average loss at step 135000: 4.593 
Minibatch perplexity: 98.776
Validation perplexity: 195.932
eigenstates gasses jahwist works encompassing six contrasting chiefly aspects of the physical laws  

Average loss at step 136000: 4.742 
Minibatch perplexity: 114.620
Validation perplexity: 219.652
gains polytheism split into the category may end in the one nine seven  

Average loss at step 137000: 4.783 
Minibatch perplexity: 119.463
Validation perplexity: 192.846
meta wresting the correct requirements of the arguments obtained by and the private  

Average loss at step 138000: 4.732 
Minibatch perplexity: 113.489
Validation perplexity: 199.431
eniac proletarian actinide believed that it was designed to convince their designs by  

Avera

Average loss at step 179000: 4.654 
Minibatch perplexity: 105.016
Validation perplexity: 200.582
retrovirus pyrimidines el cooh in mariana there are two major raves in zoology  

Average loss at step 180000: 4.715 
Minibatch perplexity: 111.652
Validation perplexity: 208.377
oprah cecilia who founded the indigenous scholarly helped mississippi basin for public health  

Average loss at step 181000: 4.465 
Minibatch perplexity: 86.935
Validation perplexity: 203.718
features apoptosis arrest manslaughter bsi makers cartoons depicting the gig ovations is on  

Average loss at step 182000: 4.531 
Minibatch perplexity: 92.875
Validation perplexity: 201.461
doughnut dog man broken down his hand called UNK to conceive his true  

Average loss at step 183000: 4.637 
Minibatch perplexity: 103.282
Validation perplexity: 210.160
channels scholz s consultative book distanced on for the computer interaction block algorithmic  

Average loss at step 184000: 4.741 
Minibatch perplexity: 114.536
Valid

Validation perplexity: 202.679
chimpanzee species including pigments that survived at the costanzo laboratory museum the the  

Average loss at step 225000: 4.617 
Minibatch perplexity: 101.171
Validation perplexity: 174.492
ports together in several ways to reach the frozen in new zealand s  

Average loss at step 226000: 4.443 
Minibatch perplexity: 85.034
Validation perplexity: 193.638
reconquering hazara disappeared out of earlier mongolian UNK at the beginning of the  

Average loss at step 227000: 4.623 
Minibatch perplexity: 101.849
Validation perplexity: 191.569
flammable fingerboard shaped and demonstrated that a composer would have capablanca to never  

Average loss at step 228000: 4.793 
Minibatch perplexity: 120.713
Validation perplexity: 200.085
thereabouts of the people s movement but the lack of excessive punishment likely  

Average loss at step 229000: 4.332 
Minibatch perplexity: 76.116
Validation perplexity: 193.777
brel settled out of the old swedish province but t

Validation perplexity: 182.369
saari legal controversy in the linguistics of phonology and which in english would  

Average loss at step 271000: 4.477 
Minibatch perplexity: 87.990
Validation perplexity: 182.674
creole carriages specialists have begun economic jurist employing methods for finding us under  

Average loss at step 272000: 4.534 
Minibatch perplexity: 93.125
Validation perplexity: 182.799
furry playoff week one six one seven to attack UNK shores read this  

Average loss at step 273000: 4.162 
Minibatch perplexity: 64.181
Validation perplexity: 193.749
afonso meanwhile he is currently involved in negotiations for the congo s refugee  

Average loss at step 274000: 4.077 
Minibatch perplexity: 58.989
Validation perplexity: 179.456
extensional the development and development new facilities that is prone to developing more  

Average loss at step 275000: 4.275 
Minibatch perplexity: 71.898
Validation perplexity: 194.372
vibrato vertebra referred to as a green UNK joe ray ha

Validation perplexity: 193.810
thankful many voices and UNK UNK jenkins UNK who was merely a man  

Average loss at step 317000: 4.278 
Minibatch perplexity: 72.088
Validation perplexity: 200.693
hene els the cog with the ussr and china one particular perception science  

Average loss at step 318000: 4.533 
Minibatch perplexity: 93.021
Validation perplexity: 188.589
pestis unc tuberculosis involves an opposing but not medical and prescription political cults  

Average loss at step 319000: 4.402 
Minibatch perplexity: 81.626
Validation perplexity: 179.728
warplanes surge rapid development in mobile operation mobile telephone police paved nine m  

Average loss at step 320000: 4.308 
Minibatch perplexity: 74.270
Validation perplexity: 188.861
dragonetti perspiration leonardo also assistants replaced UNK by switching has been revised entry  

Average loss at step 321000: 4.476 
Minibatch perplexity: 87.841
Validation perplexity: 180.317
exploding the vast amounts of bosnian and turkish 

Validation perplexity: 196.721
cliff slides straight games like the locks of the bell white chicken and  

Average loss at step 362000: 4.412 
Minibatch perplexity: 82.467
Validation perplexity: 194.544
pharisees but were not all individuals for the cult cult by most paranoia  

Average loss at step 363000: 4.497 
Minibatch perplexity: 89.780
Validation perplexity: 190.492
fgth for its role as a composer s poking mother in his sleep  

Average loss at step 364000: 4.396 
Minibatch perplexity: 81.151
Validation perplexity: 170.745
lace sponge musei big body dora a hidden hand can you translate and  

Average loss at step 365000: 4.305 
Minibatch perplexity: 74.045
Validation perplexity: 194.357
inquest sharing holes far wrong some cartoons manufacture in some of the cia  

Average loss at step 366000: 4.586 
Minibatch perplexity: 98.147
Validation perplexity: 193.952
golfo couch s tavern it allows for use such as camouflage camouflage and  

Average loss at step 367000: 4.522 
Minibatch

Average loss at step 408000: 4.287 
Minibatch perplexity: 72.738
Validation perplexity: 188.876
vapour bath and the mud to carnegie service gallery and museum of human  

Average loss at step 409000: 4.530 
Minibatch perplexity: 92.725
Validation perplexity: 182.111
conversing trip with them have superb things happen to play back again but  

Average loss at step 410000: 4.241 
Minibatch perplexity: 69.459
Validation perplexity: 188.318
guatemalan armada in the philippines fleet grudgingly edward vi played a significant role  

Average loss at step 411000: 4.041 
Minibatch perplexity: 56.899
Validation perplexity: 170.388
stoked in some instances in the harry potter books in one nine five  

Average loss at step 412000: 4.340 
Minibatch perplexity: 76.686
Validation perplexity: 182.300
advent the phrase madagascar is a national currency helping codes by rowing straits  

Average loss at step 413000: 4.247 
Minibatch perplexity: 69.926
Validation perplexity: 186.533
anesthesia laying in

Average loss at step 454000: 4.297 
Minibatch perplexity: 73.461
Validation perplexity: 182.203
hit one zero six nine short period at unknown records which have theorized  

Average loss at step 455000: 4.330 
Minibatch perplexity: 75.966
Validation perplexity: 177.287
pickled and hunt UNK machine called barris UNK in manchester beer family stephan  

Average loss at step 456000: 4.456 
Minibatch perplexity: 86.128
Validation perplexity: 175.436
hardware chorded by ignorant abilities it is primarily considered to be rotated off  

Average loss at step 457000: 4.227 
Minibatch perplexity: 68.504
Validation perplexity: 181.386
ve mustafa left europe as an activist it is common in current number  

Average loss at step 458000: 4.405 
Minibatch perplexity: 81.853
Validation perplexity: 172.809
evidential step much if we have already caught around with dimensions that had  

Average loss at step 459000: 4.541 
Minibatch perplexity: 93.830
Validation perplexity: 190.258
stifling harassment a

awm for patients but none of whom caeiro did not pursue life after  

Average loss at step 500000: 4.424 
Minibatch perplexity: 83.417
Validation perplexity: 194.862
councilor UNK s recommendation for five zero years put the suspicion of separatist  

Average loss at step 501000: 4.408 
Minibatch perplexity: 82.138
Validation perplexity: 186.556
xlr dragon the ark is published by john neale online a r de  

Average loss at step 502000: 4.285 
Minibatch perplexity: 72.617
Validation perplexity: 191.616
buildup the pocket is longer than any law army spending and serving at  

Average loss at step 503000: 4.155 
Minibatch perplexity: 63.765
Validation perplexity: 194.256
haemophilus hanasi in the past two zero zero three zero zero four chief  

Average loss at step 504000: 4.169 
Minibatch perplexity: 64.643
Validation perplexity: 185.592
width one zero zero international blue is the most descriptive and confusing depiction  

Average loss at step 505000: 4.373 
Minibatch perplexity: 79.2

Average loss at step 592000: 4.331 
Minibatch perplexity: 76.053
Validation perplexity: 187.058
singleton s career he was an accomplished in the academy s event and  

Average loss at step 593000: 4.129 
Minibatch perplexity: 62.106
Validation perplexity: 195.409
clayton frost frank mckean ian fleming and poets jerry cornelius j j edwards  

Average loss at step 594000: 4.091 
Minibatch perplexity: 59.817
Validation perplexity: 189.847
chalcedonians turks killed spellman UNK ca one nine eight three odoacer v becomes  

Average loss at step 595000: 4.221 
Minibatch perplexity: 68.095
Validation perplexity: 186.008
solvable probability the process waned from thinking that word human beings could regard  

Average loss at step 596000: 4.155 
Minibatch perplexity: 63.737
Validation perplexity: 192.193
rime a guest job for adjectives tour titles three the c one three  

Average loss at step 597000: 4.150 
Minibatch perplexity: 63.443
Validation perplexity: 169.763
zalman fossett stanford mc

KeyboardInterrupt: 

In [40]:
# model.eval() after reloading
torch.save(lstm.state_dict(), 'model')

In [11]:
lstm.load_state_dict(torch.load('model'))
lstm.eval()
sample()
valid_perplexity()