In [1]:
import torch
import numpy as np

from torch import nn
from torch import optim

In [2]:
sample_data = "hihello"
learning_rate = 0.1

In [3]:
char_set = sorted(list(set(sample_data))) ## unique한 문자를 골라냅니다.
char_dict = {c : i for i, c in enumerate(char_set)} ## 각 문자에 고유한 label index를 부여합니다. {char : int_label}
sample_idx = [char_dict[c] for c in sample_data] ## hihello 문자열에 대해 label index를 적용한 리스트를 만들어냅니다.

print(char_set)
print(char_dict)
print(sample_idx)

['e', 'h', 'i', 'l', 'o']
{'e': 0, 'h': 1, 'i': 2, 'l': 3, 'o': 4}
[1, 2, 1, 0, 3, 3, 4]


In [4]:
input_dim = len(char_set)
hidden_dim = len(char_set)

print(input_dim, hidden_dim)

5 5


In [5]:
x_data = [sample_idx[:-1]] ## 마지막 문자 'o'를 제외한 나머지를 input으로 사용
x_one_hot = [np.eye(input_dim)[x] for x in x_data] ## one_hot encoded vectors(stacked)

y_data = [sample_idx[1:]] ## label은 반대로 첫번째 문자 'h'를 제외한 나머지를 사용하여 제대로 예측하는지 확인.

print(x_data)
print(x_one_hot)
print(y_data)

[[1, 2, 1, 0, 3, 3]]
[array([[0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 1., 0.]])]
[[2, 1, 0, 3, 3, 4]]


In [6]:
X = torch.tensor(np.array(x_one_hot), dtype=torch.float32)
Y = torch.tensor(y_data, dtype=torch.int64)
h0 = torch.zeros(1, 1, hidden_dim) ## 첫번째 단계에서는 이전 단계 hidden state가 없기 때문에 0으로 초기화한다.

print(X.shape, Y.shape)

torch.Size([1, 6, 5]) torch.Size([1, 6])


In [7]:
model = nn.RNN(input_dim, hidden_dim, batch_first=True)
criterion = nn.CrossEntropyLoss() ## Multi-class classification
optimizer = optim.Adam(model.parameters(), learning_rate)

In [8]:
for i in range(100):
    optimizer.zero_grad()

    outputs, status = model(X)
    
    if i == 0:
        print(outputs.shape, status.shape)
    
    loss = criterion(outputs.view(-1, input_dim), Y.view(-1))
    loss.backward()
    optimizer.step()

    result = outputs.data.numpy().argmax(axis=2)
    result_str = ''.join([char_set[c] for c in np.squeeze(result)])

    print(f"epoch{i} | loss : {loss.item()}, prediction : {result_str}")

torch.Size([1, 6, 5]) torch.Size([1, 1, 5])


epoch0 | loss : 1.650267481803894, prediction : eeehee
epoch1 | loss : 1.4092960357666016, prediction : llelll
epoch2 | loss : 1.276911735534668, prediction : lllllo
epoch3 | loss : 1.2005442380905151, prediction : illllo
epoch4 | loss : 1.139829397201538, prediction : ilillo
epoch5 | loss : 1.0740702152252197, prediction : ilillo
epoch6 | loss : 1.0043365955352783, prediction : ilillo
epoch7 | loss : 0.9455127716064453, prediction : ilello
epoch8 | loss : 0.9044656753540039, prediction : ilello
epoch9 | loss : 0.8667243123054504, prediction : ilello
epoch10 | loss : 0.8247909545898438, prediction : ilello
epoch11 | loss : 0.7881093621253967, prediction : ilello
epoch12 | loss : 0.7645301818847656, prediction : ihello
epoch13 | loss : 0.73813796043396, prediction : ihello
epoch14 | loss : 0.7148441672325134, prediction : ihello
epoch15 | loss : 0.6978746056556702, prediction : ihello
epoch16 | loss : 0.6842898726463318, prediction : ihello
epoch17 | loss : 0.6715064644813538, predictio