In [1]:
import torch
import torch.nn as nn

In [14]:
bs, T, i_size, h_size = 2, 3, 4, 5
proj_size = 3
input = torch.randn(bs, T, i_size) #输入序列
c0 = torch.randn(bs, h_size) #初始值，不需要训练
h0 = torch.randn(bs, proj_size)

#调用官方 LSMT api
lstm_layer = nn.LSTM(i_size, h_size, batch_first = True, proj_size = proj_size)
output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output)
for k, v in lstm_layer.named_parameters():
    print(k, v.shape)

tensor([[[ 0.0009, -0.0547, -0.2468],
         [-0.1161, -0.0975, -0.1596],
         [-0.1372, -0.1323, -0.1379]],

        [[-0.0529, -0.0406, -0.0500],
         [-0.0169,  0.0008,  0.0691],
         [-0.1009, -0.0772, -0.0245]]], grad_fn=<TransposeBackward0>)
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])


In [11]:
# 自己写一个LSTM
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr = None):
    h0, c0 = initial_states 
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // i_size
    
    prev_h = h0
    prev_c = c0
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # [bs, 4*h_size, i_size]   
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # [bs, 4*h_size, h_size]   
    
    if w_hr is not None:
        p_size, _ = w_hr.shape
        output_size = p_size
        batch_w_hr = w_hr.unsqueeze(0).tile(bs, 1, 1) # bs p_size h_size
    else:
        output_size = h_size
    output = torch.zeros(bs, T, output_size)#输出序列
    
    
    for t in range(T):
        x = input[:,t,:] #当前时刻的输入向量 [bs, i_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # bs, 4*h_size, 1
        w_times_x = w_times_x.squeeze(-1) # bs, 4*h_size
        
        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1)) # bs, 4*h_size, 1
        w_times_h_prev = w_times_h_prev.squeeze(-1) # bs, 4*h_size
        
        #分别计算输入门（i） 遗忘门(f) cell门(g) 输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size] + b_ih[:h_size] +b_hh[:h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size: 2*h_size] + w_times_h_prev[:, h_size:2*h_size] \
                            + b_ih[h_size:2*h_size] +b_hh[h_size:2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size:3*h_size] + w_times_h_prev[:, 2*h_size:3*h_size] \
                            + b_ih[2*h_size:3*h_size] +b_hh[2*h_size:3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:4*h_size] + w_times_h_prev[:, 3*h_size:4*h_size] \
                            + b_ih[3*h_size:4*h_size] +b_hh[3*h_size:4*h_size])
        
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c) # bs, h_size 
        
        if w_hr is not None:
            prev_h = torch.bmm(batch_w_hr, prev_h.unsqueeze(-1))
            prev_h = prev_h.squeeze(-1)
        
        output[:, t, :] = prev_h
    return output, (prev_h, prev_c)


output_custom , (h_final_custom, c_final_custom) = lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0, lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0, lstm_layer.weight_hr_l0)
        
print(output)
print(output_custom)

tensor([[[ 0.3073, -0.0818, -0.3247, -0.1339, -0.0454],
         [ 0.0115, -0.0052, -0.0490, -0.1998,  0.1117],
         [ 0.2132,  0.0567, -0.0906, -0.2161,  0.0475]],

        [[ 0.3088,  0.1451,  0.5611,  0.0116,  0.1632],
         [ 0.1782,  0.0421,  0.2567, -0.0816,  0.2803],
         [ 0.0659,  0.0565, -0.1242, -0.2200,  0.0616]]],
       grad_fn=<TransposeBackward0>)
tensor([[[ 0.3073, -0.0818, -0.3247, -0.1339, -0.0454],
         [ 0.0115, -0.0052, -0.0490, -0.1998,  0.1117],
         [ 0.2132,  0.0567, -0.0906, -0.2161,  0.0475]],

        [[ 0.3088,  0.1451,  0.5611,  0.0116,  0.1632],
         [ 0.1782,  0.0421,  0.2567, -0.0816,  0.2803],
         [ 0.0659,  0.0565, -0.1242, -0.2200,  0.0616]]], grad_fn=<CopySlices>)
