In [38]:
import torch
import torch.nn as nn
import numpy as np


class LongShortTermMemoryModel(nn.Module):
    def __init__(self, encoding_size_in, encoding_size_out):
        super(LongShortTermMemoryModel, self).__init__()

        self.batch_size = 1
        self.state_size = 128
        self.lstm = nn.LSTM(encoding_size_in, self.state_size)
        self.dense = nn.Linear(self.state_size, encoding_size_out)

        self.zero_state = torch.zeros(1, self.batch_size, self.state_size)  # Shape: (number of layers, batch size, state size)
        self.hidden_state = self.zero_state
        self.cell_state = self.zero_state

    def reset(self):
        """Reset states prior to new input sequence"""
        self.hidden_state = self.zero_state
        self.cell_state = self.zero_state

    def logits(self, x):
        """x = (sequence length, batch size, encoding size)"""
        out, (self.hidden_state, self.cell_state) = self.lstm(x, (self.hidden_state, self.cell_state))
        return self.dense(out.reshape(-1, self.state_size))

    def f(self, x):
        """x = (sequence length, batch size, encoding size)"""
        return torch.softmax(self.logits(x), dim=1)

    def loss(self, x, y):
        """x = (sequence length, batch size, encoding size), y = (sequence length, encoding size)"""
        return nn.functional.cross_entropy(self.logits(x), y.argmax(1))


In [40]:
chars_in = [" ", "h", "a", "t", "r", "c", "f", "l", "m", "p", "s", "o", "n"]
char_in_count = len(chars_in)
char_in_codes = np.identity(char_in_count).tolist()

chars_out = [" ", "🎩", "🐀", "🐈", "🏢", "🙎", "🧢", "👶"]
char_out_count = len(chars_out)
char_out_codes = np.identity(char_out_count).tolist()


def code_char(character, out=False, wrap=False):
    if out:
        result = char_out_codes[chars_out.index(character)]
    else:
        result = char_in_codes[chars_in.index(character)]

    if wrap:
        return [result]
    else:
        return result

def code(string, out=False, wrap=False):
    return [code_char(char, out, wrap) for char in string]

x_train = torch.tensor([code(word, wrap=True) for word in ["hat ", "rat ", "cat ", "flat", "matt", "cap ", "son "]])
y_train = torch.tensor([code(emoji, out=True) for emoji in ["🎩"*4, "🐀"*4, "🐈"*4, "🏢"*4, "🙎"*4, "🧢"*4, "👶"*4]])
model = LongShortTermMemoryModel(encoding_size_in=char_in_count, encoding_size_out=char_out_count)
optimizer = torch.optim.RMSprop(model.parameters(), 0.001)

def generate_output(string):
    model.reset()

    y = []
    for char in string:
        y = model.f(torch.tensor([code_char(char, wrap=True)]))

    return chars_out[y.argmax(1)]

for epoch in range(500):
    for i in range(x_train.size()[0]):
        model.reset()
        model.loss(x_train[i], y_train[i]).backward()
        optimizer.step()
        optimizer.zero_grad()

    if epoch % 100 == 0:
        print(f"Results after {epoch} epochs:")
        model.reset()

        for word in ["hat", "ht", "h", "rat", "rt", "r", "cat", "ct", "flat", "fl", "fla", "flt", "matt", "ma", "mt", "at", "cap", "cp", "c", "son", "so", "son", "crat", "mart"]:
            emoji = generate_output(word)
            print(f"'{word}' -> '{emoji}'")

        print()


Results after 0 epochs:
'hat -> '🎩'
'ht -> '🎩'
'h -> '🎩'
'rat -> '🎩'
'rt -> '🎩'
'r -> '🎩'
'cat -> '🎩'
'ct -> '🐈'
'flat -> '🎩'
'fl -> '🎩'
'fla -> '🎩'
'flt -> '🎩'
'matt -> '🎩'
'ma -> '🎩'
'mt -> '🎩'
'at -> '🎩'
'cap -> '🎩'
'cp -> '🐈'
'c -> '🐈'
'son -> '🎩'
'so -> '🎩'
'son -> '🎩'
'crat -> '🎩'
'mart -> '🎩'

Results after 100 epochs:
'hat -> '🎩'
'ht -> '🎩'
'h -> '🎩'
'rat -> '🐀'
'rt -> '🐀'
'r -> '🐀'
'cat -> '🐈'
'ct -> '🐈'
'flat -> '🏢'
'fl -> '🏢'
'fla -> '🏢'
'flt -> '🏢'
'matt -> '🙎'
'ma -> '🙎'
'mt -> '🙎'
'at -> '🙎'
'cap -> '🧢'
'cp -> '🧢'
'c -> '🐈'
'son -> '👶'
'so -> '👶'
'son -> '👶'
'crat -> '🐀'
'mart -> '🙎'

Results after 200 epochs:
'hat -> '🎩'
'ht -> '🎩'
'h -> '🎩'
'rat -> '🐀'
'rt -> '🐀'
'r -> '🐀'
'cat -> '🐈'
'ct -> '🐈'
'flat -> '🏢'
'fl -> '🏢'
'fla -> '🏢'
'flt -> '🏢'
'matt -> '🙎'
'ma -> '🙎'
'mt -> '🙎'
'at -> '🙎'
'cap -> '🧢'
'cp -> '🧢'
'c -> '🧢'
'son -> '👶'
'so -> '👶'
'son -> '👶'
'crat -> '🐀'
'mart -> '🙎'

Results after 300 epochs:
'hat -> '🎩'
'ht -> '🎩'
'h -> '🎩'
'rat -> '🐀'
'rt -> '🐀'
'r -> '🐀

In [59]:
generate_output("ct")


'🐈'