In [64]:
import torch
import torch.optim as optim
import numpy as np

In [65]:
# Random seed to make results deterministic and reproducible
torch.manual_seed(0)

<torch._C.Generator at 0x7fd1c03b9710>

In [66]:
sample = " if you want you"

In [67]:
# make dictionary of sample
char_set = list(set(sample))
char_dic = {c: i for i, c in enumerate(char_set)}
print(char_dic)

{'t': 0, 'i': 1, 'o': 2, 'n': 3, 'y': 4, ' ': 5, 'u': 6, 'w': 7, 'f': 8, 'a': 9}


In [68]:
# hyper parameters
dic_size = len(char_dic) #10
hidden_size = len(char_dic) #10
learning_rate = 0.1

In [69]:
# data setting
sample_idx = [char_dic[c] for c in sample]  #[5, 1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6]
x_data = [sample_idx[:-1]] #마지막 원소 제외하고, [[...]] 차원 증가
x_one_hot = [np.eye(dic_size)[x] for x in x_data] #sample_idx 순서대로 대응되는 원핫 인코딩 벡터 집합
y_data = [sample_idx[1:]] #공백 자르고 차원 증가

In [70]:
print(x_one_hot) # input vector

[array([[0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.]])]


In [71]:
print(y_data)

[[1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6]]


In [72]:
# transform as torch tensor variable
X = torch.FloatTensor(x_one_hot)
Y = torch.LongTensor(y_data)

In [73]:
print(Y.view(-1))

tensor([1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6])


In [74]:
# declare RNN
rnn = torch.nn.RNN(dic_size, hidden_size, batch_first=True)

In [75]:
# loss & optimizer setting
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(rnn.parameters(), learning_rate)

In [76]:
# start training
for i in range(50):
    optimizer.zero_grad()
    outputs, _status = rnn(X)
    loss = criterion(outputs.view(-1, dic_size), Y.view(-1))
    loss.backward()
    optimizer.step()
    #outputs.data.numpy()은 (50, 15, 10) 사이즈
    result = outputs.data.numpy().argmax(axis=2)  #outputs는 모든 은닉상태 출력 -> 행마다(10개 원소) 비교해서 최대값인 인덱스 저장
    result_str = ''.join([char_set[c] for c in np.squeeze(result)]) #대응되는 문자로 바꿔주고 문자열로 저장
    print(i, "loss: ", loss.item(), "prediction: ", result, "true Y: ", y_data, "prediction str: ", result_str)

0 loss:  2.3198115825653076 prediction:  [[7 0 8 7 7 8 5 7 0 0 7 5 7 7 8]] true Y:  [[1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6]] prediction str:  wtfwwf wttw wwf
1 loss:  2.0635344982147217 prediction:  [[7 5 5 5 5 5 5 7 6 5 5 5 5 5 5]] true Y:  [[1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6]] prediction str:  w      wu      
2 loss:  1.8523945808410645 prediction:  [[4 6 5 4 5 6 5 7 4 5 2 5 7 6 5]] true Y:  [[1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6]] prediction str:  yu y u wy o wu 
3 loss:  1.6916230916976929 prediction:  [[4 6 5 4 2 6 5 7 4 3 0 5 4 2 5]] true Y:  [[1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6]] prediction str:  yu you wynt yo 
4 loss:  1.537466287612915 prediction:  [[4 2 5 4 2 6 5 7 9 1 0 5 4 2 6]] true Y:  [[1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6]] prediction str:  yo you wait you
5 loss:  1.4321811199188232 prediction:  [[4 8 5 4 2 6 5 7 9 3 0 5 4 2 6]] true Y:  [[1, 8, 5, 4, 2, 6, 5, 7, 9, 3, 0, 5, 4, 2, 6]] prediction str:  yf you want you
6 loss:  1.

In [61]:
#3차원 배열 axis 헷갈려서 정리
a = np.arange(36).reshape(3,4,3)
print(a)
print(a.shape)
print(np.sum(a, axis=2).shape) #원래 행렬 shape에서 해당 axis만 뺀 것이 shape이 됨
np.sum(a, axis=2)

[[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]
  [ 9 10 11]]

 [[12 13 14]
  [15 16 17]
  [18 19 20]
  [21 22 23]]

 [[24 25 26]
  [27 28 29]
  [30 31 32]
  [33 34 35]]]
(3, 4, 3)
(3, 4)


array([[  3,  12,  21,  30],
       [ 39,  48,  57,  66],
       [ 75,  84,  93, 102]])