In [68]:
# 3.1.2 BERT的自注意力机制与掩码任务

import torch
import torch.nn as nn
import torch.nn.functional as F
import random

In [69]:
# 自注意力机制实现
class SelfAttention(nn.Module):
    def __init__(self, embedding_size: int, heads: int):
        """Args:
            embedding_size: int, 输入和输出的维度大小
            heads: int, 注意力头的数量
        """
        super(SelfAttention, self).__init__()
        self.embedding_size = embedding_size
        self.heads = heads
        self.head_dim = embedding_size // self.heads
        assert self.head_dim * self.heads == embedding_size, "embedding_size必须是heads的整数倍"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(self.head_dim * self.heads, embedding_size)

    def forward(self, values: torch.Tensor, keys: torch.Tensor, query: torch.Tensor, mask: torch.Tensor = None):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # 分割embedding_size到多个头
        values = values.view(N, value_len, self.heads, self.head_dim)
        keys = keys.view(N, key_len, self.heads, self.head_dim)
        query = query.view(N, query_len, self.heads, self.head_dim)

        values = self.values(values)  # (N, value_len, heads, head_dim)
        keys = self.keys(keys)      # (N, key_len, heads, head_dim)
        queries = self.queries(query)  # (N, query_len, heads, head_dim)
        # 计算注意力分数
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys]) / (self.head_dim ** 0.5)  # (N, heads, query_len, key_len)
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
        
        attention = torch.softmax(energy, dim=-1)  # (N, heads, query_len, key_len)
        out = torch.einsum("nhqk,nkhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )  # (N, query_len, heads * head_dim)
        out = self.fc_out(out)  # (N, query_len, embedding_size)
        return out


# 掩码任务数据集创建函数
def create_masked_data(inputs, mask_token_id, vocab_size, mask_prob=0.15):
    """
    Args:
        inputs: torch.Tensor, 输入的token ID序列, shape: [batch_size, seq_len]
        mask_token_id: int, 用于掩码的特殊token ID
        vocab_size: int, 词汇表大小
        mask_prob: float, 掩码的概率
    Returns:
        masked_inputs: torch.Tensor, 掩码后的输入序列
        labels: torch.Tensor, 真实标签序列, 未掩码位置为-100
    """
    inputs_with_masks = inputs.clone()
    labels = inputs.clone()
    batch_size, seq_len = inputs.shape[:2]
    for i in range(batch_size):
        for j in range(seq_len):
            prob = random.random()
            if prob < mask_prob:
                prob /= mask_prob
                if prob < 0.8:
                    inputs_with_masks[i, j] = mask_token_id  # 80%概率替换为掩码token
                elif prob < 0.9:
                    inputs_with_masks[i, j] = random.randint(0, vocab_size - 1)  # 10%概率替换为随机token
                # 10%概率保持原token不变
            else:
                labels[i, j] = -100  # 非掩码位置标签设为-100
    return inputs_with_masks, labels


class FeedForward(nn.Module):
    def __init__(self, embedding_size: int, forward_expansion: int, dropout: float=0.1):
        """Args: 
            embedding_size: int, 输入和输出的维度大小
            forward_expansion: int, 前馈网络中间层的扩展倍数
            dropout: float, dropout概率
        """
        super(FeedForward, self).__init__()
        self.embedding_size = embedding_size
        self.fc1 = nn.Linear(embedding_size, embedding_size * forward_expansion)
        self.fc2 = nn.Linear(embedding_size * forward_expansion, embedding_size)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor):
        """Args:
            x: Tensor of shape (N, seq_len, embedding_size)
        Returns:
            out: Tensor of shape (N, seq_len, embedding_size)
        """
        out = self.fc1(x)
        out = F.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out


class TransformerBlock(nn.Module):
    def __init__(self, embedding_size: int, heads: int, forward_expansion: int, dropout: float=0.1):
        """Args:
            embedding_size: int, 输入和输出的维度大小
            heads: int, 注意力头的数量
            forward_expansion: int, 前馈网络中间层的扩展倍数
            dropout: float, dropout概率
        """
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embedding_size, heads)
        self.norm1 = nn.LayerNorm(embedding_size)
        self.norm2 = nn.LayerNorm(embedding_size)

        self.feed_forward = FeedForward(embedding_size, forward_expansion, dropout)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        attention = self.attention(x, x, x, mask)
        out = self.norm1(attention + x)
        dropout_1 = self.dropout(out)
        out = self.feed_forward(dropout_1)
        out = self.norm2(out + dropout_1)
        dropout_2 = self.dropout(out)
        return dropout_2

# BERT编码器模块
class BERTEncoder(nn.Module):
    def __init__(self, embedding_size: int, heads: int, forward_expansion: int, num_layers: int, dropout: float=0.1):
        """Args:
            embedding_size: int, 输入和输出的维度大小
            heads: int, 注意力头的数量
            forward_expansion: int, 前馈网络中间层的扩展倍数
            num_layers: int, Transformer编码器块的数量
            dropout: float, dropout概率
        """
        super(BERTEncoder, self).__init__()
        self.layers = nn.ModuleList(
            [
                TransformerBlock(embedding_size, heads, forward_expansion, dropout)
                for _ in range(num_layers)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        out = x
        for layer in self.layers:
            out = layer(out, mask)
        return out

In [70]:
# 模拟数据
vocab_size = 30522  # 词汇表大小
embedding_size = 768  # BERT-base的嵌入维度
num_layers = 12  # Transformer编码器块数量
heads = 12  # 注意力头数量
forward_expansion = 4  # 前馈网络扩展倍数
dropout = 0.1  # dropout概率
seq_len = 20  # 序列长度
batch_size = 2  # 批量大小
mask_token_id = 103  # BERT的掩码token ID

In [78]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 初始化BERT编码器堆叠
bert_encoder = BERTEncoder(embedding_size, heads, forward_expansion, num_layers, dropout).to(device)
# 随机生成输入数据和掩码
input_data = torch.randint(0, vocab_size, (batch_size, seq_len)).to(device)  # [2, 20]
mask_full_one = torch.ones(batch_size, seq_len).to(device)  # [2, 20, 20] 创建一个全1的掩码
mask_for_attention = mask_full_one.unsqueeze(1).unsqueeze(2)  # [N, 1, 1, L]
# 创建掩码任务数据
masked_input_ids, labels = create_masked_data(input_data, mask_token_id, vocab_size)
embedding_layer = nn.Embedding(vocab_size, embedding_size).to(device)
masked_embeddings = embedding_layer(masked_input_ids)  # [2, 20, 768]
print("输入数据 (带掩码):", masked_embeddings.shape, mask_full_one.shape)

输入数据 (带掩码): torch.Size([2, 20, 768]) torch.Size([2, 20])


In [80]:
# 前向传播
output = bert_encoder(masked_embeddings, mask=None)
print("BERT编码器输出形状:", output.shape)  # 应为 [2, 20, 768]
print("掩码后的输入:", masked_embeddings)
print("掩码标签:", labels)
print("编码器输出:", output)

BERT编码器输出形状: torch.Size([2, 20, 768])
掩码后的输入: tensor([[[ 0.6669,  1.1187, -1.0041,  ..., -0.2817, -0.4558,  1.9066],
         [-1.1914, -0.9554,  0.7611,  ...,  0.8293,  0.1727, -0.4522],
         [ 1.9041,  0.0209, -1.4513,  ..., -0.2897,  0.1818, -1.7809],
         ...,
         [-0.7646, -0.6862, -2.3349,  ..., -0.3386, -1.5757, -0.1685],
         [-1.6908,  0.6005, -0.6600,  ..., -0.3241,  1.2650,  0.1150],
         [ 1.0049,  2.0781,  1.2398,  ..., -1.5259, -1.6640, -0.0551]],

        [[-0.1353, -1.4016, -0.0999,  ..., -0.1649,  0.0735, -0.6180],
         [-0.2818,  0.6535, -0.4367,  ...,  0.0062, -1.7736,  0.7129],
         [-0.7647,  0.8749,  0.3230,  ..., -0.2911,  0.2479, -0.7094],
         ...,
         [-1.1440,  0.8340, -2.2205,  ..., -0.8262,  1.8734, -0.8777],
         [-0.1865,  0.0545,  0.8164,  ..., -0.1453, -0.7726,  0.1557],
         [-1.1914, -0.9554,  0.7611,  ...,  0.8293,  0.1727, -0.4522]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)
掩码标签: tensor([[ 