# 双方向LSTMを計算するためのTimeBiLSTMクラスを実装する

In [1]:
import numpy as np
from common.time_layers import TimeLSTM

### [演習]
* 以下のTimeBiLSTMクラスを完成させましょう

In [2]:
class TimeBiLSTM:
    """
    双方向LSTM
    """
    def __init__(self, Wx1, Wh1, b1, Wx2, Wh2, b2, stateful=False):
        
        # レイヤの定義
        self.forward_lstm = TimeLSTM(Wx1, Wh1, b1, stateful)
        self.backward_lstm = TimeLSTM(Wx2, Wh2, b2, stateful)
        
        # パラメータ、勾配をそれぞれまとめる
        self.params = self.forward_lstm.params + self.backward_lstm.params
        self.grads = self.forward_lstm.grads + self.backward_lstm.grads

    def forward(self, xs):
        """
        順伝播
        xs : 入力データ
        """
        # 順方向のLSTM
        o1 = self.forward_lstm.forward(xs)
        
        # 逆方向のLSTM
        o2 = self.backward_lstm.forward(xs[:, ::-1,:]) # xsを逆順にして入力する
        o2 = o2[:, ::-1,:] # 結果を逆順にする
        
        # 順方向LSTMの結果と逆方向LSTMの結果を結合する
        out = np.concatenate((o1, o2), axis=2)
        return out

    def backward(self, dhs):
        """
        逆伝播
        dhs : 勾配
        """
        H = dhs.shape[2] // 2
        do1 = dhs[:, :, :H]
        do2 = dhs[:, :, H:]

        dxs1 = self.forward_lstm.backward(do1)
        do2 = do2[:, ::-1]
        dxs2 = self.backward_lstm.backward(do2)
        dxs2 = dxs2[:, ::-1]
        dxs = dxs1 + dxs2
        return dxs

In [3]:
# 語彙数
V = 3
# 埋め込み後次元数
D = 3
# 中間層ノード数
H = 4
# データ数
N = 3
# 単語数
T = 5

rn = np.random.randn
Wx1 = rn(D, 4 * H) / np.sqrt(D)
Wh1 = rn(H, 4 * H) / np.sqrt(H)
b1 = np.zeros(4 * H)
Wx2 = rn(D, 4 * H) / np.sqrt(D)
Wh2 = rn(H, 4 * H) / np.sqrt(H)
b2 = np.zeros(4 * H)

# モデル構築
Wx1, Wh1, b1, Wx2, Wh2, b2
tb = TimeBiLSTM(Wx1, Wh1, b1, Wx2, Wh2, b2)


xs = np.random.randint(0, V, N*T*D).reshape(N, T, D)
print("xs=", xs)
print()

# 順伝播計算
out = tb.forward(xs)
print("out=", out)
print()

# 逆伝播計算
dhs = np.random.randn(N*T*H*2).reshape(N, T, H*2)
dxs = tb.backward(dhs)
print("dxs=", dxs)
print()

xs= [[[2 2 1]
  [1 2 1]
  [2 2 1]
  [0 1 0]
  [0 1 0]]

 [[2 0 2]
  [0 2 0]
  [1 2 0]
  [0 0 2]
  [2 2 1]]

 [[0 1 0]
  [1 1 2]
  [1 0 1]
  [1 2 2]
  [0 1 2]]]

out= [[[ 0.00316584  0.01610517 -0.20639893  0.22352903 -0.04121241
    0.03706722  0.54854523  0.35893382]
  [ 0.01590218  0.03396751 -0.37896126 -0.10983331 -0.07905664
    0.01991452  0.5090837   0.22628845]
  [ 0.02386993  0.03811029 -0.47659989  0.14308863 -0.01474335
    0.02111286  0.52913304  0.20738989]
  [ 0.0698984   0.20508753 -0.2929311  -0.09390141  0.00529681
    0.12082558  0.26425947  0.04976377]
  [ 0.07412194  0.17164285 -0.25639764 -0.14304603  0.02037226
    0.07001595  0.20618691  0.00791608]]

 [[ 0.05174704  0.01168178  0.03691464  0.57577591 -0.03587848
   -0.22819143  0.06960132 -0.07069305]
  [ 0.01294978  0.16424095 -0.25976653 -0.03394118 -0.24148771
    0.09040284  0.35312299 -0.08644511]
  [ 0.020492    0.20221584 -0.42839704 -0.12288047 -0.18086183
   -0.02160068  0.3258086  -0.17808359]
  [ 0.23