## 实现GRU网络

关于GRU和LSTM这些门网络的选择, 参考:

[Empirical Evaluation of Gated Recurrent Neural Networks on Sequence Modeling](https://arxiv.org/pdf/1412.3555.pdf)

In [1]:
import torch
import torch.nn as nn

In [2]:
lstm_layer = nn.LSTM(3, 5)
gru_layer = nn.GRU(3, 5)

print(sum(p.numel() for p in lstm_layer.parameters()))
print(sum(p.numel() for p in gru_layer.parameters()))

200
150


在input_size相同的情况下 GRU的参数数目大致是LSTM的$\frac{3}{4}$倍

### 实现GRU

In [3]:
def gru_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    prev_h = initial_states
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 3

    # 对权重扩维 复制成batch_size倍
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)

    output = torch.zeros(bs, T, h_size)  # GRU网络的输出状态序列

    for t in range(T):
        x = input[:, t, :]  # t时刻GRU cell的输入特征向量, [bs, i_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))  # [bs, 3*h_size, 1]
        w_times_x = w_times_x.squeeze(-1)  # [bs, 3*h_size]

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))  # [bs, 3*h_size, 1]
        w_times_h_prev = w_times_h_prev.squeeze(-1)  # [bs, 3*h_size]

        # 重置门
        r_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size]
                            + b_ih[:h_size] + b_hh[:h_size])
        # 更新门
        z_t = torch.sigmoid(w_times_x[:, h_size:2 * h_size] + w_times_h_prev[:, h_size:2 * h_size]
                            + b_ih[h_size:2 * h_size] + b_hh[h_size:2 * h_size])
        # 候选状态
        n_t = torch.tanh(w_times_x[:, 2 * h_size:3 * h_size] + b_ih[2 * h_size:3 * h_size] +
                         r_t*(w_times_h_prev[:, 2 * h_size:3 * h_size] + b_hh[2 * h_size:3 * h_size]))
        prev_h = (1 - z_t) * n_t + z_t * prev_h  # 增量更新隐含状态
        output[:, t, :] = prev_h

    return output, prev_h

In [4]:
bs, T, i_size, h_size = 2, 3, 4, 5

input = torch.randn(bs, T, i_size)  # 输入序列
h0 = torch.randn(bs, h_size)

# 调用PyTorch官方的GRU API
gru_layer = nn.GRU(i_size, h_size, batch_first=True)
output, h_final = gru_layer(input, h0.unsqueeze(0))
print(output)

tensor([[[ 0.1593,  0.7147, -1.0306, -0.1275,  0.3484],
         [-0.3227, -0.0368, -0.5100,  0.4591, -0.1904],
         [ 0.0651,  0.2047, -0.3852,  0.1842,  0.1674]],

        [[ 0.4132, -0.6484, -0.1507, -0.4509,  1.4405],
         [ 0.0431, -0.4333, -0.1946,  0.1758,  1.0490],
         [ 0.1021, -0.1294,  0.0391,  0.4045,  0.6429]]],
       grad_fn=<TransposeBackward1>)


In [5]:
for k, v in gru_layer.named_parameters():
    print(k, v.shape)

weight_ih_l0 torch.Size([15, 4])
weight_hh_l0 torch.Size([15, 5])
bias_ih_l0 torch.Size([15])
bias_hh_l0 torch.Size([15])


In [6]:
# 调用自定义的gru_forward函数

output_custom, h_final_custom = gru_forward(input, h0, gru_layer.weight_ih_l0, gru_layer.weight_hh_l0,
                                            gru_layer.bias_ih_l0, gru_layer.bias_hh_l0)
print(output_custom)
print(torch.allclose(output, output_custom))

tensor([[[ 0.1593,  0.7147, -1.0306, -0.1275,  0.3484],
         [-0.3227, -0.0368, -0.5100,  0.4591, -0.1904],
         [ 0.0651,  0.2047, -0.3852,  0.1842,  0.1674]],

        [[ 0.4132, -0.6484, -0.1507, -0.4509,  1.4405],
         [ 0.0431, -0.4333, -0.1946,  0.1758,  1.0490],
         [ 0.1021, -0.1294,  0.0391,  0.4045,  0.6429]]], grad_fn=<CopySlices>)
True
