<a href="https://colab.research.google.com/github/pko89403/DeepLearningSelfStudy/blob/master/FirstStepOfRNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch 
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [0]:
n_hidden = 35
lr = 0.01
epochs = 1000

string = "hello pytorch. how long can a rnn cell remember?"
chars = "abcedefghijklmnopqrstuvwxyz ?!.,:;01"
char_list = [i for i in chars]
n_letters = len(char_list)

In [0]:
def string_to_onehot(string):
  start = np.zeros(shape=n_letters, dtype=int)
  end = np.zeros(shape=n_letters, dtype=int)
  start[-2] = 1
  end[-1] = 1
  for i in string:
    idx = char_list.index(i)
    zero = np.zeros(shape=n_letters, dtype=int)
    zero[idx] = 1
    start = np.vstack([start, zero])

  output = np.vstack([start, end])
  return output

In [0]:
onehot_apple = string_to_onehot("apple")

In [0]:
def onehot_to_word(onehot_1):
  onehot = torch.Tensor.numpy(onehot_1)
  print(onehot)
  print(onehot.argmax())
  return char_list[onehot.argmax()]

In [28]:
onehot_to_word(torch.from_numpy(onehot_apple)[2])

[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
16


'p'

In [0]:
class RNN(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(RNN, self).__init__()

    self.input_size = input_size
    self.hidden_size = hidden_size
    self.output_size = output_size

    self.i2h = nn.Linear(input_size, hidden_size)
    self.h2h = nn.Linear(hidden_size, hidden_size)
    self.i2o = nn.Linear(hidden_size, output_size)
    self.act_fn = nn.Tanh()
    
  def forward(self, input, hidden):
    hidden = self.act_fn( self.i2h(input) + self.h2h(hidden))
    output = self.i2o(hidden)
    return output, hidden
  
  def init_hidden(self):
    return torch.zeros(1, self.hidden_size)

In [0]:
rnn = RNN(n_letters, n_hidden, n_letters)

In [0]:
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=lr)

In [0]:
one_hot = torch.from_numpy(string_to_onehot(string)).type_as(torch.FloatTensor())

In [40]:
for i in range(epochs):
  rnn.zero_grad()
  total_loss = 0
  hidden = rnn.init_hidden()

  for j in range(one_hot.size()[0]-1):
    input_ = one_hot[j:j+1,:]
    target = one_hot[j+1]

    output, hidden = rnn.forward(input_, hidden)
    loss = loss_func(output.view(-1), target.view(-1))
    total_loss += loss
    input_ = output

  total_loss.backward()
  optimizer.step()
  if i % 10 == 0:
    print(total_loss)

tensor(2.3231, grad_fn=<AddBackward0>)
tensor(0.9360, grad_fn=<AddBackward0>)
tensor(0.6199, grad_fn=<AddBackward0>)
tensor(0.4080, grad_fn=<AddBackward0>)
tensor(0.2660, grad_fn=<AddBackward0>)
tensor(0.2088, grad_fn=<AddBackward0>)
tensor(0.1575, grad_fn=<AddBackward0>)
tensor(0.1246, grad_fn=<AddBackward0>)
tensor(0.1345, grad_fn=<AddBackward0>)
tensor(0.1010, grad_fn=<AddBackward0>)
tensor(0.0824, grad_fn=<AddBackward0>)
tensor(0.0687, grad_fn=<AddBackward0>)
tensor(0.0721, grad_fn=<AddBackward0>)
tensor(0.0587, grad_fn=<AddBackward0>)
tensor(0.0540, grad_fn=<AddBackward0>)
tensor(0.0464, grad_fn=<AddBackward0>)
tensor(0.0421, grad_fn=<AddBackward0>)
tensor(0.0451, grad_fn=<AddBackward0>)
tensor(0.0365, grad_fn=<AddBackward0>)
tensor(0.0329, grad_fn=<AddBackward0>)
tensor(0.0295, grad_fn=<AddBackward0>)
tensor(0.0270, grad_fn=<AddBackward0>)
tensor(0.0603, grad_fn=<AddBackward0>)
tensor(0.0301, grad_fn=<AddBackward0>)
tensor(0.0249, grad_fn=<AddBackward0>)
tensor(0.0224, grad_fn=<A

In [41]:
start = torch.zeros(1, n_letters)
start[:, -2] = 1

with torch.no_grad():
  hidden = rnn.init_hidden()
  input_ = start 
  output_string = ""
  for i in range( len(string) ):
    output, hidden = rnn.forward(input_, hidden)
    output_string += onehot_to_word(output.data)
    input_ = output

print(output_string)

[[-6.37850165e-03  4.39833105e-03 -4.30901349e-03 -3.19696590e-03
   4.77719947e-12 -1.91182356e-12  1.09366345e-04  4.31960355e-03
   1.00054169e+00 -2.00454757e-12  7.27850988e-08  3.01601290e-07
  -3.82766873e-03 -8.22949596e-03 -3.12593952e-03 -5.98376244e-03
   3.32588330e-04  2.76991763e-10 -5.91028482e-04 -8.82210693e-11
   7.01257586e-03 -2.79134440e-08  8.07589231e-07  1.45000666e-02
   5.00649776e-07 -1.34850703e-02 -1.39085010e-09 -5.35494834e-03
   1.00440532e-03  1.76293497e-07  3.65361199e-03  8.05858635e-13
   1.29132722e-08 -1.52015553e-12  1.13971481e-08  1.95004139e-02]]
8
[[ 3.05759907e-03  3.19517702e-02 -8.10939074e-03  1.02494359e+00
  -1.61864619e-12 -1.60473371e-12 -3.38093960e-05 -5.68813551e-03
  -6.24473020e-02  2.90690649e-13 -1.97254906e-08 -1.05481604e-07
  -4.54721078e-02  4.53421399e-02 -2.92121060e-02 -1.65214017e-02
  -9.37980972e-03 -1.00741415e-10  2.71100886e-02  2.99615333e-11
  -9.06240940e-03  5.26141264e-09 -1.68974111e-07  3.58745605e-02
  -7.0