# RNN 的原理及其手写复现

+ 视频：[29、PyTorch RNN的原理及其手写复现](https://www.bilibili.com/video/BV13i4y1R7jB/)
+ 视频：[30、PyTorch LSTM和LSTMP的原理及其手写复现](https://www.bilibili.com/video/BV1zq4y1m7aH/)
+ [Gated RNN | yubinCloud](https://yubincloud.github.io/notebook/pages/nlp/gated-rnn/)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from loguru import logger
from typing import Tuple

## 1. PyTorch 的使用示例

+ [PyTorch RNN 官方文档](https://pytorch.org/docs/stable/generated/torch.nn.RNN.html)

### 1.1 单向、单层 RNN

![RNN 示例](../imgs/rnn.png)


In [2]:
BATCH_SIZE = 2
SEQ_LEN = 4  # 输入序列长度
INPUT_FEATURE_SIZE = 5  # 输入的 feature 大小
HIDDEN_SIZE = 3


single_rnn = nn.RNN(INPUT_FEATURE_SIZE, HIDDEN_SIZE, num_layers=1, batch_first=True)

nn.RNN 的输出：

+ `output`：各个时刻的 hidden state，shape  为 [B, seq_len, num_directions * hidden]
    + 当使用双向时，在 output 最后一维的 num_directions * hidden 元素中，前 hidden 个属于前向 RNN 的结果，后 hidden 个属于反向 RNN 的结果
+ `final_state`：最后一个时刻的最终 hidden state，当只有一层 RNN 时，它也就是 `output` 的最后一个结果

对于 many-to-many 的 task，往往是使用 output，比如词性标注任务；对于 many-to-one 的 task，往往是使用 final_state，比如文本分类任务。

In [3]:
input = torch.randn(BATCH_SIZE, SEQ_LEN, INPUT_FEATURE_SIZE)  # batch_size * seq_len * feature_size
output, final_state = single_rnn(input)
logger.info(f'output:\n{output}')  # [B, seq_len, num_directions * hidden]
logger.info(f'final_state:\n{final_state}')  # [B, num_directions * num_layers, hidden]

2023-01-24 20:34:58.737 | INFO     | __main__:<module>:3 - output:
tensor([[[-0.6163, -0.0473,  0.4505],
         [-0.9223, -0.8221,  0.2352],
         [ 0.5787,  0.8265,  0.5624],
         [-0.3236, -0.0832,  0.7732]],

        [[-0.8881, -0.7515, -0.0158],
         [-0.0427,  0.5174,  0.1675],
         [-0.7324, -0.1763,  0.3568],
         [-0.6012, -0.1244,  0.0937]]], grad_fn=<TransposeBackward1>)
2023-01-24 20:34:58.740 | INFO     | __main__:<module>:4 - final_state:
tensor([[[-0.3236, -0.0832,  0.7732],
         [-0.6012, -0.1244,  0.0937]]], grad_fn=<StackBackward0>)


从上面的结果中可以看出，简单 RNN 的最后时刻 output 就等于最终的 hidden state。

### 1.2 双向、单层 RNN

主要是在实例化 `nn.RNN` 时设置 `bidirectional=True`。

In [4]:
bi_rnn = nn.RNN(INPUT_FEATURE_SIZE, HIDDEN_SIZE, num_layers=1, batch_first=True, bidirectional=True)

In [5]:
output, final_state = bi_rnn(input)
logger.info(f'output:\n{output}')
logger.info(f'output shape: {output.shape}')  # [B, seq_len, num_directions * hidden]
logger.info(f'final_state:\n{final_state}')
logger.info(f'final_state shape: {final_state.shape}')  # [num_directions * num_layers, B, hidden]

2023-01-24 20:35:00.397 | INFO     | __main__:<module>:2 - output:
tensor([[[ 0.9410,  0.4722, -0.5096, -0.4159,  0.2891, -0.3208],
         [ 0.9068,  0.2262,  0.1875, -0.6178, -0.1427, -0.5753],
         [ 0.8753,  0.4265, -0.6604, -0.6597,  0.1480,  0.3321],
         [ 0.9202,  0.7293, -0.4818, -0.8707, -0.2412,  0.0443]],

        [[ 0.6369,  0.1251, -0.4725,  0.0477, -0.4647,  0.0939],
         [ 0.8985,  0.3284, -0.3878, -0.3990, -0.0277,  0.8199],
         [ 0.7848,  0.1551, -0.0625, -0.4547, -0.1110, -0.1604],
         [ 0.4576,  0.4690, -0.5791, -0.5404, -0.5316,  0.3072]]],
       grad_fn=<TransposeBackward1>)
2023-01-24 20:35:00.399 | INFO     | __main__:<module>:3 - output shape: torch.Size([2, 4, 6])
2023-01-24 20:35:00.402 | INFO     | __main__:<module>:4 - final_state:
tensor([[[ 0.9202,  0.7293, -0.4818],
         [ 0.4576,  0.4690, -0.5791]],

        [[-0.4159,  0.2891, -0.3208],
         [ 0.0477, -0.4647,  0.0939]]], grad_fn=<StackBackward0>)
2023-01-24 20:35:00.403

## 2. 单层单向 RNN 的逐行实现

$h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh})$

In [6]:
# 看一下 PyTorch 中 RNN 的参数：
for k,v in single_rnn.named_parameters():
    print(k, v)

weight_ih_l0 Parameter containing:
tensor([[-0.4731, -0.2062,  0.1367, -0.3421, -0.4200],
        [-0.4495, -0.4266,  0.3260, -0.1161, -0.4985],
        [-0.0610,  0.1733, -0.3137, -0.5572, -0.4128]], requires_grad=True)
weight_hh_l0 Parameter containing:
tensor([[ 0.2235, -0.5123,  0.0136],
        [-0.4945,  0.2558, -0.3320],
        [ 0.5721, -0.4248, -0.3409]], requires_grad=True)
bias_ih_l0 Parameter containing:
tensor([ 0.2150, -0.2177,  0.1753], requires_grad=True)
bias_hh_l0 Parameter containing:
tensor([-0.5560,  0.4096,  0.4706], requires_grad=True)


### 2.1 逐行实现 RNN

In [7]:
# 逐行实现 RNN 的前向传播过程
def rnn_forward(
    input: Tensor,  # [B, T, input_size]
    weight_ih: Tensor,  # [hidden, input_size]
    weight_hh: Tensor,  # [hidden, hidden]
    bias_ih: Tensor,  # [h_dim]
    bias_hh: Tensor,  # [h_dim]
    h_prev: Tensor,  # 前一时刻的 hidden state, [B, hidden]
) -> Tuple[Tensor, Tensor]:
    bs, T, input_size = input.shape
    h_dim = weight_ih.shape[0]  # 这个维度是根据公式来判断的
    
    h_out = torch.zeros(bs, T, h_dim)  # 初始化一个输出状态矩阵
    for t in range(T):
        x = input[:, t, :]  # 获取当前时刻的输入 feature, [bs, input_size]
        x = x.unsqueeze(2)  # [bs, input_size, 1]
        h_prev = h_prev.unsqueeze(2)  # [B, hidden, 1]
        w_ih_batch = weight_ih.unsqueeze(0).tile([bs, 1, 1])  # [bs, h_dim, input_size]
        w_hh_batch = weight_hh.unsqueeze(0).tile([bs, 1, 1])  # [bs, h_dim, h_dim]
        
        w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1)  # [bs, h_dim]
        w_times_h = torch.bmm(w_hh_batch, h_prev).squeeze(-1)  # [bs, h_dim]
        h_prev = torch.tanh(w_times_x + bias_ih + w_times_h + bias_hh)
        
        h_out[:, t, :] = h_prev
    
    return h_out, h_prev.unsqueeze(0)

### 2.2 结果验证

通过与 PyTorch 官方实现的运算结果进行对比，验证 RNN 的实现

In [8]:
# 验证一下 rnn_forward 的正确性
h_prev = torch.zeros(BATCH_SIZE, HIDDEN_SIZE)  # 初始 hidden state
output1, final_state1 = rnn_forward(
    input, 
    single_rnn.weight_ih_l0,
    single_rnn.weight_hh_l0,
    single_rnn.bias_ih_l0,
    single_rnn.bias_hh_l0,
    h_prev
)
output2, final_state2 = single_rnn(input, h_prev.unsqueeze(0))
logger.info(f'自己实现的 RNN 的 output:\n{output1}')
logger.info(f'PyTorch 的 RNN 的 output:\n{output2}')
logger.info(f'自己实现的 RNN 的 final_state:\n{final_state1}')
logger.info(f'PyTorch 的 RNN 的 final_state:\n{final_state2}')

2023-01-24 20:35:04.865 | INFO     | __main__:<module>:12 - 自己实现的 RNN 的 output:
tensor([[[-0.6163, -0.0473,  0.4505],
         [-0.9223, -0.8221,  0.2352],
         [ 0.5787,  0.8265,  0.5624],
         [-0.3236, -0.0832,  0.7732]],

        [[-0.8881, -0.7515, -0.0158],
         [-0.0427,  0.5174,  0.1675],
         [-0.7324, -0.1763,  0.3568],
         [-0.6012, -0.1244,  0.0937]]], grad_fn=<CopySlices>)
2023-01-24 20:35:04.869 | INFO     | __main__:<module>:13 - PyTorch 的 RNN 的 output:
tensor([[[-0.6163, -0.0473,  0.4505],
         [-0.9223, -0.8221,  0.2352],
         [ 0.5787,  0.8265,  0.5624],
         [-0.3236, -0.0832,  0.7732]],

        [[-0.8881, -0.7515, -0.0158],
         [-0.0427,  0.5174,  0.1675],
         [-0.7324, -0.1763,  0.3568],
         [-0.6012, -0.1244,  0.0937]]], grad_fn=<TransposeBackward1>)
2023-01-24 20:35:04.871 | INFO     | __main__:<module>:14 - 自己实现的 RNN 的 final_state:
tensor([[[-0.3236, -0.0832,  0.7732],
         [-0.6012, -0.1244,  0.0937]]], grad_

## 3. 单层双向 RNN 的逐行实现

### 3.1 逐行实现 RNN

In [9]:
def bidirectional_rnn_forward(
    input: Tensor,
    weight_ih: Tensor,
    weight_hh: Tensor,
    bias_ih: Tensor,
    bias_hh: Tensor,
    h_prev: Tensor,
    weihgt_ih_reverse: Tensor,
    weight_hh_reverse: Tensor,
    bias_ih_reverse: Tensor,
    bias_hh_reverse: Tensor,
    h_prev_reverse: Tensor
) -> Tuple[Tensor, Tensor]:
    NUM_DIRECTIONS = 2  # 表示双向
    bs, T, input_size = input.shape
    h_dim = weight_ih.shape[0]  # 这个维度是根据公式来判断的
    
    forward_output, _ = rnn_forward(input, weight_ih, weight_hh, bias_ih, bias_hh, h_prev)
    reverse_input = input.flip([1])  # 在 dim=1 上进行翻转
    backward_output, _ = rnn_forward(reverse_input, weihgt_ih_reverse, weight_hh_reverse, bias_ih_reverse, bias_hh_reverse, h_prev_reverse)
    
    h_out = torch.zeros(bs, T, h_dim * NUM_DIRECTIONS)  # 初始化一个输出状态矩阵，在最后一维上，前 h_dim 表示前向 RNN 的，后 h_dim 表示反向 RNN 的
    h_out[:, :, :h_dim] = forward_output
    h_out[:, :, h_dim:] = torch.flip(backward_output, [1])
    
    h_n = torch.zeros(bs, NUM_DIRECTIONS, h_dim)
    h_n[:, 0, :] = forward_output[:, -1, :]  # 前向 RNN 的最后时刻的 hidden state
    h_n[:, 1, :] = backward_output[:, -1, :]  # 前向 RNN 的最后时刻的 hidden state
    
    # 为保持与 PyTorch 输出形状一致，对 h_n 进行简单的变换
    h_n = h_n.transpose(0, 1)  # [num_directions, B, h_dim]
    return h_out, h_n

### 3.2 结果验证

In [10]:
# 先看一下 PyTorch 中的参数
for k, v in bi_rnn.named_parameters():
    print(k, v)

weight_ih_l0 Parameter containing:
tensor([[-0.1291,  0.2647, -0.0030,  0.2523, -0.3660],
        [-0.4363,  0.1191, -0.2971,  0.3472,  0.2682],
        [ 0.5282, -0.0040,  0.0744, -0.3637, -0.3977]], requires_grad=True)
weight_hh_l0 Parameter containing:
tensor([[-0.0840,  0.3124, -0.3386],
        [ 0.4872, -0.4758, -0.2817],
        [-0.0568, -0.2534, -0.4144]], requires_grad=True)
bias_ih_l0 Parameter containing:
tensor([ 0.5236, -0.4694, -0.0271], requires_grad=True)
bias_hh_l0 Parameter containing:
tensor([0.5643, 0.0778, 0.0066], requires_grad=True)
weight_ih_l0_reverse Parameter containing:
tensor([[ 0.3810, -0.5524,  0.5134,  0.4822, -0.0038],
        [ 0.0877, -0.0310,  0.1023,  0.0720, -0.3541],
        [-0.4617,  0.1002,  0.5680,  0.2013,  0.3023]], requires_grad=True)
weight_hh_l0_reverse Parameter containing:
tensor([[ 0.1870,  0.1410,  0.2251],
        [-0.1406, -0.1604, -0.1103],
        [ 0.5239,  0.1003, -0.1858]], requires_grad=True)
bias_ih_l0_reverse Parameter cont

In [11]:
NUM_DIRECTIONS = 2  # 双向
h_prev = torch.zeros(NUM_DIRECTIONS, BATCH_SIZE, HIDDEN_SIZE)
output1, final_state1 = bidirectional_rnn_forward(
    input,
    bi_rnn.weight_ih_l0,
    bi_rnn.weight_hh_l0,
    bi_rnn.bias_ih_l0,
    bi_rnn.bias_hh_l0,
    h_prev[0],
    bi_rnn.weight_ih_l0_reverse,
    bi_rnn.weight_hh_l0_reverse,
    bi_rnn.bias_ih_l0_reverse,
    bi_rnn.bias_hh_l0_reverse,
    h_prev[1]
)
output2, final_state2 = bi_rnn(input, h_prev)

logger.info(f'自己实现的 RNN 的 output:\n{output1}')
logger.info(f'PyTorch 的 RNN 的 output:\n{output2}')
logger.info(f'自己实现的 RNN 的 final_state:\n{final_state1}')
logger.info(f'PyTorch 的 RNN 的 final_state:\n{final_state2}')

2023-01-24 20:35:08.178 | INFO     | __main__:<module>:18 - 自己实现的 RNN 的 output:
tensor([[[ 0.9410,  0.4722, -0.5096, -0.4159,  0.2891, -0.3208],
         [ 0.9068,  0.2262,  0.1875, -0.6178, -0.1427, -0.5753],
         [ 0.8753,  0.4265, -0.6604, -0.6597,  0.1480,  0.3321],
         [ 0.9202,  0.7293, -0.4818, -0.8707, -0.2412,  0.0443]],

        [[ 0.6369,  0.1251, -0.4725,  0.0477, -0.4647,  0.0939],
         [ 0.8985,  0.3284, -0.3878, -0.3990, -0.0277,  0.8199],
         [ 0.7848,  0.1551, -0.0625, -0.4547, -0.1110, -0.1604],
         [ 0.4576,  0.4690, -0.5791, -0.5404, -0.5316,  0.3072]]],
       grad_fn=<CopySlices>)
2023-01-24 20:35:08.182 | INFO     | __main__:<module>:19 - PyTorch 的 RNN 的 output:
tensor([[[ 0.9410,  0.4722, -0.5096, -0.4159,  0.2891, -0.3208],
         [ 0.9068,  0.2262,  0.1875, -0.6178, -0.1427, -0.5753],
         [ 0.8753,  0.4265, -0.6604, -0.6597,  0.1480,  0.3321],
         [ 0.9202,  0.7293, -0.4818, -0.8707, -0.2412,  0.0443]],

        [[ 0.6369,  0

## 4. LSTM 手写实现

![LSTM 示意图](../imgs/LSTM.png)

计算公式：

+ 输入门：$i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1}+b_{hi})$
+ 遗忘门：$f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1}+b_{hf})$
+ cell 门：$g_t = \tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1}+b_{hg})$
+ 输出门：$o_t = \tanh(W_{io}x_t + b_{io} + W_{ho}h_{t-1}+b_{ho})$
+ 记忆单元的更新：$c_t = f_t \odot c_{t-1} + i_t \odot g_t$
+ 隐藏状态的更新：$h_t = o_t \odot \tanh(c_t)$

### 4.1 PyTorch 官方 API

+ [PyTorch LSTM 官方文档](https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html)

In [12]:
# 定义常量
BATCH_SIZE = 2
SEQ_LEN = 3
INPUT_SIZE = 4
HIDDEN_SIZE = 5

input = torch.randn(BATCH_SIZE, SEQ_LEN, INPUT_SIZE)  # 输入序列
c0 = torch.randn(BATCH_SIZE, HIDDEN_SIZE)  # 初始记忆单元，不会参与训练
h0 = torch.randn(BATCH_SIZE, HIDDEN_SIZE)  # 初始 hidden state

In [22]:
# 调用官方 API
lstm_layer = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, batch_first=True)
output, (h_final, c_final) = lstm_layer(
    input,
    (h0.unsqueeze(0), c0.unsqueeze(0))  # 调用 unsqueeze 是因为我们用的是单向的，需要符合 API 的 shape 要求
)
logger.info(f'LSTM 的 output shape: {output.shape}')   # [B, seq_len, h_dim]
logger.info(f'LSTM 的 h_final shape: {h_final.shape}') # [1, B, h_dim]
logger.info(f'LSTM 的 c_final shape: {c_final.shape}') # [1, B, h_dim]

2023-01-24 20:56:39.577 | INFO     | __main__:<module>:7 - LSTM 的 output shape: torch.Size([2, 3, 5])
2023-01-24 20:56:39.578 | INFO     | __main__:<module>:8 - LSTM 的 h_final shape: torch.Size([1, 2, 5])
2023-01-24 20:56:39.580 | INFO     | __main__:<module>:9 - LSTM 的 c_final shape: torch.Size([1, 2, 5])


In [14]:
# 查看官方实现的 LSTM 中的参数
for k, v in lstm_layer.named_parameters():
    print(k, v.shape)

weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 5])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])


### 4.2 逐行实现 LSTM

+ 输入门：$i_t = \sigma(W_{ii}x_t + b_{ii} + W_{hi}h_{t-1}+b_{hi})$
+ 遗忘门：$f_t = \sigma(W_{if}x_t + b_{if} + W_{hf}h_{t-1}+b_{hf})$
+ cell 门：$g_t = \tanh(W_{ig}x_t + b_{ig} + W_{hg}h_{t-1}+b_{hg})$
+ 输出门：$o_t = \tanh(W_{io}x_t + b_{io} + W_{ho}h_{t-1}+b_{ho})$
+ 记忆单元的更新：$c_t = f_t \odot c_{t-1} + i_t \odot g_t$
+ 隐藏状态的更新：$h_t = o_t \odot \tanh(c_t)$

In [19]:
def lstm_forward(
    input: Tensor,  # [B, seq_len, h_dim]
    initial_states: Tuple[Tensor, Tensor],
    w_ih: Tensor,  # [h_dim*4, input_size]
    w_hh: Tensor,  # [h_dim*4, h_dim]
    b_ih: Tensor,  # [h_dim*4]
    b_hh: Tensor   # [h_dim*4]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
    h0, c0 = initial_states  # 初始状态
    bs, T, input_size = input.shape
    h_dim = w_ih.shape[0] // 4  # 因为 w_ih 是四个 W 拼接起来的
    
    prev_h = h0  # [bs, h_dim]
    prev_c = c0
    # 对 W 进行一下扩维，方便之后与 x 进行 mini-batch 的运算
    batch_w_ih = w_ih.unsqueeze(0).tile([bs, 1, 1])  # [bs, h_dim*4, input_size]
    batch_w_hh = w_hh.unsqueeze(0).tile([bs, 1, 1])  # [bs, h_dim*4, h_dim]
    
    h_out = torch.zeros(bs, T, h_dim)  # 输出序列
    
    for t in range(T):
        x = input[:, t, :]  # 当前时刻的输入向量, [bs, input_size]
        
        x = x.unsqueeze(-1)  # [bs, input_size, 1]
        w_times_x = torch.bmm(batch_w_ih, x).squeeze(-1)  # [bs, h_dim*4]
        
        prev_h = prev_h.unsqueeze(-1)  # [bs, h_dim, 1]
        w_times_h = torch.bmm(batch_w_hh, prev_h).squeeze(-1)  #[bs. h_dim*4]
        
        # 分别计算输入门(i)、遗忘门(f)、cell 门(g)、输出门(o)
        i_t = torch.sigmoid(
            w_times_x[:, :h_dim] + w_times_h[:, :h_dim] + b_ih[:h_dim] + b_hh[:h_dim]
        )  # 注意都是取前四分之一
        f_t = torch.sigmoid(
            w_times_x[:, h_dim:h_dim*2] + w_times_h[:, h_dim:h_dim*2] + b_ih[h_dim:h_dim*2] + b_hh[h_dim:h_dim*2]
        )
        g_t = torch.tanh(
            w_times_x[:, h_dim*2:h_dim*3] + w_times_h[:, h_dim*2:h_dim*3] + b_ih[h_dim*2:h_dim*3] + b_hh[h_dim*2:h_dim*3]
        )
        o_t = torch.sigmoid(
            w_times_x[:, h_dim*3:] + w_times_h[:, h_dim*3:] + b_ih[h_dim*3:] + b_hh[h_dim*3:]
        )
        # 更新 h 和 c
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)
        
        h_out[:, t, :] = prev_h
        
    return h_out, (prev_h, prev_c)

### 4.3 结果验证

In [20]:
output1, (h_final1, c_final1) = lstm_forward(
    input,
    (h0, c0),
    lstm_layer.weight_ih_l0,
    lstm_layer.weight_hh_l0,
    lstm_layer.bias_ih_l0,
    lstm_layer.bias_hh_l0
)

output2, (h_final2, c_final2) = lstm_layer(input, (h0.unsqueeze(0), c0.unsqueeze(0)))

logger.info(f'自己实现的 LSTM 的 output:\n{output1}')
logger.info(f'PyTorch 的 LSTM 的 output:\n{output2}')
logger.info(f'自己实现的 LSTM 的 h_final:\n{h_final1}')
logger.info(f'PyTorch 的 LSTM 的 h_final:\n{h_final2}')
logger.info(f'自己实现的 LSTM 的 c_final:\n{c_final1}')
logger.info(f'PyTorch 的 LSTM 的 c_final:\n{c_final2}')

2023-01-24 20:41:36.167 | INFO     | __main__:<module>:12 - 自己实现的 LSTM 的 output:
tensor([[[ 0.2445, -0.1039,  0.2655,  0.0444, -0.1417],
         [ 0.3375, -0.0979,  0.1451,  0.0821, -0.0425],
         [ 0.0544, -0.1059,  0.1582,  0.1207, -0.0798]],

        [[ 0.1822, -0.0253,  0.2478,  0.0370, -0.0945],
         [ 0.3750, -0.0538,  0.0535,  0.0458, -0.0489],
         [ 0.2384, -0.1202,  0.0717,  0.0612, -0.0362]]], grad_fn=<CopySlices>)
2023-01-24 20:41:36.170 | INFO     | __main__:<module>:13 - PyTorch 的 LSTM 的 output:
tensor([[[ 0.2445, -0.1039,  0.2655,  0.0444, -0.1417],
         [ 0.3375, -0.0979,  0.1451,  0.0821, -0.0425],
         [ 0.0544, -0.1059,  0.1582,  0.1207, -0.0798]],

        [[ 0.1822, -0.0253,  0.2478,  0.0370, -0.0945],
         [ 0.3750, -0.0538,  0.0535,  0.0458, -0.0489],
         [ 0.2384, -0.1202,  0.0717,  0.0612, -0.0362]]],
       grad_fn=<TransposeBackward0>)
2023-01-24 20:41:36.173 | INFO     | __main__:<module>:14 - 自己实现的 LSTM 的 h_final:
tensor([[ 0.0

## 5. LSTMP 手写实现

### 5.1 PyTorch 官方 API

在 PyTorch 的 API 中，只需要在 `nn.LSTM` 实例化时加上一个 `proj_size` 参数即可。

这个 projection 的作用就是对 h_dim 进行压缩。

In [21]:
PROJ_SIZE = 3

proj_lstm_layer = nn.LSTM(INPUT_SIZE, HIDDEN_SIZE, batch_first=True, proj_size=PROJ_SIZE)

for k, v in proj_lstm_layer.named_parameters():
    print(k, v.shape)

weight_ih_l0 torch.Size([20, 4])
weight_hh_l0 torch.Size([20, 3])
bias_ih_l0 torch.Size([20])
bias_hh_l0 torch.Size([20])
weight_hr_l0 torch.Size([3, 5])


查看一下 `proj_lstm_layer` 的参数，可以看到它就是比 `lstm_layer` 多了一个 `weight_hr_l0` 的参数，这个参数就是用来对 hidden state 进行压缩的。

因此现在 hidden state 的大小变成了 3 (PROJ_SIZE)，而不是之前的 5 (HIDDEN_SIZE)。

从运行结果可以看到，只是对 hidden state 进行了压缩，并没有对记忆单元 c 进行压缩。

In [24]:
c0 = torch.randn(BATCH_SIZE, HIDDEN_SIZE)  # c0 的 shape 与之前一样
h0 = torch.randn(BATCH_SIZE, PROJ_SIZE)    # h0 的 shape 由之前的 HIDDEN_SIZE 变成 PROJ_SIZE

output, (h_final, c_final) = proj_lstm_layer(
    input,
    (h0.unsqueeze(0), c0.unsqueeze(0))  # 调用 unsqueeze 是因为我们用的是单向的，需要符合 API 的 shape 要求
)

logger.info(f'LSTMP 的 output shape: {output.shape}')   # [B, seq_len, proj_size]
logger.info(f'LSTMP 的 h_final shape: {h_final.shape}') # [1, B, proj_size]
logger.info(f'LSTMP 的 c_final shape: {c_final.shape}') # [1, B, h_dim]

2023-01-24 21:00:22.710 | INFO     | __main__:<module>:9 - LSTMP 的 output shape: torch.Size([2, 3, 3])
2023-01-24 21:00:22.712 | INFO     | __main__:<module>:10 - LSTMP 的 h_final shape: torch.Size([1, 2, 3])
2023-01-24 21:00:22.713 | INFO     | __main__:<module>:11 - LSTMP 的 c_final shape: torch.Size([1, 2, 5])


### 5.2 逐行实现 LSTMP

这里只需要对 `lstm_forward` 进行简单修改即可实现：

+ 参数中增加一个 `w_hr`，表示 projection，并通过这个参数可以获得 `proj_size`
+ 对 `w_hr` 进行扩维，获得 `batch_w_hr`
+ 最后的输出 `h_out` 的 shape：(bs, T, h_dim) -> (bs, T, proj_size)
+ 在之前计算完 `prev_h` 后，再通过 `w_hr` 对 `prev_h` 进行降维

In [25]:
def proj_lstm_forward(
    input: Tensor,  # [B, seq_len, h_dim]
    initial_states: Tuple[Tensor, Tensor],
    w_ih: Tensor,  # [h_dim*4, input_size]
    w_hh: Tensor,  # [h_dim*4, h_dim]
    b_ih: Tensor,  # [h_dim*4]
    b_hh: Tensor,  # [h_dim*4]
    w_hr: Tensor   # [proj_size, h_dim*4]
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
    h0, c0 = initial_states  # 初始状态
    bs, T, input_size = input.shape
    h_dim = w_ih.shape[0] // 4  # 因为 w_ih 是四个 W 拼接起来的
    proj_size = w_hr.shape[0]
    
    prev_h = h0  # [bs, h_dim]
    prev_c = c0
    # 对 W 进行一下扩维，方便之后与 x 进行 mini-batch 的运算
    batch_w_ih = w_ih.unsqueeze(0).tile([bs, 1, 1])  # [bs, h_dim*4, input_size]
    batch_w_hh = w_hh.unsqueeze(0).tile([bs, 1, 1])  # [bs, h_dim*4, h_dim]
    batch_w_hr = w_hr.unsqueeze(0).tile([bs, 1, 1])  # [bs, proj_size, h_dim]
    
    h_out = torch.zeros(bs, T, proj_size)  # 输出序列
    
    for t in range(T):
        x = input[:, t, :]  # 当前时刻的输入向量, [bs, input_size]
        
        x = x.unsqueeze(-1)  # [bs, input_size, 1]
        w_times_x = torch.bmm(batch_w_ih, x).squeeze(-1)  # [bs, h_dim*4]
        
        prev_h = prev_h.unsqueeze(-1)  # [bs, h_dim, 1]
        w_times_h = torch.bmm(batch_w_hh, prev_h).squeeze(-1)  #[bs. h_dim*4]
        
        # 分别计算输入门(i)、遗忘门(f)、cell 门(g)、输出门(o)
        i_t = torch.sigmoid(
            w_times_x[:, :h_dim] + w_times_h[:, :h_dim] + b_ih[:h_dim] + b_hh[:h_dim]
        )  # 注意都是取前四分之一
        f_t = torch.sigmoid(
            w_times_x[:, h_dim:h_dim*2] + w_times_h[:, h_dim:h_dim*2] + b_ih[h_dim:h_dim*2] + b_hh[h_dim:h_dim*2]
        )
        g_t = torch.tanh(
            w_times_x[:, h_dim*2:h_dim*3] + w_times_h[:, h_dim*2:h_dim*3] + b_ih[h_dim*2:h_dim*3] + b_hh[h_dim*2:h_dim*3]
        )
        o_t = torch.sigmoid(
            w_times_x[:, h_dim*3:] + w_times_h[:, h_dim*3:] + b_ih[h_dim*3:] + b_hh[h_dim*3:]
        )
        # 更新 h 和 c
        prev_c = f_t * prev_c + i_t * g_t
        prev_h = o_t * torch.tanh(prev_c)  # [bs, h_dim]
        
        # 进行 projection
        prev_h.unsqueeze_(-1)  # [bs, h_dim, 1]
        prev_h = torch.bmm(batch_w_hr, prev_h).squeeze(-1)  # [bs, proj_size]
        
        h_out[:, t, :] = prev_h
        
    return h_out, (prev_h, prev_c)

### 5.3 结果验证

In [26]:
output1, (h_final1, c_final1) = proj_lstm_forward(
    input,
    (h0, c0),
    proj_lstm_layer.weight_ih_l0,
    proj_lstm_layer.weight_hh_l0,
    proj_lstm_layer.bias_ih_l0,
    proj_lstm_layer.bias_hh_l0,
    proj_lstm_layer.weight_hr_l0
)

output2, (h_final2, c_final2) = proj_lstm_layer(
    input,
    (h0.unsqueeze(0), c0.unsqueeze(0))  # 调用 unsqueeze 是因为我们用的是单向的，需要符合 API 的 shape 要求
)

logger.info(f'自己实现的 LSTMP 的 output:\n{output1}')
logger.info(f'PyTorch 的 LSTMP 的 output:\n{output2}')
logger.info(f'自己实现的 LSTMP 的 h_final:\n{h_final1}')
logger.info(f'PyTorch 的 LSTMP 的 h_final:\n{h_final2}')
logger.info(f'自己实现的 LSTMP 的 c_final:\n{c_final1}')
logger.info(f'PyTorch 的 LSTMP 的 c_final:\n{c_final2}')

2023-01-24 21:18:10.829 | INFO     | __main__:<module>:16 - 自己实现的 LSTMP 的 output:
tensor([[[ 0.0256,  0.1098, -0.1193],
         [ 0.1177,  0.1309, -0.1649],
         [ 0.2817,  0.1614, -0.0983]],

        [[ 0.1013, -0.1186, -0.0720],
         [ 0.0932, -0.0278, -0.1713],
         [ 0.1110, -0.0149, -0.1693]]], grad_fn=<CopySlices>)
2023-01-24 21:18:10.832 | INFO     | __main__:<module>:17 - PyTorch 的 LSTMP 的 output:
tensor([[[ 0.0256,  0.1098, -0.1193],
         [ 0.1177,  0.1309, -0.1649],
         [ 0.2817,  0.1614, -0.0983]],

        [[ 0.1013, -0.1186, -0.0720],
         [ 0.0932, -0.0278, -0.1713],
         [ 0.1110, -0.0149, -0.1693]]], grad_fn=<TransposeBackward0>)
2023-01-24 21:18:10.835 | INFO     | __main__:<module>:18 - 自己实现的 LSTMP 的 h_final:
tensor([[ 0.2817,  0.1614, -0.0983],
        [ 0.1110, -0.0149, -0.1693]], grad_fn=<SqueezeBackward1>)
2023-01-24 21:18:10.837 | INFO     | __main__:<module>:19 - PyTorch 的 LSTMP 的 h_final:
tensor([[[ 0.2817,  0.1614, -0.0983],
     

LSTMP 相比于 LSTM，由于做了降维，因此计算量会比 LSTM 小不少，而且效果也没有差太多，是个很好的 trick。