## Imports

In [1]:
import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy as np

## Utils

In [2]:
def loadData(filePath, batchSize, sequenceSize):
    # Load data
    with open(filePath) as file:
        text = file.read()
    
    text = text.split()
    
    # Create support dictionaries
    from collections import Counter as counter
    
    wordsCounter = counter(text)
    
    sortedWords = sorted(wordsCounter, key=wordsCounter.get, reverse=True)
    
    intToWords = dict((indice, word) for indice, word in enumerate(sortedWords))
    
    wordsToInt = dict((word, indice) for indice, word in intToWords.items())
    
    numberOfWords = len(intToWords)
    
    # Generate network input, i.e words as integers
    intText = [wordsToInt[word] for word in text]
    
    numberOfBatchs = len(intText) // (sequenceSize * batchSize)
    
    # Remove one batch from the end of the list
    batchs = intText[:numberOfBatchs * batchSize * sequenceSize]
    
    # Generate network input target, the target of each input,
    # in text generation, its the consecutive input. To obtain
    # the target its necessary to shift all values one step to
    # the left
    labels = np.zeros_like(batchs)
    try:
        # Shift all values to the left
        labels[:-1] = batchs[1:]

        # Set the next word of the last value of the last list to the
        # first value of the first list
        labels[-1] = batchs[0]

        labels = np.reshape(labels, (batchSize, -1))

        batchs = np.reshape(batchs, (batchSize, -1))
    except IndexError as error:
        raise Exception('Invalid amount of words to generate the batchs / sequences')
    
    return dict(
        intToWords=intToWords,
        wordsToInt=wordsToInt,
        batchs=batchs,
        labels=labels,
        numberOfWords=numberOfWords
    )

In [3]:
def getBatchs(batch, labels, batchSize, sequenceSize):
    # functools.reduce(lambda a, b: a * b, batch.shape) // (sequenceSize * batchSize) 
    
    numBatchs = np.prod(batch.shape) // (sequenceSize * batchSize)
    
    for indice in range(0, numBatchs * sequenceSize, sequenceSize):
        yield batch[:, indice:indice + sequenceSize], labels[:, indice:indice + sequenceSize]

## Model

In [4]:
class LSTM(nn.Module):
    def __init__(self, numberOfWords, sequenceSize, embeddingSize, lstmSize):
        super(LSTM, self).__init__()
        
        self.sequenceSize = sequenceSize
        
        self.lstmSize = lstmSize
        
        self.embedding = nn.Embedding(numberOfWords, embeddingSize)
        
        self.lstm = nn.LSTM(embeddingSize,
                           lstmSize,
                           batch_first=True)
        
        self.dense = nn.Linear(lstmSize, numberOfWords)
        
    def forward(self, state, previousState):
        embed = self.embedding(state)
        
        output, state = self.lstm(embed, previousState)
        
        logits = self.dense(output)
        
        return logits, state
    
    def resetState(self, batchSize):
        # Reset the hidden (h) state and the memory (c) state
        return (torch.zeros(1, batchSize, self.lstmSize) for indice in range(2))

## Training

In [5]:
sequenceSize = 32

batchSize = 16

embeddingSize = 64

lstmSize = 64

cuda = False

epochs = 5

learnRating = 0.1

gradientsNorm = 5

initialWords = ['I', 'am']

top = 5

In [6]:
data = loadData('data.raw', batchSize, sequenceSize)

In [7]:
model = LSTM(
    data.get('numberOfWords'),
    sequenceSize,
    embeddingSize,
    lstmSize
)

if torch.cuda.is_available and cuda:
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=learnRating)

criterion = nn.CrossEntropyLoss()

iteration = 0

In [8]:
def predict(model, initialWords, numberOfWords, wordsToInt, intToWords, top=5):
    # Set evaluation mode
    model.eval()
    
    words = initialWords.copy()
    
    # Reset state
    stateHidden, stateMemory = model.resetState(1)
    
    if torch.cuda.is_available and cuda:
        stateHidden, stateMemory = stateHidden.cuda(), stateMemory.cuda()

    for word in words:
        _word = torch.tensor([[wordsToInt[word]]])
        
        if torch.cuda.is_available and cuda:
            _word = _word.cuda()
        
        output, (stateHidden, stateMemory) = model(
            _word,
            (stateHidden, stateMemory)
        )
        
    _, _top = torch.topk(output[0], k=top)

    choices = _top.tolist()

    choice = np.random.choice(choices[0])

    words.append(intToWords[choice])
    
    for _ in range(100):
        _word = torch.tensor([[choice]])
        
        if torch.cuda.is_available and cuda:
            _word = _word.cuda()
        
        output, (stateHidden, stateMemory) = model(
            _word,
            (stateHidden, stateMemory)
        )

        _, _top = torch.topk(output[0], k=top)
        
        choices = _top.tolist()
        
        choice = np.random.choice(choices[0])
        
        words.append(intToWords[choice])

    print(' '.join(words).encode('utf-8'))

In [9]:
for epoch in range(epochs):
    batchs = getBatchs(
        data.get('batchs'),
        data.get('labels'),
        batchSize,
        sequenceSize
    )
    
    stateHidden, stateMemory = model.resetState(batchSize)
    
    if torch.cuda.is_available and cuda:
        stateHidden, stateMemory = stateHidden.cuda(), stateMemory.cuda()
            
    for batch_data, batch_label in batchs:
        iteration += 1
        
        # Set train mode
        model.train()
        
        # Reset gradient
        optimizer.zero_grad()
        
        # Transform array to tensor
        batch_data = torch.tensor(batch_data)
        
        batch_label = torch.tensor(batch_label)
        
        # Send tensor to GPU
        if torch.cuda.is_available and cuda:
            batch_data = batch_data.cuda()
            
            batch_label = batch_label.cuda()
        
        # Train
        logits, (stateHidden, stateMemory) = model(
            batch_data,
            (stateHidden, stateMemory)
        )
        
        # Loss
        loss = criterion(logits.transpose(1, 2), batch_label)
        
        # Remove state from graph for gradient clipping
        stateHidden = stateHidden.detach()
        
        stateMemory = stateMemory.detach()
        
        # Back-propagation
        loss.backward()
        
        # Gradient clipping (inline)
        nn.utils.clip_grad_norm_(
            model.parameters(),
            gradientsNorm
        )
        
        # Update network's parameters
        optimizer.step()
        
        # Loss value
        print(f'Epoch {epoch}, Iteration: {iteration}, Loss: {loss.item()}')
        
        # Predict value
        if iteration % 20 == 0:
            predict(model, initialWords, data.get('numberOfWords'), data.get('wordsToInt'), data.get('intToWords'), top)

Epoch 0, Iteration: 1, Loss: 10.297513961791992
Epoch 0, Iteration: 2, Loss: 9.968701362609863
Epoch 0, Iteration: 3, Loss: 9.570663452148438
Epoch 0, Iteration: 4, Loss: 8.739331245422363
Epoch 0, Iteration: 5, Loss: 8.53475284576416
Epoch 0, Iteration: 6, Loss: 8.700026512145996
Epoch 0, Iteration: 7, Loss: 8.66878890991211
Epoch 0, Iteration: 8, Loss: 8.480375289916992
Epoch 0, Iteration: 9, Loss: 8.208251953125
Epoch 0, Iteration: 10, Loss: 8.421578407287598
Epoch 0, Iteration: 11, Loss: 8.497539520263672


KeyboardInterrupt: 