位置编码器
  - transformer处理信息时无序
  - 为引入位置信息使用
  - 使模型能够识别不同词之间的相对位置关系

In [2]:
import torch
import torch.nn as nn
import math


class PositionalEncoding(nn.Module):
    def __init__(self, embedding_size, max_len=5000):
        super(PositionalEncoding, self).__init__()
        # 创建位置编码矩阵
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_size, 2).float() * (-math.log(10000.0) / embedding_size))
        pe = torch.zeros(max_len, embedding_size)
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # 添加批次维度
        self.register_buffer('pe', pe) # 将位置编码注册为缓冲区,不作为模型参数更新

    def forward(self, x):
        # 将位置编码添加到输入嵌入中
        x = x + self.pe[:, :x.size(1), :]
        return x

In [6]:
max_len = 10
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
position, position.shape

(tensor([[0.],
         [1.],
         [2.],
         [3.],
         [4.],
         [5.],
         [6.],
         [7.],
         [8.],
         [9.]]),
 torch.Size([10, 1]))

In [13]:
torch.exp(torch.tensor(math.log(10000.0)))

tensor(10000.0010)

$$ \text{div\_term}_i = 10000^{2i / d_{\text{model}}}  \\
    math.log(10000.0) = \ln10000.0
$$

In [2]:
# 带位置编码的嵌入层
class EmbeddingWithPositionalEncoding(nn.Module):
    def __init__(self, vocab_size, embedding_size, max_len=5000):
        super(EmbeddingWithPositionalEncoding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.pos_encoder = PositionalEncoding(embedding_size, max_len)

    def forward(self, x):
        x = self.embedding(x) # * math.sqrt(self.embedding.embedding_dim)  # 缩放嵌入
        x = self.pos_encoder(x)
        return x

In [3]:
# 超参数设置
vocab_size = 100  # 词汇表大小
embedding_size = 64  # 嵌入维度
max_len = 50  # 最大序列长度
batch_size = 2  # 批量大小

In [4]:
# 模拟输入数据
input_data = torch.randint(0, vocab_size, (batch_size, max_len))  # 随机整数作为词索引
# 初始化带位置编码的嵌入层
embedding_layer = EmbeddingWithPositionalEncoding(vocab_size, embedding_size, max_len)
# 前向传播
output = embedding_layer(input_data)
print(output.shape)  # 输出形状应为 (batch_size, max_len, embedding_size)
print("位置编码后的嵌入输出:", output)
# 检查位置编码的具体值
print("位置编码矩阵的一部分:", embedding_layer.pos_encoder.pe[:, :5, :10])  # 打印前5个位置的前10维编码

torch.Size([2, 50, 64])
位置编码后的嵌入输出: tensor([[[-1.0561,  1.0751, -1.3608,  ...,  1.2327,  0.1943, -0.2491],
         [ 0.6463, -0.5637,  1.2275,  ...,  0.2415, -0.4657,  0.0312],
         [ 1.3585, -1.6629,  0.6527,  ...,  1.3106, -1.4897,  0.2261],
         ...,
         [ 0.4542, -0.4600, -1.7503,  ...,  2.2198, -0.1154,  1.9548],
         [-1.3759, -1.7832, -2.2850,  ...,  0.3963, -0.0343,  0.2413],
         [ 0.2892, -0.0244, -1.7518,  ...,  2.2570,  0.4946,  2.2084]],

        [[ 0.0775,  2.0395, -0.9150,  ...,  0.0586, -1.0018, -0.6656],
         [ 2.5678,  0.7628,  1.6714,  ...,  0.7785,  0.1843,  1.2955],
         [-0.0403, -0.2317,  0.0995,  ...,  0.2330,  0.4641,  1.9797],
         ...,
         [ 0.0164, -1.0728, -1.1679,  ..., -0.1144,  0.1658,  1.6534],
         [-1.3759, -1.7832, -2.2850,  ...,  0.3963, -0.0343,  0.2413],
         [-2.5312,  1.1633, -1.8766,  ...,  1.9803,  1.7273,  0.3746]]],
       grad_fn=<AddBackward0>)
位置编码矩阵的一部分: tensor([[[ 0.0000,  1.0000,  0.0000, 

In [6]:
# 测试位置编码在不同序列长度下的稳定性
test_input = torch.randint(0, vocab_size, (batch_size, max_len - 10))  # 更长的序列
test_output = embedding_layer(test_input)
print("测试数据的嵌入输出:", test_output.shape)
print("测试数据的嵌入输出:", test_output)

测试数据的嵌入输出: torch.Size([2, 40, 64])
测试数据的嵌入输出: tensor([[[ 1.1289,  1.9779,  0.9760,  ...,  2.6218, -0.0465,  2.3552],
         [-0.2126,  0.3550,  0.5644,  ..., -0.3758,  0.0059, -1.0395],
         [-0.9947, -0.2407,  1.2984,  ...,  0.6317,  0.4573,  0.7703],
         ...,
         [-0.5170,  1.5539, -0.4908,  ...,  0.4382,  0.3909,  1.3073],
         [ 0.6961, -0.5537,  1.0388,  ...,  1.9344, -0.9189,  1.0266],
         [ 0.5713, -0.0043, -0.5297,  ...,  2.0219,  0.0780,  0.0214]],

        [[-1.9113,  0.4304,  0.3782,  ...,  1.7448, -0.8705,  2.3921],
         [ 1.7945,  0.4640,  1.6173,  ...,  2.4166, -0.7445,  0.7879],
         [-1.2218, -0.4883,  1.2362,  ...,  2.5640,  0.7730,  2.5356],
         ...,
         [ 0.5144,  1.6762, -0.0917,  ...,  1.3068, -1.4846,  1.4938],
         [-1.2811,  1.8178, -1.2805,  ...,  1.9803,  1.7258,  0.3746],
         [ 1.6554, -0.2844,  1.0370,  ...,  1.0341,  1.3276,  2.6463]]],
       grad_fn=<AddBackward0>)
