In [1]:
import torch
import torch.nn as nn

# 定义常量
bs, T, i_size, h_size = 2, 3, 4, 5 #batch_size  Time  input_size hidden_size
proj_size = 3 #投影大小
input = torch.randn(bs, T, i_size)
print(input.shape)
c0 = torch.randn(bs, h_size) # 初始值 不参与训练
h0 = torch.randn(bs, proj_size) # 初始值 不参与训练

# 调用官方API
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True, proj_size=proj_size)
output, (h_n, c_n) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))
print(output)

for k, v in lstm_layer.named_parameters():
    print(k, v.shape)

torch.Size([2, 3, 4])
tensor([[[-0.0338,  0.1641,  0.2356],
         [-0.1276,  0.1819,  0.3085],
         [-0.0627,  0.1581,  0.1736]],

        [[-0.0144,  0.1480, -0.0043],
         [ 0.0130,  0.1638,  0.0774],
         [ 0.0719,  0.1236,  0.0584]]], grad_fn=<TransposeBackward0>)
weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])


In [11]:
# 自己写一个LSTM
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr=None):
    h0, c0 = initial_states # 初始状态
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] # 20
    h_size = h_size // 4 # 分为4份
    
    prev_h = h0 # 不停更新h0 c0
    prev_c = c0 
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1) # bs * 4*h_size * i_size
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1) # bs * 4*h_size * h_size
    
    if w_hr is not None:
        p_size, _ = w_hr.shape
        output_size = p_size
        batch_w_hr = w_hr.unsqueeze(0).tile(bs, 1, 1)
    else:
        output_size = h_size
    output = torch.zeros(bs, T, output_size) # 输出序列
    
    for t in range(T):
        x = input[:, t, :] # 当前时刻的输入向量 bs * i_size
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1)) # bs * 4*h_size * 1
        w_times_x = w_times_x.squeeze(-1) # bs * 4*h_size
        
        w_times_h = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1)) # bs * 4*h_size * 1
        w_times_h = w_times_h.squeeze(-1) # bs * 4*h_size
        
        # 分别各取四分之一
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h[:, :h_size] + b_ih[:h_size] + b_hh[:h_size]) # 取前四分之一
        f_t = torch.sigmoid(w_times_x[:, h_size: 2*h_size] + w_times_h[:, h_size: 2*h_size] + b_ih[h_size: 2*h_size] + b_hh[h_size: 2*h_size])
        g_t = torch.tanh(w_times_x[:, 2*h_size: 3*h_size] + w_times_h[:, 2*h_size: 3*h_size] + b_ih[2*h_size: 3*h_size] + b_hh[2*h_size: 3*h_size])
        o_t = torch.sigmoid(w_times_x[:, 3*h_size:] + w_times_h[:, 3*h_size:] + b_ih[3*h_size:] + b_hh[3*h_size:])
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)
        
        if w_hr is not None:  # 进行projection 对维度进行压缩
            prev_h = torch.bmm(batch_w_hr, prev_h.unsqueeze(-1))
            prev_h = prev_h.squeeze(-1)
        
        output[:, t, :] = prev_h
        
    return output, (prev_h, prev_c)
 
    
    
output_custom, (h_0_custom, c_0_custom) = lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0, lstm_layer.weight_hh_l0,\
                                                       lstm_layer.bias_ih_l0, lstm_layer.bias_hh_l0, lstm_layer.weight_hr_l0)

print(output_custom)

tensor([[[-2.0233e-01, -1.6884e-01, -8.2945e-02],
         [ 7.6853e-03,  8.5598e-02, -6.2300e-02],
         [ 1.3547e-02, -2.7407e-03, -1.4973e-04]],

        [[-1.3269e-01, -9.3741e-02, -8.3748e-02],
         [-2.5617e-02,  1.6450e-02, -2.7930e-02],
         [ 4.7940e-02,  9.6475e-02, -5.1577e-02]]], grad_fn=<CopySlices>)
