In [1]:
import torch
import numpy as np

from torch import nn
from torch import optim

In [2]:
sample_data = "hihello"
learning_rate = 0.1

In [3]:
## one-hot encoding을 위한 label을 제작하기 위함.
char_set = sorted(list(set(sample_data))) ## unique char 선별.
char_dict = {c : i for i, c in enumerate(char_set)} ## label index 부여.
sample_idx = [char_dict[c] for c in sample_data] ## sparse_categories list

print(char_set)
print(char_dict)
print(sample_idx)

['e', 'h', 'i', 'l', 'o']
{'e': 0, 'h': 1, 'i': 2, 'l': 3, 'o': 4}
[1, 2, 1, 0, 3, 3, 4]


In [4]:
input_dim = len(char_set)
hidden_dim = len(char_set)

In [5]:
x_data = [sample_idx[:-1]] ## 마지막 문자 'o'를 제외한 나머지를 input으로 사용
x_one_hot = [np.eye(input_dim)[x] for x in x_data] ## one_hot encoded vectors(stacked)

y_data = [sample_idx[1:]] ## label은 반대로 첫번째 문자 'h'를 제외한 나머지를 사용하여 제대로 예측하는지 확인.

print(x_data)
print(x_one_hot)
print(y_data)

[[1, 2, 1, 0, 3, 3]]
[array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.]])]
[[2, 1, 0, 3, 3, 4]]


In [6]:
X = torch.tensor(np.array(x_one_hot), dtype=torch.float32)
Y = torch.tensor(y_data, dtype=torch.int64)
h0 = torch.zeros(1, 1, hidden_dim) ## 맨 처음 사용될 h_{t-1}

print(X.shape, Y.shape)

torch.Size([1, 6, 5]) torch.Size([1, 6])


In [7]:
model = nn.RNN(input_dim, hidden_dim, batch_first=True)
criterion = nn.CrossEntropyLoss() ## Multi-class classification
optimizer = optim.Adam(model.parameters(), learning_rate)

In [8]:
for i in range(100):
    optimizer.zero_grad()

    outputs, status = model(X)
    
    if i == 0:
        print(outputs.shape, status.shape)
    
    loss = criterion(outputs.view(-1, input_dim), Y.view(-1))
    loss.backward()
    optimizer.step()

    result = outputs.data.numpy().argmax(axis=2)
    result_str = ''.join([char_set[c] for c in np.squeeze(result)])

    print(f"epoch{i} | loss : {loss.item()}, prediction : {result_str}")

torch.Size([1, 6, 5]) torch.Size([1, 1, 5])
epoch0 | loss : 1.8366345167160034, prediction : hiehoh
epoch1 | loss : 1.543753743171692, prediction : ioehii
epoch2 | loss : 1.345700740814209, prediction : eoelio
epoch3 | loss : 1.2051149606704712, prediction : ioello
epoch4 | loss : 1.0974088907241821, prediction : ioello
epoch5 | loss : 1.01070237159729, prediction : ioello
epoch6 | loss : 0.9334042072296143, prediction : ioello
epoch7 | loss : 0.8635385632514954, prediction : ioello
epoch8 | loss : 0.7992921471595764, prediction : ioello
epoch9 | loss : 0.7528714537620544, prediction : ioello
epoch10 | loss : 0.7204034328460693, prediction : ioello
epoch11 | loss : 0.6878079771995544, prediction : ioello
epoch12 | loss : 0.6629229187965393, prediction : ihello
epoch13 | loss : 0.6429349780082703, prediction : ihello
epoch14 | loss : 0.6217709183692932, prediction : ihello
epoch15 | loss : 0.6023480892181396, prediction : ihello
epoch16 | loss : 0.5876160264015198, prediction : ihello
e