In [23]:
import torch
import torch.nn as nn

# 定义自注意力机制
class SelfAttention(nn.Module):
    def __init__(self, embedding_size: int, heads: int):
        """Args:
            embedding_size: int, 输入和输出的维度大小
            heads: int, 注意力头的数量
        """
        super(SelfAttention, self).__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // heads
        
        assert (
            self.head_dim * heads == embedding_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embedding_size)

    def forward(self, values: torch.Tensor, keys: torch.Tensor, query: torch.Tensor, mask: torch.Tensor = None):
        """Args:
            values: Tensor of shape (N, value_len, embedding_size)
            keys: Tensor of shape (N, key_len, embedding_size)
            query: Tensor of shape (N, query_len, embedding_size)
            mask: Tensor of shape (N, heads, query_len, key_len)
        Returns:
            out: Tensor of shape (N, query_len, embedding_size)
        """
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # 对输入进行分头处理
        values = values.view(N, value_len, self.heads, self.head_dim)
        keys = keys.view(N, key_len, self.heads, self.head_dim)
        queries = query.view(N, query_len, self.heads, self.head_dim)

        # 对输入进行线性变换
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        # 计算注意力分数
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** 0.5)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # 计算注意力权重
        attention = torch.softmax(energy, dim=-1)

        # 计算输出
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        return self.fc_out(out)

In [24]:
# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, embedding_size: int, forward_expansion: int, dropout: float=0.1):
        """Args: 
            embedding_size: int, 输入和输出的维度大小
            forward_expansion: int, 前馈网络中间层的扩展倍数
            dropout: float, dropout概率
        """
        super(FeedForward, self).__init__()
        self.embedding_size = embedding_size
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(embedding_size, forward_expansion * embedding_size)
        self.fc2 = nn.Linear(forward_expansion * embedding_size, embedding_size)

    def forward(self, x: torch.Tensor):
        """Args:
            x: Tensor of shape (N, seq_len, embedding_size)
        Returns:
            out: Tensor of shape (N, seq_len, embedding_size)
        """
        out = self.fc1(x)
        out = nn.ReLU()(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out

In [None]:
# Transformer编码器块
class TransformerBlock(nn.Module):
    def __init__(self, embedding_size: int, heads: int, forward_expansion: int, dropout: float=0.1):
        """Args:
            embedding_size: int, 输入和输出的维度大小
            heads: int, 注意力头的数量
            forward_expansion: int, 前馈网络中间层的扩展倍数
            dropout: float, dropout概率
        """
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embedding_size, heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.norm2 = nn.LayerNorm(embedding_size)

        self.embedding_size = embedding_size
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = FeedForward(embedding_size, forward_expansion)
        

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        attention = self.attention(x, x, x, mask)
        out = self.norm1(attention + x)
        dropout_1 = self.dropout(out)
        out = self.feed_forward(dropout_1)
        out = self.norm2(out + dropout_1)
        dropout_2 = self.dropout(out)
        return dropout_2

In [None]:
# BERT模型核心实现与预训练
class BERTEncoder(nn.Module):
    def __init__(self, embedding_size: int, heads: int, forward_expansion: int, num_layers: int, dropout: float=0.1):
        """Args:
            embedding_size: int, 输入和输出的维度大小
            heads: int, 注意力头的数量
            forward_expansion: int, 前馈网络中间层的扩展倍数
            num_layers: int, Transformer块的数量
            dropout: float, dropout概率
        """
        super(BERTEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(embedding_size, heads, forward_expansion, dropout)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        out = x
        for layer in self.layers:
            out = layer(out, mask)
        return out


In [27]:
# 模拟输入
embedding_size = 768
heads = 12
forward_expansion = 4
dropout = 0.1
num_layers = 12
seq_len = 20
batch_size = 2


In [28]:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化BERT编码器堆叠
bert_encoder = BERTEncoder(embedding_size, heads, forward_expansion, num_layers, dropout).to(device)
# 随机生成输入数据和掩码
input_data = torch.randn(batch_size, seq_len, embedding_size).to(device)  # [2, 20, 768]
mask_full_one = torch.ones(seq_len, seq_len)  # [20, 20] 创建一个全1的掩码
# print(mask_full_one)
mask_lower_triangular = torch.tril(mask_full_one)  # 创建下三角矩阵实现单向注意力
# print(mask_lower_triangular)
mask = mask_lower_triangular.expand(batch_size, heads, seq_len, seq_len).to(device)  # mask是实现单向注意力的机制,通过下三角矩阵实现
print(mask.shape)  # [2, 12, 20, 20]
# 前向传播
output = bert_encoder(input_data, mask)
print(output.shape)
print(output)


torch.Size([2, 12, 20, 20])
torch.Size([2, 20, 768])
tensor([[[-1.3843e-01, -1.7316e+00, -1.5024e+00,  ..., -1.2094e+00,
          -0.0000e+00,  0.0000e+00],
         [ 6.1285e-02, -5.6583e-01,  5.1647e-01,  ..., -0.0000e+00,
           1.2017e-01,  0.0000e+00],
         [-2.2074e-01, -0.0000e+00, -0.0000e+00,  ..., -5.8492e-01,
           7.9412e-02, -6.3077e-02],
         ...,
         [ 1.1067e+00,  5.7706e-01, -8.4552e-01,  ...,  2.1099e-01,
           4.5171e-01,  4.7263e-01],
         [ 0.0000e+00,  2.5806e-01, -1.0885e+00,  ..., -0.0000e+00,
          -1.1213e-01,  0.0000e+00],
         [-2.3991e-01,  4.6774e-01, -5.2625e-01,  ..., -1.7556e-01,
           2.0771e+00,  0.0000e+00]],

        [[-1.1635e+00, -6.1539e-01, -7.2478e-01,  ..., -1.3339e-01,
           5.8951e+00, -8.4884e-01],
         [ 3.5002e-01,  1.0124e+00, -3.8720e+00,  ..., -7.2561e-01,
           1.1796e+00,  1.1796e+00],
         [-0.0000e+00,  1.4496e+00, -0.0000e+00,  ..., -2.0154e-01,
           5.2736e-01, 