In [None]:
import torch
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward, max_len=5000):
        super(TransformerModel, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, max_len, d_model))
        self.transformer_encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward),
            num_layers=num_encoder_layers
        )
        self.decoder = nn.Linear(d_model, vocab_size)
    
    def forward(self, src):
        embedded = self.embedding(src) + self.positional_encoding[:, :src.size(0), :]
        output = self.transformer_encoder(embedded)
        output = self.decoder(output)
        return output

# 参数设置
vocab_size = 10000
d_model = 512
nhead = 8
num_encoder_layers = 6
dim_feedforward = 2048

# 初始化模型
model = TransformerModel(vocab_size, d_model, nhead, num_encoder_layers, dim_feedforward)

# 输入示例
src = torch.randint(0, vocab_size, (30, 64))  # (seq_len, batch_size)
output = model(src)
print(output.shape)  # 输出: torch.Size([30, 64, 10000])