In [3]:
import torch
import torch.nn as nn

# 参数
d_model = 512  # 隐藏维度
nhead = 8      # 注意力头数
num_encoder_layers = 6  # 编码器层数
num_decoder_layers = 6  # 解码器层数
dim_feedforward = 2048  # 前馈层中间维度
dropout = 0.1

# 定义编码器和解码器层
encoder_layer = nn.TransformerEncoderLayer(
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    batch_first=True
)
decoder_layer = nn.TransformerDecoderLayer(
    d_model=d_model,
    nhead=nhead,
    dim_feedforward=dim_feedforward,
    dropout=dropout,
    batch_first=True
)

# 堆叠多层
transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_encoder_layers)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_decoder_layers)

# 输入数据
batch_size = 2
src_seq_len = 10  # 源序列长度
tgt_seq_len = 8   # 目标序列长度
src = torch.randn(batch_size, src_seq_len, d_model)  # 源序列
tgt = torch.randn(batch_size, tgt_seq_len, d_model)  # 目标序列

# 生成掩码（自回归掩码）
def generate_square_subsequent_mask(sz):
    mask = torch.triu(torch.ones(sz, sz), diagonal=1).bool()
    return mask

tgt_mask = generate_square_subsequent_mask(tgt_seq_len).to(src.device)

# 前向传播
# (1) 编码器处理源序列
memory = transformer_encoder(src)  # 输出形状: (batch_size, src_seq_len, d_model)

# (2) 解码器处理目标序列，并使用编码器输出进行交叉注意力
output = transformer_decoder(
    tgt,              # 目标序列
    memory,           # 编码器输出
    tgt_mask=tgt_mask # 自回归掩码
)  # 输出形状: (batch_size, tgt_seq_len, d_model)

# 打印结果
print("Encoder output (memory) shape:", memory.shape)
print("Decoder output shape:", output.shape)

Encoder output (memory) shape: torch.Size([2, 10, 512])
Decoder output shape: torch.Size([2, 8, 512])
