In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [32]:
def gru_forward(inp, h0, w_ih, w_hh, b_ih, b_hh):
    """
    wih和whh是三个矩阵堆叠
    """
    bs, seq, i_size = inp.shape
    h_size = w_ih.shape[0] // 3
    prev_h = h0 # [bs, h_dim]

    bw_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # [bs, 3*h_dim, i_size]
    bw_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # [bs, 3*h_dim, h_dim]

    output = torch.randn(bs, seq, h_size)

    for t in range(seq):
        x = inp[:, t, :] # [bs, i_size]
        w_times_x = torch.bmm(bw_ih, x.unsqueeze(-1)).squeeze(-1) # [bs, 3*h_d]
        w_times_h = torch.bmm(bw_hh, prev_h.unsqueeze(-1)).squeeze(-1) # [bs, 3*h_d]
        i = 0
        ind_l = h_size * i
        ind_r = h_size * (i+1)
        r_t = torch.sigmoid(w_times_x[:, ind_l:ind_r] + w_times_h[:, ind_l:ind_r] + b_ih[ind_l:ind_r]
                            + b_hh[ind_l:ind_r])
        i += 1
        ind_l = h_size * i
        ind_r = h_size * (i+1)
        z_t = torch.sigmoid(w_times_x[:, ind_l:ind_r] + w_times_h[:, ind_l:ind_r] + b_ih[ind_l:ind_r]
                            + b_hh[ind_l:ind_r])
        # 候选状态
        i += 1
        ind_l = h_size * i
        ind_r = h_size * (i+1)
        n_t = torch.tanh(w_times_x[:, ind_l:ind_r]+b_ih[ind_l:ind_r] +
                         r_t * (w_times_h[:, ind_l:ind_r] + b_hh[ind_l:ind_r]))
        prev_h  = (1-z_t) * n_t + z_t * prev_h
        output[:, t, :] = prev_h

    return output, prev_h

def test_gru_impl():
    bs, seq, i_size, h_dim = 2, 3, 4, 5
    inp = torch.randn(bs, seq, i_size)

    h0 = torch.randn(bs, h_dim)

    gru = nn.GRU(i_size, h_dim, batch_first=True)
    res1, _ = gru(inp, h0.unsqueeze(0))

    for k, v in gru.named_parameters():
        print(k, v.shape)
    w_ih = gru.weight_ih_l0
    w_hh = gru.weight_hh_l0
    b_ih = gru.bias_ih_l0
    b_hh = gru.bias_hh_l0
    res2, _ = gru_forward(inp, h0, w_ih, w_hh, b_ih, b_hh)
    print(torch.allclose(res1, res2))
    print(res1)
    print(res2)


test_gru_impl()


weight_ih_l0 torch.Size([15, 4])
weight_hh_l0 torch.Size([15, 5])
bias_ih_l0 torch.Size([15])
bias_hh_l0 torch.Size([15])
True
tensor([[[ 0.2041, -1.3078, -0.9466,  0.2533, -0.6753],
         [-0.0092, -1.0803, -0.7793,  0.3769, -0.1754],
         [-0.0859, -0.6894, -0.6592,  0.0124,  0.1118]],

        [[ 1.0236, -0.7514,  0.0233,  0.1132,  0.1051],
         [ 0.8297, -0.3459,  0.0732,  0.1759,  0.2705],
         [ 0.7027, -0.2098, -0.0653,  0.1357,  0.5134]]],
       grad_fn=<TransposeBackward1>)
tensor([[[ 0.2041, -1.3078, -0.9466,  0.2533, -0.6753],
         [-0.0092, -1.0803, -0.7793,  0.3769, -0.1754],
         [-0.0859, -0.6894, -0.6592,  0.0124,  0.1118]],

        [[ 1.0236, -0.7514,  0.0233,  0.1132,  0.1051],
         [ 0.8297, -0.3459,  0.0732,  0.1759,  0.2705],
         [ 0.7027, -0.2098, -0.0653,  0.1357,  0.5134]]], grad_fn=<CopySlices>)
