In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
def rnn_impl(inp, weight_ih, weight_hh, bias_ih, bias_hh, h_prev):
    """
    默认input是三维形状 [bs, len, in_size]
    weight input hidden: [h_dim, input_size]
    weight hidden hidden [h_dim, h_dim]
    h_prev [bs, hidden_size]
    output = [bs, T, h_dim]
    """
    bs, seq_len, input_size = inp.shape
    h_dim = weight_ih.shape[0]
    # 初始化一个输出矩阵
    h_out = torch.zeros(bs, seq_len, h_dim)

    # RNN 复杂度和 序列长度 呈线性关系
    for t in range(seq_len):
        x = inp[:, t, :] # x shape = [bs, input_size]
        x = x.unsqueeze(dim=2) # 扩充1维 [bs, input_size, 1]
        # 对weight扩充，复制bs份 [bs, h_dim, input_size]
        bw_ih = weight_ih.unsqueeze(dim=0).tile(bs, 1, 1)
        bw_hh = weight_hh.unsqueeze(dim=0).tile(bs, 1, 1)
        # [h_dim, input_size] @ [input_size, 1] = [h_dim, 1]
        wih_times_x = torch.bmm(bw_ih, x).squeeze(-1) # [h_dim,]
        whh_times_h = torch.bmm(bw_hh, h_prev.unsqueeze(2)).squeeze(-1) # [bs, h_dim]
        h_prev = torch.tanh(wih_times_x + whh_times_h + bias_ih + bias_hh)
        h_out[:, t, :] = h_prev

    return h_out, h_prev.unsqueeze(0) # 因为官方输出hn是3维的


def test_rnn_impl():
    bs, seq_len = 2, 3
    input_size, hidden_size = 2, 3
    # 随机初始化一个输入
    inp = torch.randn(bs, seq_len, input_size)
    # 初始隐藏状态
    h_prev = torch.zeros(bs, hidden_size)
    rnn = nn.RNN(input_size, hidden_size, batch_first=True)
    res1, h_n1 = rnn(inp, h_prev.unsqueeze(dim=0))

    # 取出RNN中的参数
    for parameter, name in rnn.named_parameters():
        print(parameter, name)
    print("=========================")

    weight_ih = rnn.weight_ih_l0
    weight_hh = rnn.weight_hh_l0
    bias_ih = rnn.bias_ih_l0
    bias_hh = rnn.bias_hh_l0
    res2, h_n2 = rnn_impl(inp, weight_ih, weight_hh, bias_ih, bias_hh, h_prev)

    print(res2)
    print(res1)
    print(torch.allclose(res1, res2))

test_rnn_impl()

weight_ih_l0 Parameter containing:
tensor([[ 0.1128,  0.0155],
        [-0.2187,  0.0653],
        [ 0.2091,  0.0609]], requires_grad=True)
weight_hh_l0 Parameter containing:
tensor([[-0.3954, -0.2865, -0.1061],
        [-0.1641,  0.5477, -0.1421],
        [-0.1960,  0.5420, -0.3079]], requires_grad=True)
bias_ih_l0 Parameter containing:
tensor([ 0.1051,  0.3406, -0.1874], requires_grad=True)
bias_hh_l0 Parameter containing:
tensor([-0.3666,  0.2226,  0.5342], requires_grad=True)
tensor([[[-0.3690,  0.6068,  0.0786],
         [-0.4361,  0.7997,  0.3363],
         [-0.3997,  0.8359,  0.5713]],

        [[-0.1562,  0.2542,  0.4664],
         [-0.4288,  0.7636,  0.1443],
         [-0.2091,  0.6744,  0.7675]]], grad_fn=<CopySlices>)
tensor([[[-0.3690,  0.6068,  0.0786],
         [-0.4361,  0.7997,  0.3363],
         [-0.3997,  0.8359,  0.5713]],

        [[-0.1562,  0.2542,  0.4664],
         [-0.4288,  0.7636,  0.1443],
         [-0.2091,  0.6744,  0.7675]]], grad_fn=<TransposeBackward1>)

In [37]:
def bi_rnn_impl(inp, weight_ih, weight_hh, bias_ih, bias_hh, h_prev,
                weight_ih_reverse, weight_hh_reverse, bias_ih_reverse,
                bias_hh_reverse, h_prev_reverse):
    bs, seq_len, input_size = inp.shape
    h_dim = weight_ih.shape[0]
    # 初始化一个输出矩阵,注意隐藏层维度要乘以2
    h_out = torch.zeros(bs, seq_len, h_dim * 2)
    # 调用两遍RNN
    forward_output, _ = rnn_impl(inp, weight_ih, weight_hh,
                              bias_ih, bias_hh, h_prev)
    # flip 对input张量seq_len维度进行翻转
    backward_output, _ = rnn_impl(torch.flip(inp, [1]), weight_ih_reverse,
             weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse)
    # 将f和b output 填充到h_out中
    h_out[:, :, :h_dim] = forward_output
    # 注意要对backward_output进行seq翻转才能对应上seq序列
    h_out[:, :, h_dim:] = torch.flip(backward_output, [1])
    # 取最后一个时刻 T=-1 状态向量hn，注意hn shape=[D*layer, bs, h_dim]
    hn = h_out[:, -1, :].reshape([bs, 2, h_dim]).transpose(0, 1)
    return h_out, hn


def test_bi_rnn_impl():
    bs, seq_len = 2, 3
    input_size, hidden_size = 2, 3
    # 随机初始化一个输入
    inp = torch.randn(bs, seq_len, input_size)
    # 初始隐藏状态
    h_prev = torch.zeros(2, bs, hidden_size)
    bi_rnn = nn.RNN(input_size, hidden_size, bidirectional=True, batch_first=True)
    res1, h_n1 = bi_rnn(inp, h_prev)
    # 取出RNN中的参数
    for parameter, name in bi_rnn.named_parameters():
        print(parameter, name)
    print("=========================")
    weight_ih = bi_rnn.weight_ih_l0
    weight_hh = bi_rnn.weight_hh_l0
    bias_ih = bi_rnn.bias_ih_l0
    bias_hh = bi_rnn.bias_hh_l0
    weight_ih_reverse = bi_rnn.weight_ih_l0_reverse
    weight_hh_reverse = bi_rnn.weight_hh_l0_reverse
    bias_ih_reverse = bi_rnn.bias_ih_l0_reverse
    bias_hh_reverse = bi_rnn.bias_hh_l0_reverse
    res2, hn2 = bi_rnn_impl(inp,
                            weight_ih,
                            weight_hh,
                            bias_ih,
                            bias_hh,
                            h_prev[0],
                            weight_ih_reverse,
                            weight_hh_reverse,
                            bias_ih_reverse,
                            bias_hh_reverse,
                            h_prev[1])
    print(res1)
    print(res2)
    print(torch.allclose(res1, res2))


test_bi_rnn_impl()

weight_ih_l0 Parameter containing:
tensor([[-0.0896,  0.2071],
        [ 0.1370,  0.1086],
        [ 0.0022, -0.2868]], requires_grad=True)
weight_hh_l0 Parameter containing:
tensor([[ 0.2489,  0.5583,  0.0484],
        [ 0.0950, -0.0650, -0.5075],
        [-0.2816,  0.3662,  0.2736]], requires_grad=True)
bias_ih_l0 Parameter containing:
tensor([0.3195, 0.5765, 0.5038], requires_grad=True)
bias_hh_l0 Parameter containing:
tensor([ 0.0648, -0.1833, -0.1685], requires_grad=True)
weight_ih_l0_reverse Parameter containing:
tensor([[-0.4970, -0.4178],
        [ 0.0177, -0.1048],
        [ 0.5066,  0.3071]], requires_grad=True)
weight_hh_l0_reverse Parameter containing:
tensor([[ 0.4729,  0.2107,  0.0673],
        [ 0.2862,  0.3941,  0.4727],
        [-0.3802,  0.0082,  0.1022]], requires_grad=True)
bias_ih_l0_reverse Parameter containing:
tensor([-0.1800,  0.1912, -0.2034], requires_grad=True)
bias_hh_l0_reverse Parameter containing:
tensor([-0.4098, -0.2872, -0.2664], requires_grad=True)
t