# 双方向LSTMを計算するためのTimeBiLSTMクラスを実装する

In [1]:
import numpy as np
from common.time_layers import TimeLSTM

### [演習]
* 以下のTimeBiLSTMクラスを完成させましょう

In [2]:
class TimeBiLSTM:
    """
    双方向LSTM
    """
    def __init__(self, Wx1, Wh1, b1, Wx2, Wh2, b2, stateful=False):
        
        # レイヤの定義
        self.forward_lstm = TimeLSTM(Wx1, Wh1, b1, stateful)
        self.backward_lstm = TimeLSTM(Wx2, Wh2, b2, stateful)
        
        # パラメータ、勾配をそれぞれまとめる
        self.params = self.forward_lstm.params + self.backward_lstm.params
        self.grads = self.forward_lstm.grads + self.backward_lstm.grads

    def forward(self, xs):
        """
        順伝播
        xs : 入力データ
        """
        # 順方向のLSTM
        o1 = self.forward_lstm.forward(xs)
        
        # 逆方向のLSTM
        o2 = self.backward_lstm.forward(xs[:, ::-1,:]) # xsを逆順にして入力する
        o2 = o2[:, ::-1,:] # 結果を逆順にする
        
        # 順方向LSTMの結果と逆方向LSTMの結果を結合する
        out = np.concatenate((o1, o2), axis=2)
        return out

    def backward(self, dhs):
        """
        逆伝播
        dhs : 勾配
        """
        H = dhs.shape[2] // 2
        do1 = dhs[:, :, :H]
        do2 = dhs[:, :, H:]

        dxs1 = self.forward_lstm.backward(do1)
        do2 = do2[:, ::-1]
        dxs2 = self.backward_lstm.backward(do2)
        dxs2 = dxs2[:, ::-1]
        dxs = dxs1 + dxs2
        return dxs

In [3]:
# 語彙数
V = 3
# 埋め込み後次元数
D = 3
# 中間層ノード数
H = 4
# データ数
N = 3
# 単語数
T = 5

rn = np.random.randn
Wx1 = rn(D, 4 * H) / np.sqrt(D)
Wh1 = rn(H, 4 * H) / np.sqrt(H)
b1 = np.zeros(4 * H)
Wx2 = rn(D, 4 * H) / np.sqrt(D)
Wh2 = rn(H, 4 * H) / np.sqrt(H)
b2 = np.zeros(4 * H)

# モデル構築
Wx1, Wh1, b1, Wx2, Wh2, b2
tb = TimeBiLSTM(Wx1, Wh1, b1, Wx2, Wh2, b2)


xs = np.random.randint(0, V, N*T*D).reshape(N, T, D)
print("xs=", xs)
print()

# 順伝播計算
out = tb.forward(xs)
print("out=", out)
print()

# 逆伝播計算
dhs = np.random.randn(N*T*H*2).reshape(N, T, H*2)
dxs = tb.backward(dhs)
print("dxs=", dxs)
print()

xs= [[[1 2 1]
  [1 2 0]
  [2 1 2]
  [0 2 2]
  [0 0 2]]

 [[2 0 2]
  [0 0 0]
  [0 1 0]
  [2 1 2]
  [2 0 2]]

 [[2 1 0]
  [0 0 2]
  [1 0 2]
  [1 0 2]
  [0 0 0]]]

out= [[[ 0.16685588  0.009486    0.13236564  0.23199225 -0.12323231
    0.02879749 -0.26732792 -0.06916514]
  [ 0.01137014  0.02610252  0.07351707  0.41700616  0.05831055
    0.02854604 -0.20052598 -0.1032093 ]
  [ 0.09142294  0.01335917  0.14931676  0.35770804  0.04060165
    0.03361097 -0.40133624 -0.02257903]
  [ 0.0754943   0.0087739   0.3732713   0.34468753  0.48483296
    0.01223175 -0.16865829 -0.0238287 ]
  [-0.00380614 -0.01396081  0.45295622 -0.16970229  0.15480918
    0.04257493  0.28387804 -0.04730145]]

 [[ 0.07416103  0.00723287  0.09772374 -0.20484657  0.00178951
    0.09846216 -0.00983824 -0.02876495]
  [ 0.05459274 -0.00779988  0.10075503 -0.06862001 -0.07518063
    0.11499957 -0.09894058 -0.07636624]
  [ 0.05560332  0.02673525  0.12264325  0.04663169 -0.15509559
    0.08064685 -0.216393   -0.14149599]
  [ 0.16