In [108]:
import torch
import torch.nn as nn
import numpy as np

In [109]:
class LongShortTermMemoryModel(nn.Module):

    def __init__(self, encoding_size):
        super(LongShortTermMemoryModel, self).__init__()

        self.lstm = nn.LSTM(encoding_size, 128)
        self.dense = nn.Linear(128, encoding_size)

    def reset(self):
        zero_state = torch.zeros(1, 1, 128)
        self.hidden_state = zero_state
        self.cell_state = zero_state

    def logits(self, x):
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        return self.dense(out.reshape(-1, 128))

    def f(self, x):
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))


In [110]:
index_to_char = [' ', 'h', 'e', 'l', 'o', 'w', 'r', 'd']
char_encodings = np.identity(len(index_to_char), dtype=float).tolist()
encoding_size = len(char_encodings)


x_train = torch.tensor([[char_encodings[0]], [char_encodings[1]], [char_encodings[2]], [char_encodings[3]], [char_encodings[3]],
                        [char_encodings[4]],[char_encodings[0]], [char_encodings[5]], [char_encodings[4]], [char_encodings[6]], [char_encodings[3]], [char_encodings[7]]])  # ' hello world'
                    
y_train = torch.tensor([char_encodings[1], char_encodings[2], char_encodings[3], char_encodings[3], char_encodings[4], char_encodings[0], char_encodings[5], char_encodings[4], char_encodings[6], char_encodings[3], char_encodings[7], char_encodings[0]])  # 'hello world '


In [111]:
model = LongShortTermMemoryModel(encoding_size)

optimizer = torch.optim.RMSprop(model.parameters(), 0.001)
for epoch in range(500):
    model.reset()
    model.loss(x_train, y_train).backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % 10 == 9:
        model.reset()
        text = ' h'
        model.f(torch.tensor([[char_encodings[0]]]))
        y = model.f(torch.tensor([[char_encodings[1]]]))
        text += index_to_char[y.argmax(1)]
        for c in range(50):
            y = model.f(torch.tensor([[char_encodings[y.argmax(1)]]]))
            text += index_to_char[y.argmax(1)]
        print(text)


 hlllo                                               
 hlllo wrld    d                                     
 hlll world  wrld  wrld  wrld  wrld  wrld  wrld  wrld
 hello world  wrld  world  wrld  world  wrlld  wrld  
 hello world  wrld  world  wrld  world  wrld  world  
 hello world  wrld  world  wrld  world  wrld  world  
 hello world  wrld  world  wrld  world  wrld  world  
 hello world  wrld  wrlld world  wrlld world  wrlld w
 hello world world  wrlld world  wrll  world world  w
 hello world world  wrlld world world  wrlld world wo
 hello world world  wrll  world world  wrlld world wo
 hello world world world  wrlld world world  wrlld wo
 hello world world world  wrlld world world  wrll  wo
 hello world world world world  wrlld world world wor
 hello world world world world  wrll  world world wor
 hello world world world world world world  wrll  wor
 hello world world world world world world world worl
 hello world world world world world world world worl
 hello world world world wor