In [2]:
import torch

In [3]:
word2idx = {
    "e" : 0,
    "h": 1,
    "l": 2,
    "o": 3
}

idx2word = {}
for w in word2idx:
    idx2word[word2idx[w]] = w

def encode(seq):
    return [word2idx[w] for w in seq]

def decode(token):
    return [idx2word[i] for i in token]

def one_hot(idx, len = 4):
    hots = []
    for i in range(len):
        if(i == idx):
            hots.append(1)
        else:
            hots.append(0)
    return hots

def batch_one_hot(seq):
    token = encode(seq)
    return [one_hot(i) for i in token]


In [4]:
input_size = 4
hidden_size = 4
batch_size = 1

In [5]:
input_seq = "hello"
output_seq ="ohlol"
x_data = batch_one_hot(input_seq)
y_data = batch_one_hot(output_seq)

In [6]:
inputs = torch.tensor(x_data).view(-1, batch_size, input_size).float()
labels = torch.tensor(y_data).view(-1, batch_size, input_size).float()
print(inputs.shape)
print(labels.shape)

torch.Size([5, 1, 4])
torch.Size([5, 1, 4])


In [7]:
class Model(torch.nn.Module):
    def __init__(self, input_size, hidden_size, batch_size):
         super(Model, self).__init__()
         self.input_size = input_size
         self.hidden_size = hidden_size
         self.batch_size = batch_size
         self.rnncell = torch.nn.RNNCell(input_size=self.input_size, hidden_size=self.hidden_size)

    def forward(self, input, hidden):
        hidden = self.rnncell(input, hidden)
        return hidden
    
    def init_hidden(self):
        return torch.zeros(self.batch_size, self.hidden_size)

net = Model(input_size, hidden_size, batch_size)

In [8]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(net.parameters(), lr=0.1)

In [9]:
for epoch in range(20):
    loss = 0
    optimizer.zero_grad()
    hidden = net.init_hidden()
    print("Predicted string: ", end="")
    for input, label in zip(inputs, labels):
        hidden = net(input, hidden)
        loss += criterion(hidden, label)
        _, idx = hidden.max(dim=1)
        print(idx2word[idx.item()], end="")
    loss.backward()
    optimizer.step()
    print(", Epoch [%d/15] loss=%.4f" % (epoch + 1, loss.item()))

Predicted string: ooooh, Epoch [1/15] loss=6.3014
Predicted string: ohool, Epoch [2/15] loss=5.3462
Predicted string: oholl, Epoch [3/15] loss=4.8429
Predicted string: ohlll, Epoch [4/15] loss=4.5127
Predicted string: ohlll, Epoch [5/15] loss=4.2176
Predicted string: ohlll, Epoch [6/15] loss=3.9339
Predicted string: ohlll, Epoch [7/15] loss=3.6590
Predicted string: ohool, Epoch [8/15] loss=3.4299
Predicted string: ohool, Epoch [9/15] loss=3.2712
Predicted string: ohool, Epoch [10/15] loss=3.1406
Predicted string: ohool, Epoch [11/15] loss=2.9982
Predicted string: ohool, Epoch [12/15] loss=2.8628
Predicted string: ohool, Epoch [13/15] loss=2.7637
Predicted string: ohlol, Epoch [14/15] loss=2.6736
Predicted string: ohlol, Epoch [15/15] loss=2.5535
Predicted string: ohlol, Epoch [16/15] loss=2.4242
Predicted string: ohlol, Epoch [17/15] loss=2.4584
Predicted string: ohlol, Epoch [18/15] loss=2.2555
Predicted string: ohlol, Epoch [19/15] loss=2.2744
Predicted string: ohlol, Epoch [20/15] l