# Vanilla RNN

ref: https://github.com/hunkim/PyTorchZeroToAll/blob/master/12_1_rnn_basics.py
https://github.com/hunkim/PyTorchZeroToAll/blob/master/12_2_hello_rnn.py


In [24]:
import sys
import torch 
import torch.nn as nn
from torch.autograd import Variable

In [2]:
# one hot encoding for each charavcter in 'hello'

h = [1, 0, 0, 0]
e = [0, 1, 0, 0]
l = [0, 0, 1, 0]
o = [0, 0, 0, 1]

In [11]:
# one cell RNN input_dim(4) -> output_dim(2) sequence; 5

cell = nn.RNN(input_size=4, hidden_size=2, batch_first=True)

# (num_layers * num_directions, batch, hidden_size) batch_first에 상관없이

hidden = (Variable(torch.randn(1, 1, 2)))

# propagate input through RNN
# Input: (batch, seq_len, input_size) batch_first=True
inputs = Variable(torch.Tensor([h, e, l, l, o]))

for one in inputs:
    one = one.view(1, 1, -1)
    #print(one)
    out, hidden = cell(one, hidden)
    print("one input size", one.size(), "out size", out.size())

one input size torch.Size([1, 1, 4]) out size torch.Size([1, 1, 2])
one input size torch.Size([1, 1, 4]) out size torch.Size([1, 1, 2])
one input size torch.Size([1, 1, 4]) out size torch.Size([1, 1, 2])
one input size torch.Size([1, 1, 4]) out size torch.Size([1, 1, 2])
one input size torch.Size([1, 1, 4]) out size torch.Size([1, 1, 2])


In [14]:
# 한번에 할수있다?
# propagate input through RNN
inputs = inputs.view(1, 5, -1)
out, hidden = cell(inputs, hidden)

print("sequence input size", inputs.size(), "out size", out.size())


sequence input size torch.Size([1, 5, 4]) out size torch.Size([1, 5, 2])


In [18]:
# hidden : (num_layers * num_directions, batch, hidden_size) batch_first에 상관없이
hidden = Variable(torch.randn(1, 3, 2))

# One cell RNN input_dim (4) -> output_dim (2) . sequnce: 5, batch: 3
# 3 batches "hello", 'eolll', 'lleel'
# rank = (3, 5, 4)

inputs = Variable(torch.Tensor([[h, e, l, l, o],
                                [e, o, l, l, l],
                                [l, l, e, e, l],
                               ]))

# propagate input through RNN
# Input : (Batch, seq_len, input_size) when batch_first = True
# B * S * I
out, hidden = cell(inputs, hidden)
print("batch input size", inputs.size(), "out size", out.size())

batch input size torch.Size([3, 5, 4]) out size torch.Size([3, 5, 2])


In [20]:
# One cell RNN input_dim(4) -> output-dim (2)
cell = nn.RNN(input_size=4, hidden_size=2)

# given dimensions dim0 and dim1 are swapped
inputs = inputs.transpose(dim0=0, dim1=1)
# propagate input through RNN
# input: (seq_len, batch_size, input_size) batch_first=False (default)
# S * B * I
out, hidden = cell(inputs, hidden)
print("batch input size", inputs.size(), "out size", out.size())

batch input size torch.Size([5, 3, 4]) out size torch.Size([5, 3, 2])


---

In [46]:
torch.manual_seed(777) # reproducibility
#           0     1    2    3    4
idx2char = ['h', 'i', 'e', 'l', 'o']

# teach hihell -. ihello
x_data = [0, 1, 0, 2, 3, 3] # hihell
one_hot_lookup = [[1, 0, 0, 0, 0], # 0
                  [0, 1, 0, 0, 0], # 1
                  [0, 0, 1, 0, 0], # 2
                  [0, 0, 0, 1, 0], # 3
                  [0, 0, 0, 0, 1], # 4
                 ]

y_data = [1, 0, 2, 3, 3, 4] # ihello
x_one_hot = [one_hot_lookup[x] for x in x_data]

In [47]:
# 샘플 배치를 가졌으니 한번에 변수로 바꿔보자
inputs = Variable(torch.Tensor(x_one_hot))
labels = Variable(torch.LongTensor(y_data))

In [48]:
num_classes = 5
input_size = 5 # one-hot size
hidden_size = 5 # output from the RNN. 원핫을 바로 예측
batch_size = 1 # one sentence
sequence_length = 1 # one by one
num_layers = 1 # one-layer RNN

In [53]:
class RnnModel(nn.Module):
    
    def __init__(self):
        super(RnnModel, self).__init__()
        self.rnn = nn.RNN(input_size=input_size,
                          hidden_size=hidden_size,
                         batch_first=True)
    
    def forward(self, hidden, x):
        # reshape input (batch first)
        x = x.view(batch_size, sequence_length, input_size)
        
        # propagate input through RNN
        # Input : (batch, seq_len, input_size)
        # hidden : (num_layers * num_directions, batch, hidden_size)
        out, hidden = self.rnn(x, hidden)
        return hidden, out.view(-1, num_classes)
    
    def init_hidden(self):
        # initialize hidden and cell states
        # (num_layers * num_directions, batch, hidden_size)
        return Variable(torch.zeros(num_layers, batch_size, hidden_size))

In [54]:
# RNN 모델 인스턴스
model = RnnModel()
print(model)

RnnModel(
  (rnn): RNN(5, 5, batch_first=True)
)


In [55]:
# loss, optimizer 설정
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

In [56]:
# Train the model
for epoch in range(1000):
    optimizer.zero_grad()
    loss = 0
    hidden = model.init_hidden()
    
    sys.stdout.write("predicted string: ")
    for input_, label in zip(inputs, labels):
        hidden, output = model(hidden, input_)
        val, idx = output.max(1)
        sys.stdout.write(idx2char[idx.data[0]])
        loss += criterion(output, label)
    
    print(", epoch: %d, loss: %1.3f" % (epoch+1, loss.data[0]))
    
    loss.backward()
    optimizer.step()

print("Learning finished!")

predicted string: lellll, epoch: 1, loss: 9.266
predicted string: lillll, epoch: 2, loss: 7.890
predicted string: lieloo, epoch: 3, loss: 7.014
predicted string: liello, epoch: 4, loss: 6.249
predicted string: lhello, epoch: 5, loss: 5.653
predicted string: ihello, epoch: 6, loss: 5.191
predicted string: ihelll, epoch: 7, loss: 4.807
predicted string: ihelll, epoch: 8, loss: 4.514
predicted string: ihelll, epoch: 9, loss: 4.324
predicted string: ihelll, epoch: 10, loss: 4.156
predicted string: ihelll, epoch: 11, loss: 4.004
predicted string: ihelll, epoch: 12, loss: 3.887
predicted string: ihelll, epoch: 13, loss: 3.786
predicted string: ihelll, epoch: 14, loss: 3.702
predicted string: ihelll, epoch: 15, loss: 3.629
predicted string: ihelll, epoch: 16, loss: 3.570
predicted string: ihelll, epoch: 17, loss: 3.529
predicted string: ihelll, epoch: 18, loss: 3.490
predicted string: ihelll, epoch: 19, loss: 3.444
predicted string: ihelll, epoch: 20, loss: 3.409
predicted string: ihelll, epo