## Imports

In [1]:
import torch

import torch.nn as nn

import torch.nn.functional as F

import numpy

## Utils

In [2]:
def load_data(file_path, batch_size, sequence_size):
    # Load data
    with open(file_path) as file:
        text = file.read().split()
    
    # Create support dictionaries
    from collections import Counter as counter
    
    # Count how many times each word appears in the data
    words_counter = counter(text)
    
    sorted_words = sorted(words_counter, key=words_counter.get, reverse=True)
    
    int_to_words = dict((indice, word) for indice, word in enumerate(sorted_words))
    
    words_to_int = dict((word, indice) for indice, word in int_to_words.items())
    
    number_of_words = len(int_to_words)
    
    # Generate network input, i.e words as integers
    int_text = [words_to_int[word] for word in text]
    
    number_of_batchs = len(int_text) // (sequence_size * batch_size)
    
    # Remove one batch from the end of the list
    batchs = int_text[:number_of_batchs * batch_size * sequence_size]
    
    # Generate network input target, the target of each input,
    # in text generation, its the consecutive input
    # 
    # To obtain the target its necessary to shift all values one
    # step to the left
    labels = numpy.zeros_like(batchs)
    
    try:
        # Shift all values to the left
        labels[:-1] = batchs[1:]

        # Set the next word of the last value of the last list to the
        # first value of the first list
        labels[-1] = batchs[0]

        labels = numpy.reshape(labels, (batch_size, -1))

        batchs = numpy.reshape(batchs, (batch_size, -1))
    except IndexError as error:
        raise Exception('Invalid amount of words to generate the batchs / sequences')
    
    return dict(
        int_to_words=int_to_words,
        words_to_int=words_to_int,
        batchs=batchs,
        labels=labels,
        number_of_words=number_of_words
    )

In [3]:
def getBatchs(batch, labels, batch_size, sequence_size):
    numBatchs = numpy.prod(batch.shape) // (sequence_size * batch_size)
    
    for indice in range(0, numBatchs * sequence_size, sequence_size):
        yield batch[:, indice:indice + sequence_size], labels[:, indice:indice + sequence_size]

## Model

In [4]:
class LSTM(nn.Module):
    def __init__(self, number_of_words, sequence_size, embedding_size, lstm_size):
        super(LSTM, self).__init__()

        self.sequence_size = sequence_size

        self.lstm_size = lstm_size

        self.embedding = nn.Embedding(number_of_words, embedding_size)

        self.lstm = nn.LSTM(
            embedding_size,
            lstm_size,
            batch_first=True
        )

        self.dense = nn.Linear(lstm_size, number_of_words)

    def forward(self, state, previous_state):
        embed = self.embedding(state)

        output, state = self.lstm(embed, previous_state)

        logits = self.dense(output)

        return logits, state

    def resetState(self, batchSize):
        # Reset the hidden (h) state and the memory (c) state
        return (torch.zeros(1, batchSize, self.lstm_size) for indice in range(2))


## Training

In [5]:
sequence_size = 32

batch_size = 16

embedding_size = 64

lstm_size = 64

cuda = False

epochs = 14

learn_rating = 0.1

gradients_norm = 5

initial_words = ['I', 'think', 'life', 'is']

top = 5

In [6]:
data = load_data('data.raw', batch_size, sequence_size)

In [7]:
model = LSTM(
    data.get('number_of_words'),
    sequence_size,
    embedding_size,
    lstm_size
)

if torch.cuda.is_available and cuda:
    model = model.cuda()

optimizer = torch.optim.Adam(model.parameters(), lr=learn_rating)

criterion = nn.CrossEntropyLoss()

iteration = 0

In [8]:
def predict(model, initial_words, number_of_words, words_to_int, int_to_words, top=5):
    # Set evaluation mode
    model.eval()

    words = initial_words.copy()

    # Reset state
    stateHidden, stateMemory = model.resetState(1)

    if torch.cuda.is_available and cuda:
        stateHidden, stateMemory = stateHidden.cuda(), stateMemory.cuda()

    for word in words:
        _word = torch.tensor([[words_to_int[word]]])

        if torch.cuda.is_available and cuda:
            _word = _word.cuda()

        output, (stateHidden, stateMemory) = model(
            _word,
            (stateHidden, stateMemory)
        )

    _, _top = torch.topk(output[0], k=top)

    choices = _top.tolist()

    choice = numpy.random.choice(choices[0])

    words.append(int_to_words[choice])

    for _ in range(100):
        _word = torch.tensor([[choice]])

        if torch.cuda.is_available and cuda:
            _word = _word.cuda()

        output, (stateHidden, stateMemory) = model(
            _word,
            (stateHidden, stateMemory)
        )

        _, _top = torch.topk(output[0], k=top)

        choices = _top.tolist()

        choice = numpy.random.choice(choices[0])

        words.append(int_to_words[choice])

    print(' '.join(words).encode('utf-8'))


In [9]:
for epoch in range(epochs):
    batchs = getBatchs(
        data.get('batchs'),
        data.get('labels'),
        batch_size,
        sequence_size
    )
    
    stateHidden, stateMemory = model.resetState(batch_size)
    
    if torch.cuda.is_available and cuda:
        stateHidden, stateMemory = stateHidden.cuda(), stateMemory.cuda()
            
    for batch_data, batch_label in batchs:
        iteration += 1
        
        # Set train mode
        model.train()
        
        # Reset gradient
        optimizer.zero_grad()
        
        # Transform array to tensor
        batch_data = torch.tensor(batch_data)
        
        batch_label = torch.tensor(batch_label)
        
        # Send tensor to GPU
        if torch.cuda.is_available and cuda:
            batch_data = batch_data.cuda()
            
            batch_label = batch_label.cuda()
        
        # Train
        logits, (stateHidden, stateMemory) = model(
            batch_data,
            (stateHidden, stateMemory)
        )
        
        # Loss
        loss = criterion(logits.transpose(1, 2), batch_label)
        
        # Remove state from graph for gradient clipping
        stateHidden = stateHidden.detach()
        
        stateMemory = stateMemory.detach()
        
        # Back-propagation
        loss.backward()
        
        # Gradient clipping (inline)
        nn.utils.clip_grad_norm_(
            model.parameters(),
            gradients_norm
        )
        
        # Update network's parameters
        optimizer.step()
        
        # Loss value
        print(f'Epoch {epoch}, Iteration: {iteration}, Loss: {loss.item()}')
        
        # Predict value
        if iteration % 20 == 0:
            predict(model, initial_words, data.get('number_of_words'), data.get('words_to_int'), data.get('int_to_words'), top)

Epoch 0, Iteration: 1, Loss: 8.65025520324707
Epoch 0, Iteration: 2, Loss: 8.3564453125
Epoch 0, Iteration: 3, Loss: 8.264205932617188
Epoch 0, Iteration: 4, Loss: 7.507660865783691
Epoch 0, Iteration: 5, Loss: 7.194112777709961
Epoch 0, Iteration: 6, Loss: 7.409313678741455
Epoch 0, Iteration: 7, Loss: 7.552872657775879
Epoch 0, Iteration: 8, Loss: 7.43687629699707
Epoch 0, Iteration: 9, Loss: 7.4203877449035645
Epoch 0, Iteration: 10, Loss: 7.384235858917236
Epoch 0, Iteration: 11, Loss: 7.565990924835205
Epoch 0, Iteration: 12, Loss: 7.666876792907715
Epoch 0, Iteration: 13, Loss: 7.5550312995910645
Epoch 0, Iteration: 14, Loss: 7.284774303436279
Epoch 0, Iteration: 15, Loss: 7.402453899383545
Epoch 0, Iteration: 16, Loss: 7.411348342895508
Epoch 0, Iteration: 17, Loss: 7.311158180236816
Epoch 0, Iteration: 18, Loss: 7.591379165649414
Epoch 0, Iteration: 19, Loss: 7.241386413574219
Epoch 0, Iteration: 20, Loss: 7.4343061447143555
b"I guess. INT. SEBASTIAN'S DECKARD'S Japanese) sign 

KeyboardInterrupt: 