In [2]:
import torch
import torch.nn as nn
import numpy as np


class LongShortTermMemoryModel(nn.Module):
    def __init__(self, encoding_size):
        super(LongShortTermMemoryModel, self).__init__()

        self.state_size = 128
        self.lstm = nn.LSTM(encoding_size, self.state_size)
        self.dense = nn.Linear(self.state_size, encoding_size)

        self.zero_state = torch.zeros(1, 1, self.state_size)  # Shape: (number of layers, batch size, state size)
        self.hidden_state = self.zero_state
        self.cell_state = self.zero_state

    def reset(self):
        """Reset states prior to new input sequence"""
        self.hidden_state = self.zero_state
        self.cell_state = self.zero_state

    def logits(self, x):
        """x = (sequence length, batch size, encoding size)"""
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        return self.dense(out.reshape(-1, self.state_size))

    def f(self, x):
        """x = (sequence length, batch size, encoding size)"""
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):
        """x = (sequence length, batch size, encoding size), y = (sequence length, encoding size)"""
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))


In [34]:
chars = [" ", "h", "e", "l", "o", "w", "r", "d"]
char_count = len(chars)
char_codes = np.identity(char_count).tolist()

def code_char(character, wrap):
    if wrap:
        return [char_codes[chars.index(character)]]
    else:
        return char_codes[chars.index(character)]

def code(string, wrap=False):
    return [code_char(char, wrap) for char in string]

x_train = torch.tensor(code(" hello world", wrap=True))
y_train = torch.tensor(code("hello world "))

model = LongShortTermMemoryModel(encoding_size=char_count)
optimizer = torch.optim.RMSprop(model.parameters(), 0.001)

for epoch in range(500):
    model.reset()
    model.loss(x_train, y_train).backward()
    optimizer.step()
    optimizer.zero_grad()

    if epoch % 10 == 9:
        # Generate characters from the initial characters ' h'
        model.reset()
        text = ' h'
        model.f(torch.tensor([[char_codes[0]]]))
        y = model.f(torch.tensor([[char_codes[1]]]))
        text += chars[y.argmax(1)]
        for c in range(50):
            y = model.f(torch.tensor([[char_codes[y.argmax(1)]]]))
            text += chars[y.argmax(1)]
        print(text)


 hllloooorllll  lllll  llll  llll  llll  llll  llll  
 hlllo worlld   rddd   dddd   ddd    ddd   ddd    dd 
 hlllo world  wrrdd  wrrdd  wrrdd  wrrdd  wrrdd  wrrd
 hello world  wrrd  wordd  wrrdd  wrld  wrrdd  wrrd  
 hello world  wrrd  world  wrrdd world  wrrld world  
 hello world  wrrd  world  wrrdd world  wrrld world  
 hello world world  wrrld world  wrrld world  wrrld w
 hello world world  wrrld world  wrrld world  wrrld w
 hello world world  wrrld world world  wrrld world  w
 hello world world wwrrd  world world  wrrld world wo
 hello world world world  wrrld world world  wrrld wo
 hello world world world  wrrld world world  wrrld wo
 hello world world world wwrrld world world wwrrld wo
 hello world world world wwrld  world world world  wr
 hello world world world world  wrrld world world wor
 hello world world world world  wrrld world world wor
 hello world world world world wwrrld world world wor
 hello world world world world world  world world wor
 hello world world world wor

In [35]:
for c in range(50):
    y = model.f(torch.tensor([[char_codes[y.argmax(1)]]]))
    text += chars[y.argmax(1)]
print(text)


 hello world world world world world world world world world world world world world world world world 
