In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# 定义自注意力机制
class SelfAttention(nn.Module):
    def __init__(self, embedding_size, heads):
        super(SelfAttention, self).__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // heads
        
        assert (
            self.head_dim * heads == embedding_size
        ), "Embedding size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embedding_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]
        
        # 对输入进行分头处理
        values = values.view(N, value_len, self.heads, self.head_dim)
        keys = keys.view(N, key_len, self.heads, self.head_dim)
        queries = query.view(N, query_len, self.heads, self.head_dim)

        # 对输入进行线性变换
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)
        
        # 计算注意力分数
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** 0.5)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        # 计算注意力权重
        attention = torch.softmax(energy, dim=-1)

        # 计算输出
        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        return self.fc_out(out)

In [64]:
# 前馈神经网络
class FeedForward(nn.Module):
    def __init__(self, embedding_size, forward_expansion, dropout=0.1):
        super(FeedForward, self).__init__()
        self.embedding_size = embedding_size
        self.dropout = nn.Dropout(dropout)
        self.fc1 = nn.Linear(embedding_size, forward_expansion * embedding_size)
        self.fc2 = nn.Linear(forward_expansion * embedding_size, embedding_size)

    def forward(self, x):
        out = self.fc1(x)
        out = nn.ReLU()(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out

In [65]:
# Transformer编码器块
class TransformerBlock(nn.Module):
    def __init__(self, embedding_size, heads, forward_expansion, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embedding_size, heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.norm2 = nn.LayerNorm(embedding_size)

        self.embedding_size = embedding_size
        self.dropout = nn.Dropout(dropout)
        self.feed_forward = FeedForward(embedding_size, forward_expansion)
        

    def forward(self, x, mask):
        attention = self.attention(x, x, x, mask)
        out = self.norm1(attention + x)
        dropout_1 = self.dropout(out)
        out = self.feed_forward(dropout_1)
        out = self.norm2(out + dropout_1)
        dropout_2 = self.dropout(out)
        return dropout_2

In [66]:
# BERT模型核心实现与预训练
class BERTEncoder(nn.Module):
    def __init__(self, embedding_size, heads, forward_expansion, num_layers, dropout=0.1):
        super(BERTEncoder, self).__init__()
        self.layers = nn.ModuleList([
            TransformerBlock(embedding_size, heads, forward_expansion, dropout)
            for _ in range(num_layers)
        ])
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        for i, layer in enumerate(self.layers):
            x = layer(x, mask)
            x = self.dropout(x)
        return x


In [67]:
# 模拟输入
embedding_size = 768
heads = 12
forward_expansion = 4
dropout = 0.1
num_layers = 12
seq_len = 20
batch_size = 2


In [68]:
# 初始化BERT编码器堆叠
bert_encoder = BERTEncoder(embedding_size, heads, forward_expansion, num_layers, dropout)
# 随机生成输入数据和掩码
input_data = torch.randn(batch_size, seq_len, embedding_size)
mask = torch.tril(torch.ones(seq_len, seq_len)).expand(batch_size, heads, seq_len, seq_len)  # mask是实现单向注意力的机制,通过下三角矩阵实现

# 前向传播
output = bert_encoder(input_data, mask)
print(output.shape)
print(output)


torch.Size([2, 20, 768])
tensor([[[-0.0000, -0.9497, -0.4968,  ...,  0.8411,  0.6234,  0.3803],
         [-0.0000, -0.1906, -0.0000,  ..., -0.5954,  1.1762, -2.4484],
         [ 8.4211,  0.7008, -0.3765,  ..., -0.3821,  0.2273, -0.5241],
         ...,
         [-4.9552,  0.0328, -0.0000,  ...,  0.1010,  0.2527,  1.3961],
         [ 0.0000, -0.9603,  0.0000,  ...,  0.6102,  1.4641,  0.4910],
         [ 0.1294, -1.0197,  0.0104,  ..., -0.3218,  1.9343,  1.2534]],

        [[-1.3132,  0.0000, -0.0782,  ...,  1.4497, -0.2155, -0.0000],
         [ 0.0830, -0.1739,  0.0000,  ...,  0.2428, -0.2089,  0.7737],
         [-0.2147,  0.3273,  0.0581,  ...,  0.3750, -0.0000,  0.0451],
         ...,
         [ 1.4524,  1.0399,  0.1295,  ...,  1.1996, -1.6566, -3.2263],
         [-0.0000,  0.4060, -1.4194,  ..., -0.2064, -0.7548,  0.0805],
         [-2.1884, -0.0938, -2.6976,  ...,  0.0755, -0.2438,  0.0000]]],
       grad_fn=<MulBackward0>)
