## LSTM与LSTMP

In [1]:
import torch
import torch.nn as nn

定义常量

In [2]:
bs, T, i_size, h_size = 2, 3, 4, 5
proj_size = 3

input = torch.randn(bs, T, i_size)  # 输入序列
c0 = torch.randn(bs, h_size)  # 初始值 不需要训练
h0 = torch.randn(bs, h_size)

### 调用官方API

In [3]:
lstm_layer = nn.LSTM(i_size, h_size, batch_first=True)

output, (h_final, c_final) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))

print(output)
print(output.shape, h_final.shape, c_final.shape)

tensor([[[-0.7186,  0.1966,  0.6559, -0.1657,  0.0615],
         [-0.2585, -0.0097,  0.3003, -0.1403,  0.0449],
         [-0.3948,  0.0111,  0.1536, -0.1586,  0.0535]],

        [[ 0.0563,  0.0184,  0.0143,  0.1093,  0.3215],
         [ 0.3571,  0.1254,  0.3315,  0.0475, -0.1405],
         [ 0.1477,  0.1639,  0.0257,  0.0627, -0.1628]]],
       grad_fn=<TransposeBackward0>)
torch.Size([2, 3, 5]) torch.Size([1, 2, 5]) torch.Size([1, 2, 5])


In [4]:
for k, v in lstm_layer.named_parameters():
    print(k, v.shape)

weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])


### 实现LSTM模型

In [5]:
def lstm_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh):
    h0, c0 = initial_states  # 初始状态
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 4

    prev_h = h0
    prev_c = c0
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)  # [bs, 4 * h_size, i_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)  # [bs, 4 * h_size, p_size]

    output_size = h_size
    output = torch.zeros(bs, T, output_size) # 输出序列

    for t in range(T):
        x = input[:, t, :]  # 当前时刻的输入向量 [bs, i_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))  # [bs, 4 * h_size, 1]
        w_times_x = w_times_x.squeeze(-1)  # [bs, 4 * h_size]

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))  # [bs, 4 * h_size, 1]
        w_times_h_prev = w_times_h_prev.squeeze(-1)  # [bs, 4 * h_size]

        # 分别计算输入门(i)、遗忘门(f)、cell门(g)、输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size] + b_ih[:h_size] + b_hh[:h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size:2 * h_size] + w_times_h_prev[:, h_size:2 * h_size]
                            + b_ih[h_size:2 * h_size] + b_hh[h_size:2 * h_size])
        g_t = torch.tanh(w_times_x[:, 2 * h_size:3 * h_size] + w_times_h_prev[:, 2 * h_size:3 * h_size]
                            + b_ih[2 * h_size:3 * h_size] + b_hh[2 * h_size:3 * h_size])
        o_t = torch.sigmoid(w_times_x[:, 3 * h_size:4 * h_size] + w_times_h_prev[:, 3 * h_size:4 * h_size]
                            + b_ih[3 * h_size:4 * h_size] + b_hh[3 * h_size:4 * h_size])

        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)

        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)

In [6]:
output_custom, (h_final_custom, c_final_custom) = lstm_forward(input, (h0, c0), lstm_layer.weight_ih_l0,
                                                               lstm_layer.weight_hh_l0, lstm_layer.bias_ih_l0,
                                                               lstm_layer.bias_hh_l0)
print(output_custom)

tensor([[[-0.7186,  0.1966,  0.6559, -0.1657,  0.0615],
         [-0.2585, -0.0097,  0.3003, -0.1403,  0.0449],
         [-0.3948,  0.0111,  0.1536, -0.1586,  0.0535]],

        [[ 0.0563,  0.0184,  0.0143,  0.1093,  0.3215],
         [ 0.3571,  0.1254,  0.3315,  0.0475, -0.1405],
         [ 0.1477,  0.1639,  0.0257,  0.0627, -0.1628]]], grad_fn=<CopySlices>)


### 调用官方API

In [7]:
h0 = torch.randn(bs, proj_size)
lstmp_layer = nn.LSTM(i_size, h_size, batch_first=True, proj_size=proj_size)

output, (h_final, c_final) = lstmp_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))

print(output)
print(output.shape, h_final.shape, c_final.shape)

tensor([[[ 0.0429, -0.0155, -0.1125],
         [ 0.0399, -0.0240, -0.1188],
         [ 0.0226,  0.0233, -0.0486]],

        [[ 0.0499,  0.1909,  0.1218],
         [-0.0208,  0.0090,  0.1503],
         [ 0.0155,  0.0440,  0.1079]]], grad_fn=<TransposeBackward0>)
torch.Size([2, 3, 3]) torch.Size([1, 2, 3]) torch.Size([1, 2, 5])


In [8]:
for k, v in lstmp_layer.named_parameters():
    print(k, v.shape)

weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])


### 实现LSTMP

In [9]:

def lstmp_forward(input, initial_states, w_ih, w_hh, b_ih, b_hh, w_hr = None):
    global batch_w_hr
    h0, c0 = initial_states  # 初始状态
    bs, T, i_size = input.shape
    h_size = w_ih.shape[0] // 4

    prev_h = h0
    prev_c = c0
    batch_w_ih = w_ih.unsqueeze(0).tile(bs, 1, 1)  # [bs, 4 * h_size, i_size]
    batch_w_hh = w_hh.unsqueeze(0).tile(bs, 1, 1)  # [bs, 4 * h_size, h_size]

    if w_hr is not None:
        p_size = w_hr.shape[0]
        output_size = p_size
        batch_w_hr = w_hr.unsqueeze(0).tile(bs, 1, 1)  # [bs, p_size, h_size]
    else:
        output_size = h_size
    output = torch.zeros(bs, T, output_size) # 输出序列

    for t in range(T):
        x = input[:, t, :]  # 当前时刻的输入向量 [bs, i_size]
        w_times_x = torch.bmm(batch_w_ih, x.unsqueeze(-1))  # [bs, 4 * h_size, 1]
        w_times_x = w_times_x.squeeze(-1)  # [bs, 4 * h_size]

        w_times_h_prev = torch.bmm(batch_w_hh, prev_h.unsqueeze(-1))  # [bs, 4 * h_size, 1]
        w_times_h_prev = w_times_h_prev.squeeze(-1)  # [bs, 4 * h_size]

        # 分别计算输入门(i)、遗忘门(f)、cell门(g)、输出门(o)
        i_t = torch.sigmoid(w_times_x[:, :h_size] + w_times_h_prev[:, :h_size] + b_ih[:h_size] + b_hh[:h_size])
        f_t = torch.sigmoid(w_times_x[:, h_size:2 * h_size] + w_times_h_prev[:, h_size:2 * h_size]
                            + b_ih[h_size:2 * h_size] + b_hh[h_size:2 * h_size])
        g_t = torch.tanh(w_times_x[:, 2 * h_size:3 * h_size] + w_times_h_prev[:, 2 * h_size:3 * h_size]
                            + b_ih[2 * h_size:3 * h_size] + b_hh[2 * h_size:3 * h_size])
        o_t = torch.sigmoid(w_times_x[:, 3 * h_size:4 * h_size] + w_times_h_prev[:, 3 * h_size:4 * h_size]
                            + b_ih[3 * h_size:4 * h_size] + b_hh[3 * h_size:4 * h_size])

        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)  # [bs, h_size]

        if w_hr is not None:  # 做projection
            prev_h = torch.bmm(batch_w_hr, prev_h.unsqueeze(-1))  # [bs, p_size, 1]
            prev_h = prev_h.squeeze(-1)  # [bs, p_size]

        output[:, t, :] = prev_h

    return output, (prev_h, prev_c)

In [10]:
output_custom, (h_final_custom, c_final_custom) = lstmp_forward(input, (h0, c0), lstmp_layer.weight_ih_l0,
                                                               lstmp_layer.weight_hh_l0, lstmp_layer.bias_ih_l0,
                                                               lstmp_layer.bias_hh_l0, lstmp_layer.weight_hr_l0)
print(output_custom)

tensor([[[ 0.0429, -0.0155, -0.1125],
         [ 0.0399, -0.0240, -0.1188],
         [ 0.0226,  0.0233, -0.0486]],

        [[ 0.0499,  0.1909,  0.1218],
         [-0.0208,  0.0090,  0.1503],
         [ 0.0155,  0.0440,  0.1079]]], grad_fn=<CopySlices>)
