In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [20]:
def lstm_forward(inp, initial_states, w_ih, w_hh, b_ih, b_hh):
    """
    input [bs, T, input_size]
    """
    h0, c0 = initial_states
    bs, seq_len, i_size = inp.shape
    h_size = w_ih.shape[0] // 4 # w_ih [4 * h_dim, i_size]
    bw_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # [bs, 4 * h_dim, i_size]
    bw_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # [bs, 4 * h_dim, h_dim]
    prev_h = h0 # [bs, h_d]
    prev_c = c0
    output_size = h_size
    output = torch.randn(bs, seq_len, output_size)

    # 对时间进行遍历
    for t in range(seq_len):
        x = inp[:, t, :] # [bs, input_size]
        # 为了能进行bmm，对x增加一维 [bs, i_s, 1]
        w_times_x = torch.bmm(bw_ih, x.unsqueeze(-1)).squeeze(-1) # [bs, 4h_d]
        w_times_h_prev = torch.bmm(bw_hh, prev_h.unsqueeze(-1)).squeeze(-1) # [bs, 4h_d]
        # 计算i门，取矩阵的前1/4
        i = 0
        i_t = torch.sigmoid(w_times_x[:, h_size*i:h_size*(1+i)] + w_times_h_prev[:,h_size*i:h_size*(1+i)] + b_ih[h_size*i:h_size*(1+i)] + b_hh[h_size*i:h_size*(1+i)])
        # f门
        i += 1
        f_t = torch.sigmoid(w_times_x[:, h_size*i:h_size*(1+i)] + w_times_h_prev[:, h_size*i:h_size*(1+i)] + b_ih[h_size*i:h_size*(1+i)] + b_hh[h_size*i:h_size*(1+i)])
        # g门
        i += 1
        g_t = torch.tanh(w_times_x[:, h_size*i:h_size*(1+i)] + w_times_h_prev[:, h_size*i:h_size*(1+i)] + b_ih[h_size*i:h_size*(1+i)] + b_hh[h_size*i:h_size*(1+i)])
        # o门
        i += 1
        o_t = torch.sigmoid(w_times_x[:, h_size*i:h_size*(1+i)] + w_times_h_prev[:, h_size*i:h_size*(1+i)] + b_ih[h_size*i:h_size*(1+i)] + b_hh[h_size*i:h_size*(1+i)])

        # cell
        prev_c = f_t * prev_c + i_t * g_t

        # h
        prev_h = o_t * torch.tanh(prev_c)

        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)


def test_lstm_impl():
    bs, t, i_size, h_size = 2, 3, 4, 5
    inp = torch.randn(bs, t, i_size)
    # 不需要训练
    c0 = torch.randn(bs, h_size)
    h0 = torch.randn(bs, h_size)

    # 调用官方API
    lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)
    output, _ = lstm_layer(inp, (h0.unsqueeze(0), c0.unsqueeze(0)))
    for k, v in lstm_layer.named_parameters():
        print(k, "# #", v.shape)

    print("++++++++++++++++++++++++++++++++++++++")

    w_ih = lstm_layer.weight_ih_l0
    w_hh = lstm_layer.weight_hh_l0
    b_ih = lstm_layer.bias_ih_l0
    b_hh = lstm_layer.bias_hh_l0

    output2, _ = lstm_forward(inp, (h0, c0), w_ih, w_hh, b_ih, b_hh)
    print(torch.allclose(output2, output))
    print(output)
    print(output2)

test_lstm_impl()

weight_ih_l0 # # torch.Size([20, 4])
weight_hh_l0 # # torch.Size([20, 5])
bias_ih_l0 # # torch.Size([20])
bias_hh_l0 # # torch.Size([20])
++++++++++++++++++++++++++++++++++++++
True
tensor([[[-0.0253, -0.2863,  0.0269, -0.1507, -0.1089],
         [ 0.0480, -0.0321,  0.0813, -0.1903, -0.1225],
         [ 0.0992, -0.1596,  0.1424, -0.2368, -0.0450]],

        [[-0.2811,  0.2649, -0.1801, -0.1610,  0.6522],
         [-0.1713,  0.2637, -0.0560, -0.0387,  0.2390],
         [-0.0578,  0.1238, -0.0354, -0.2534, -0.0732]]],
       grad_fn=<TransposeBackward0>)
tensor([[[-0.0253, -0.2863,  0.0269, -0.1507, -0.1089],
         [ 0.0480, -0.0321,  0.0813, -0.1903, -0.1225],
         [ 0.0992, -0.1596,  0.1424, -0.2368, -0.0450]],

        [[-0.2811,  0.2649, -0.1801, -0.1610,  0.6522],
         [-0.1713,  0.2637, -0.0560, -0.0387,  0.2390],
         [-0.0578,  0.1238, -0.0354, -0.2534, -0.0732]]], grad_fn=<CopySlices>)


In [33]:
def lstm_forward(inp, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):
    """
    input [bs, T, input_size]
    如果w_hr不是None说明是带projection
    w_hr [p_dim, h_dim]
    """
    h0, c0 = initial_states
    bs, seq_len, i_size = inp.shape
    h_size = w_ih.shape[0] // 4 # w_ih [4 * h_dim, i_size]
    bw_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # [bs, 4 * h_dim, i_size]
    bw_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # [bs, 4 * h_dim, h_dim]
    prev_h = h0 # [bs, h_d]
    prev_c = c0
    if w_hr is not None:
        output_size =  w_hr.shape[0]
        bw_hr = w_hr.unsqueeze(0).tile(bs, 1, 1)
    else:
        output_size = h_size
        bw_hr = None


    output = torch.randn(bs, seq_len, output_size)

    # 对时间进行遍历
    for t in range(seq_len):
        x = inp[:, t, :] # [bs, input_size]
        # 为了能进行bmm，对x增加一维 [bs, i_s, 1]
        w_times_x = torch.bmm(bw_ih, x.unsqueeze(-1)).squeeze(-1) # [bs, 4h_d]
        w_times_h_prev = torch.bmm(bw_hh, prev_h.unsqueeze(-1)).squeeze(-1) # [bs, 4h_d]
        # 计算i门，取矩阵的前1/4
        i = 0
        i_t = torch.sigmoid(w_times_x[:, h_size*i:h_size*(1+i)] + w_times_h_prev[:,h_size*i:h_size*(1+i)] + b_ih[h_size*i:h_size*(1+i)] + b_hh[h_size*i:h_size*(1+i)])
        # f门
        i += 1
        f_t = torch.sigmoid(w_times_x[:, h_size*i:h_size*(1+i)] + w_times_h_prev[:, h_size*i:h_size*(1+i)] + b_ih[h_size*i:h_size*(1+i)] + b_hh[h_size*i:h_size*(1+i)])
        # g门
        i += 1
        g_t = torch.tanh(w_times_x[:, h_size*i:h_size*(1+i)] + w_times_h_prev[:, h_size*i:h_size*(1+i)] + b_ih[h_size*i:h_size*(1+i)] + b_hh[h_size*i:h_size*(1+i)])
        # o门
        i += 1
        o_t = torch.sigmoid(w_times_x[:, h_size*i:h_size*(1+i)] + w_times_h_prev[:, h_size*i:h_size*(1+i)] + b_ih[h_size*i:h_size*(1+i)] + b_hh[h_size*i:h_size*(1+i)])

        # cell
        prev_c = f_t * prev_c + i_t * g_t

        # h
        prev_h = o_t * torch.tanh(prev_c) # [bs, h_size]

        # 对h进行压缩
        if w_hr is not None:
            prev_h = torch.bmm(bw_hr, prev_h.unsqueeze(-1)).squeeze(-1) # [bs, p_size]


        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)

def test_lstmp_impl():
    bs, t, i_size, h_size = 2, 3, 4, 5
    proj_size = 3
    inp = torch.randn(bs, t, i_size)
    # 不需要训练
    c0 = torch.randn(bs, h_size)
    h0 = torch.randn(bs, proj_size)

    # 调用官方API
    lstm_layer = nn.LSTM(i_size, h_size, batch_first=True, proj_size=proj_size)
    output, _ = lstm_layer(inp, (h0.unsqueeze(0), c0.unsqueeze(0)))
    for k, v in lstm_layer.named_parameters():
        print(k, "# #", v.shape)

    print("++++++++++++++++++++++++++++++++++++++")

    w_ih = lstm_layer.weight_ih_l0
    w_hh = lstm_layer.weight_hh_l0 # [bs, p_size] p_size相比h_d变小了
    b_ih = lstm_layer.bias_ih_l0
    b_hh = lstm_layer.bias_hh_l0
    w_hr = lstm_layer.weight_hr_l0
    output2, _ = lstm_forward(inp, (h0, c0), w_ih, w_hh, b_ih, b_hh, w_hr)
    print(torch.allclose(output2, output))
    print(output.shape)
    print(output2.shape) # [bs, seq, p_size]

test_lstmp_impl()

weight_ih_l0 # # torch.Size([20, 4])
weight_hh_l0 # # torch.Size([20, 3])
bias_ih_l0 # # torch.Size([20])
bias_hh_l0 # # torch.Size([20])
weight_hr_l0 # # torch.Size([3, 5])
++++++++++++++++++++++++++++++++++++++
True
torch.Size([2, 3, 3])
torch.Size([2, 3, 3])
