# 手写模型

## 1. Model

In [1]:
import torch
import torch.nn.functional as F
from torch import nn
import math


In [2]:
# 超参数设置
d_model = 512  # 词向量维度
context_size = 16  # 上下文窗口大小
num_heads = 8  # 多头注意力机制的头数
head_dim = d_model // num_heads  # 每个头的维度
dropout = 0.1  # 随机失活率
num_blocks = 6  # 编码器块数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
# ffn层
class Feedforward(nn.Module):
    def __init__(self):
        super(Feedforward, self).__init__()
        self.ffn == nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.ReLU(),
            nn.Linear(4*d_model, d_model),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        self.ffn(x)
        

In [4]:
# 注意力层
class Attention(nn.Module):
    def __init__(self):
        super(Attention, self).__init__()
        self.Wq = nn.Linear(d_model, head_dim)  
        self.Wk = nn.Linear(d_model, head_dim)
        self.Wv = nn.Linear(d_model, head_dim)
        # 掩码
        self.register_buffer("mask", torch.tril(torch.ones(context_size, context_size)))
    def forward(self, x):
        # x是[batch_size, seq_len, d_model]，其中seq_len是序列长度，范围是[1, seq_len]
        # mask是[batch_size, 1, seq_len, seq_len]，其中mask[i,j,k,l]表示第i个样本的第k个位置是否可以看做第j个样本的第l个位置
        # 注意力权重
        B, S, D = x.shape
        q = self.Wq(x)  # [batch_size, seq_len, head_dim]
        k = self.Wk(x)  
        v = self.Wv(x)  # [batch_size, seq_len, head_dim]
        
        output = torch.matmul(q, k.transpose(-2, -1))/math.sqrt(k.size(-1))
        # 掩码
        output = output.masked_fill(self.mask[:S, :S] == 0, float('-inf'))  # 掩掉对角线，即不允许看做自己
        output = F.softmax(output, dim=-1)  # [batch_size, seq_len, seq_len]
        output = nn.Dropout(self.dropout)(output)
        
        # 输出
        output = torch.matmul(output, v)  
        return output

In [5]:
# 多头注意力层
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.heads = nn.ModuleList([Attention() for _ in range(num_heads)])
        self.linear = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        output = torch.cat([head(x) for head in self.heads], dim=-1)
        output = self.dropout(self.linear(output))
        
        return output

In [6]:
# Transformer块
class TransformerBlock(nn.Module):
    def __init__(self):
        super(TransformerBlock, self).__init__()
        self.mha = MultiHeadAttention()
        self.ffn = Feedforward()
        self.layernorm1 = nn.LayerNorm(d_model)
        self.layernorm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        x = x + self.mha(self.layernorm1(x))
        x = x + self.ffn(self.layernorm2(x))
        return x

In [7]:
class SimplifiedTransformerBlock(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        # 1. 移除残差连接和归一化层，保留核心注意力+MLP
        self.attn = nn.MultiheadAttention(d_model, num_heads=8)
        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.GELU(),
            nn.Linear(4*d_model, d_model)
        )
        # 2. 使用Value-SkipInit稳定训练（论文[7](@ref)方案）
        self.alpha = nn.Parameter(torch.tensor(0.1))
        self.beta = nn.Parameter(torch.tensor(1.0))

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        attn_out = self.alpha * x + self.beta * attn_out  # 替代残差连接
        mlp_out = self.mlp(attn_out)
        return mlp_out

In [8]:
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class BlockTransformer(nn.Module):
    def __init__(self, d_model, block_size=4):
        super().__init__()
        self.block_size = block_size
        # 块内局部注意力
        self.local_attn = TransformerEncoder(
            TransformerEncoderLayer(d_model, nhead=8, dim_feedforward=4*d_model),
            num_layers=2
        )
        # 块间全局注意力
        self.global_attn = TransformerEncoder(
            TransformerEncoderLayer(d_model, nhead=8),
            num_layers=1
        )

    def forward(self, x):
        B, T, D = x.shape
        x = x.view(B, T//self.block_size, self.block_size, D)
        x = self.local_attn(x)          # 处理块内依赖
        x = x.mean(dim=2)               # 块嵌入向量
        x = self.global_attn(x)         # 处理块间依赖
        return x.repeat_interleave(self.block_size, dim=1)[:,:T]

In [ ]:
import math
import torch
import torch.nn as nn

# 位置编码优化
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维度用sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维度用cos
        self.register_buffer('pe', pe.unsqueeze(0))  # 形状(1, max_len, d_model)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

In [ ]:


class OptimizedModel(nn.Module):
    def __init__(self, vocab_size, d_model=512, num_blocks=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.vocab_linear = nn.Linear(d_model, vocab_size)
        # 使用简化Transformer块+块级注意力混合架构
        self.transformer = nn.Sequential(
            *[SimplifiedTransformerBlock(d_model) if i < num_blocks//2 
              else BlockTransformer(d_model) for i in range(num_blocks)],
            nn.LayerNorm(d_model)
        )
        # 位置编码优化（直接调用PyTorch内置）
        self.pos_encoder = PositionalEncoding(d_model)

    def forward(self, x_batch):
        x = self.embedding(x_batch) * math.sqrt(d_model)
        x = self.pos_encoder(x)  # 替代手动实现
        return self.vocab_linear(self.transformer(x))