In [1]:
import torch
from torch import nn
import torch.nn.functional as F
import math

## Embedding


###TokenEmbedding

In [2]:
from torch import Tensor
# 将输入词汇表的索引转化为指定维度的Embedding

class TokenEmbedding(nn.Embedding):
  def __init__(self,vocab_size,d_model):
    super(TokenEmbedding,self).__init__(vocab_size,d_model,padding_idx=1)

###Positional Embedding

In [16]:
class PositionalEmbedding(nn.Module):
  def __init__(self,d_model,max_len):
    super(PositionalEmbedding,self).__init__()
    self.encoding=torch.zeros(max_len,d_model)
    self.encoding.requires_grad=False
    pos=torch.arange(0,max_len)
    pos=pos.float().unsqueeze(dim=1)
    _2i=torch.arange(0,d_model,step=2).float()
    self.encoding[:,0::2]=torch.sin(pos/10000**(_2i/d_model))
    self.encoding[:,1::2]=torch.cos(pos/10000**(_2i/d_model))

  def forward(self,x):
    batch_size,seq_len=x.size()
    return self.encoding[:seq_len,:]

###TransformerEmbedding

In [4]:
class TransformerEmbedding(nn.Module):
  def __init__(self,vocab_size,d_model,max_len,drop_prob):
    super(TransformerEmbedding,self).__init__()
    self.tok_emb=TokenEmbedding(vocab_size,d_model)
    self.pos_emb=PositionalEmbedding(d_model,max_len)
    self.drop_prob=nn.Dropout(p=drop_prob)

  def forward(self,x):
    tok_emb=self.tok_emb(x)
    pos_emb=self.pos_emb(x)
    return self.drop_prob(tok_emb+pos_emb)

## Multi-Head Attention

In [5]:
x=torch.rand(128,32,512)

In [18]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model,n_head):
    super(MultiHeadAttention,self).__init__()
    self.n_head=n_head
    self.d_model=d_model
    self.w_q=nn.Linear(d_model,d_model)
    self.w_k=nn.Linear(d_model,d_model)
    self.w_v=nn.Linear(d_model,d_model)
    self.w_combine=nn.Linear(d_model,d_model)
    self.softmax=nn.Softmax(dim=-1)

  def forward(self,q,k,v,mask=None):
    batch,_,_=q.shape
    _,time_q,_=q.shape
    _,time_k,_=k.shape
    _,time_v,_=v.shape
    n_d=self.d_model//self.n_head
    q,k,v=self.w_q(q),self.w_k(k),self.w_v(v)
    q=q.view(batch,time_q,self.n_head,n_d).permute(0,2,1,3)
    k=k.view(batch,time_k,self.n_head,n_d).permute(0,2,1,3)
    v=v.view(batch,time_v,self.n_head,n_d).permute(0,2,1,3)
    score=q@k.transpose(-2,-1)/math.sqrt(n_d)
    if mask is not None:
      score=score.masked_fill(mask==0,float('-inf'))
    score=self.softmax(score)@v
    score=score.permute(0,2,1,3).contiguous().view(batch,time_q,self.d_model)
    out=self.w_combine(score)
    return out

##LayerNorm

In [7]:
class LayerNorm(nn.Module):
  def __init__(self,d_model,eps=1e-12):
    super(LayerNorm,self).__init__()
    self.gamma=nn.Parameter(torch.ones(d_model))
    self.beta=nn.Parameter(torch.zeros(d_model))
    self.eps=eps
  def forward(self,x):
    mean=x.mean(-1,keepdim=True)
    var=x.var(-1,unbiased=False,keepdim=True)
    out=(x-mean)/torch.sqrt(var+self.eps)
    out=self.gamma*out+self.beta
    return out

## Encoder

###PositionwiseFeedForward

In [8]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self,d_model,hidden,dropout=0.1):
    super(PositionwiseFeedForward,self).__init__()
    self.fc1=nn.Linear(d_model,hidden)
    self.fc2=nn.Linear(hidden,d_model)
    self.dropout=nn.Dropout(dropout)

  def forward(self,x):
    x=self.fc1(x)
    x=F.relu(x)
    x=self.dropout(x)
    x=self.fc2(x)
    return x


###EncoderLayer

In [9]:
class EncoderLayer(nn.Module):
  def __init__(self,d_model,ffn_hidden,n_head,dropout=0.1):
    super(EncoderLayer,self).__init__()
    self.attention=MultiHeadAttention(d_model,n_head)
    self.dropout1=nn.Dropout(dropout)
    self.norm1=LayerNorm(d_model)
    self.ffn=PositionwiseFeedForward(d_model,ffn_hidden,dropout)
    self.dropout2=nn.Dropout(dropout)
    self.norm2=LayerNorm(d_model)


  def forward(self,x,mask=None):
    _x=x
    x=self.attention(x,x,x,mask)
    x=self.dropout1(x)
    x=self.norm1(x+_x)
    _x=x
    x=self.ffn(x)
    x=self.dropout2(x)
    x=self.norm2(x+_x)
    return x

###Encoder

In [10]:
class Encoder(nn.Module):
  def __init__(self,enc_voc_size,max_len,d_model,ffn_hidden,n_head,n_layer,dropout=0.1):
    super(Encoder,self).__init__()
    self.embedding=TransformerEmbedding(enc_voc_size,d_model,max_len,dropout)
    self.layers=nn.ModuleList(
      [
        EncoderLayer(d_model,ffn_hidden,n_head,dropout)
        for _ in range(n_layer)
      ]
    )

  def forward(self,x,src_mask):
    x=self.embedding(x)
    for layer in self.layers:
      x=layer(x,src_mask)
    return x

##Decoder

###DecoderLayer

In [11]:
class DecoderLayer(nn.Module):
  def __init__(self,d_model,ffn_hidden,n_head,dropout=0.1):
    super(DecoderLayer,self).__init__()
    self.self_attention=MultiHeadAttention(d_model,n_head)
    self.dropout1=nn.Dropout(dropout)
    self.norm1=LayerNorm(d_model)
    self.cross_attention=MultiHeadAttention(d_model,n_head)
    self.dropout2=nn.Dropout(dropout)
    self.norm2=LayerNorm(d_model)
    self.ffn=PositionwiseFeedForward(d_model,ffn_hidden,dropout)
    self.dropout3=nn.Dropout(dropout)
    self.norm3=LayerNorm(d_model)

  def forward(self,x,memory,tar_mask,src_mask):
    _x=x
    x=self.self_attention(x,x,x,tar_mask)
    x=self.dropout1(x)
    x=self.norm1(x+_x)
    _x=x
    x=self.cross_attention(x,memory,memory,src_mask)
    x=self.dropout2(x)
    x=self.norm2(x+_x)
    _x=x
    x=self.ffn(x)
    x=self.dropout3(x)
    x=self.norm3(x+_x)
    return x

###Decoder

In [12]:
class Decoder(nn.Module):
  def __init__(self,dec_voc_size,max_len,d_model,ffn_hidden,n_head,n_layer,dropout=0.1):
    super(Decoder,self).__init__()
    self.embedding=TransformerEmbedding(dec_voc_size,d_model,max_len,dropout)
    self.layers=nn.ModuleList(
        [
            DecoderLayer(d_model,ffn_hidden,n_head,dropout)
            for _ in range(n_layer)
        ]
    )

  def forward(self,x,memory,tar_mask,src_mask):
    x=self.embedding(x)
    for layer in self.layers:
      x=layer(x,memory,tar_mask,src_mask)
    return x

## Transformer

### Generator

In [13]:
class Generator(nn.Module):
  def __init__(self,d_model,dec_voc_size):
    super(Generator,self).__init__()
    self.fc=nn.Linear(d_model,dec_voc_size)
  def forward(self,x):
    return self.fc(x)

### Transformer

In [14]:
class Transformer(nn.Module):
  def __init__(self,enc_voc_size,dec_voc_size,max_len,d_model,ffn_hidden,n_head,n_layer,dropout=0.1):
    super(Transformer,self).__init__()
    self.encoder=Encoder(enc_voc_size,max_len,d_model,ffn_hidden,n_head,n_layer,dropout)
    self.decoder=Decoder(dec_voc_size,max_len,d_model,ffn_hidden,n_head,n_layer,dropout)
    self.generator=Generator(d_model,dec_voc_size)

  def forward(self,src,tar,src_mask,tar_mask):
    memory=self.encoder(src,src_mask)
    out=self.decoder(tar,memory,tar_mask,src_mask)
    out=self.generator(out)
    return out

## Test

In [19]:
# 假设参数
batch = 2
src_len = 5
tgt_len = 6
src_vocab = 20
tgt_vocab = 30
d_model = 16

# 随机输入
src = torch.randint(0, src_vocab, (batch, src_len))
tar = torch.randint(0, tgt_vocab, (batch, tgt_len))

# 简单 mask (全1，表示不屏蔽任何位置)
src_mask = torch.ones((batch, 1, 1, src_len), dtype=torch.bool)
tar_mask = torch.ones((batch, 1, tgt_len, tgt_len), dtype=torch.bool)

# 初始化 Transformer
model = Transformer(src_vocab, tgt_vocab, max_len=10, d_model=d_model,
                    ffn_hidden=32, n_head=2, n_layer=2)

# 前向传播
out = model(src, tar, src_mask, tar_mask)

print("输出 shape:", out.shape)   # 预期: [batch, tgt_len, tgt_vocab]


输出 shape: torch.Size([2, 6, 30])
