In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from einops import rearrange
import math
import random
import numpy as np

torch.manual_seed(12046)
np.random.seed(12046)
random.seed(12046)

In [5]:
class PositionalEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, d_type, dropout=0.1):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(max_len, d_model)
        self.type_emb = nn.Embedding(d_type, d_model)
        self.layer_norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, type_ids=None):
        # x: [batch_size, seq_len]
        seq_len = x.size(1)
        if type_ids is None:
            type_ids = torch.zeros_like(x)
        pos_ids = torch.arange(0, seq_len, dtype=torch.long, device=x.device)
        x = self.tok_emb(x) + self.pos_emb(pos_ids) + self.type_emb(type_ids)
        return self.dropout(self.layer_norm(x))

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)

        self.wo = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, q, k, v, mask=None):
        # q, k, v: (batch_size, num_heads, seq_len, d_k)
        # 高效支持多头并行计算（num_heads 维度）；
        # 直接满足注意力分数的矩阵乘法需求（seq_len 和 d_k 的位置）；
        # 与后续的 softmax 归一化、V 的加权求和等操作无缝衔接。
        # 检查关键维度是否匹配

        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attn_weights = F.softmax(scores, dim=-1)
        outputs = torch.matmul(attn_weights, v)

        return outputs, attn_weights

    def split_heads(self, x, batch_size):
        # x: (batch_size, seq_len, d_model)
        # d_model -> num_heads * d_k
        x = x.view(batch_size, -1, self.num_heads, self.d_k)
        return x.transpose(1, 2) # (batch_size, num_heads, seq_len, d_k)

    def forward(self, q, k, v, mask=None):
        q = self.wq(q)
        k = self.wk(k)
        v = self.wv(v)

        batch_size = q.size(0)

        q = self.split_heads(q, batch_size)
        k = self.split_heads(k, batch_size)
        v = self.split_heads(v, batch_size)

        # attention
        outputs, attn_weights = self.scaled_dot_product_attention(q, k, v, mask) # (batch_size, num_heads, seq_len, d_k)

        # combine heads
        outputs = outputs.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model) # (batch_size, seq_len, d_model)

        outputs = self.wo(outputs)
        return outputs, attn_weights

In [7]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(p=dropout)
        self.activation = nn.GELU()  # BERT使用GELU激活函数
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [14]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout=0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        # Self_attention
        attn_output, _ = self.self_attn(x, x, x, mask)
        x = x + self.dropout1(attn_output)
        x = self.norm1(x)
        
        # Feed Forward Network
        ffn_output = self.ffn(x)
        x = x + self.dropout2(ffn_output)
        x = self.norm2(x)

        return x

In [18]:
class PoolingLayer(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.dense = nn.Linear(d_model, d_model)
        self.tanh = nn.Tanh()
        
    def forward(self, x):
        # x: (batch_size, seq_len, d_model)
        cls_token = x[:, 0] # (batch_size, d_model)
        pooled_output = self.dense(cls_token)
        pooled_output = self.tanh(pooled_output)

        return pooled_output


In [19]:
class BERT(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_heads=8, d_ff=2048, num_layers=6, max_len=5000, d_type=2, dropout=0.1):
        super().__init__()
        self.embedding = PositionalEmbedding(vocab_size, d_model, max_len, d_type, dropout)
        self.EncoderLayers = nn.ModuleList([
            EncoderLayer(d_model, num_heads, d_ff, dropout)
            for _ in range(num_layers)
        ])
        self.pooler = PoolingLayer(d_model)

    def generate_mask(self, src, pad=0):
        # src: (batch_size, seq_len)
        mask = (src != pad).unsqueeze(1).unsqueeze(2) # (batch_size, 1, 1, seq_len)
        return mask
    
    def forward(self, x):
        mask = self.generate_mask(x)
        x = self.embedding(x)
        for encoder in self.EncoderLayers:
            x = encoder(x, mask)
        pooled_output = self.pooler(x)
        return x, pooled_output

In [21]:
def test():
    vocab_size = 30522
    d_model = 768
    num_heads = 12
    num_layers = 12
    d_ff = 3072
    max_len = 512
    
    bert = BERT(vocab_size, d_model, num_heads, d_ff, num_layers, max_len)
    
    batch_size = 2
    seq_len = 10
    x = torch.randint(0, vocab_size, (batch_size, seq_len))

    encoder_outputs, pooled_output = bert(x)
    
    print(f"Encoder outputs shape: {encoder_outputs.shape}")
    print(f"Pooled output shape: {pooled_output.shape}")

test()

Encoder outputs shape: torch.Size([2, 10, 768])
Pooled output shape: torch.Size([2, 768])
