# RNNレイヤの実装

In [1]:
import numpy as np
from common.functions import sigmoid

### [演習]
* 以下のRNNレイヤのクラスを完成させましょう

In [2]:
class RNN:
    def __init__(self, Wx, Wh, b):
        """
        Wx : 入力xにかかる重み
        Wh : １時刻前のhにかかる重み
        b : バイアス
        """
        
        # パラメータのリスト
        self.params = [Wx, Wh, b]
        
        # 勾配のリスト
        self.grads = [np.zeros_like(Wx), np.zeros_like(Wh), np.zeros_like(b)]
        self.cache = None

    def forward(self, x, h_prev):
        """
        順伝播計算
        """
        Wx, Wh, b = self.params
        
        # 行列の積　+　行列の積 + バイアス
        t = np.dot(h_prev, Wh) + np.dot(x, Wx) + b
        
        # 活性化関数に入れる
        h_next = np.tanh(t)

        # 値の一時保存
        self.cache = (x, h_prev, h_next)
        
        return h_next

    def backward(self, dh_next):
        """
        逆伝播計算
        """
        Wx, Wh, b = self.params
        x, h_prev, h_next = self.cache

        # tanhでの逆伝播
        A3 = dh_next * (1 - h_next ** 2) # dh_next * (1 - y^2)
        
        # バイアスbの勾配
        # Nの方向に合計する
        db = np.sum(A3, axis=0)
        
        # 重みWhの勾配
        dWh = np.dot(h_prev.T, A3)
        
        # 1時刻前に渡す勾配
        dh_prev = np.dot(A3, Wh.T)
        
        # 重みWxの勾配
        dWx = np.dot(x.T, A3)
        
        # 入力xに渡す勾配
        dx = np.dot(A3, Wx.T)

        # 勾配をまとめる
        self.grads[0][:] = dWx # 同じメモリ位置に代入
        self.grads[1][:] = dWh # 同じメモリ位置に代入
        self.grads[2][:] = db # 同じメモリ位置に代入

        return dx, dh_prev

In [3]:
D = 1 # 入力層のノード数
H = 5 # 中間層のノード数
Wx = np.random.randn(D, H)
Wh = np.random.randn(H, H)
b = np.zeros(H)

# オブジェクトの生成
rnn = RNN(Wx, Wh, b)
print("id of rnn.grads[0]", id(rnn.grads[0]))
print("id of rnn.grads[1]", id(rnn.grads[1]))
print("id of rnn.grads[2]", id(rnn.grads[2]))
print()

# 順伝播計算
N = 4 # バッチサイズ
x = np.random.randn(N, D)
h_prev = np.random.randn(N, H)
h_next = rnn.forward(x, h_prev)
print("h_next=", h_next)
print()

# 逆伝播計算
dh_next = np.random.randn(N, H )
dx, dh_prev = rnn.backward(dh_next)
print("dx=", dx)
print()
print("dh_prev=", dh_prev)
print()

print("id of rnn.grads[0]", id(rnn.grads[0]))
print("id of rnn.grads[1]", id(rnn.grads[1]))
print("id of rnn.grads[2]", id(rnn.grads[2]))
print()


id of rnn.grads[0] 4607344352
id of rnn.grads[1] 4607344032
id of rnn.grads[2] 4607906320

h_next= [[-0.27051552  0.95542557  0.94911125  0.23734408 -0.99991545]
 [ 0.94066667 -0.99981665 -0.08333901 -0.98972713  0.99999264]
 [ 0.98323665 -0.99844006 -0.98932822  0.98917993 -0.98991407]
 [-0.91761981 -0.98493142  0.99056924 -0.99840465  0.99999908]]

dx= [[-1.0759228 ]
 [-0.15347811]
 [ 0.00502181]
 [ 0.08895468]]

dh_prev= [[-7.21359293e-01 -4.80841484e+00 -2.04551816e+00 -6.60240321e-01
   1.57245028e+00]
 [-8.26766808e-01 -4.75817394e-01  4.53963919e-01 -6.24690013e-01
   5.23285627e-01]
 [-8.72635387e-04  3.45682083e-02  2.19641737e-02  1.29540589e-02
  -6.61903458e-02]
 [ 1.04657843e-01  3.22933735e-01  1.24694200e-01  1.09659281e-01
  -3.98961902e-02]]

id of rnn.grads[0] 4607344352
id of rnn.grads[1] 4607344032
id of rnn.grads[2] 4607906320

