In [None]:
import torch
from torch.nn import functional as F
from torch import nn

In [None]:
# 伪代码
class Encoder(nn.Module):
    def __init__(self,*args):
        super().__init__()

In [None]:
# 伪代码
class Decoder(nn.Module):
    def __init__(self,*args):
        super().__init__()

### 完整公式总结

$\mathbf{P} = \text{softmax}\left(
    \text{Decoder}\left(
        \mathbf{X}_{\text{trg}},
        \text{Encoder}(\mathbf{X}_{\text{src}}, \mathbf{M}_{\text{src}}),
        \mathbf{M}_{\text{trg}},
        \mathbf{M}_{\text{enc\_dec}}
    \right) \cdot \mathbf{W}_{\text{out}}
\right)$

其中各个掩码的定义如上所述。

---

### 连接点的核心公式

编码器和解码器连接的核心体现在解码器的交叉注意力层：

$\text{CrossAttention} = \text{softmax}\left(
    \frac{(\mathbf{H}_{\text{dec}}\mathbf{W}_{Q,\text{cross}})(\mathbf{H}_{\text{enc}}\mathbf{W}_{K,\text{cross}})^\top}{\sqrt{d_k}} + \mathbf{M}_{\text{enc\_dec}}
\right) \cdot (\mathbf{H}_{\text{enc}}\mathbf{W}_{V,\text{cross}})$

In [None]:
class Transformer(nn.Module):
    def __init__(self,src_pad_idx,trg_pad_idx,enc_voc_size,dec_voc_size,d_model,max_len,n_head,ffn_hidden,n_layer,drop_prob,device):
        super().__init__()
        # 定义decoder和encoder
        self.encoder=Encoder(enc_voc_size,max_len,d_model,ffn_hidden,n_head,n_layer,device,drop_prob)
        self.decoder=Decoder(dec_voc_size,max_len,d_model,ffn_hidden,n_head,n_layer,drop_prob,device)
        # 添加属性
        self.src_pad_idx=src_pad_idx
        self.trg_pad_idx=trg_pad_idx
        self.device=device
    # 生成mask的函数
    def make_pad_mask(self,q:torch.Tensor,k:torch.Tensor,pad_idx_q,pad_idx_k):
        len_q,len_k=q.size(1),k.size(1)
        q=q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3) # q的shape是(batch_size,1,len_q,1)
        q=q.repeat(1,1,1,len_k) # q的shape是(batch_size,1,len_q,len_k)
        k=k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2) # k的shape是(batch_size,1,1,len_k)
        k=k.repeat(1,1,len_q,1) # k的shape是(batch_size,1,len_q,len_k)
        mask=q&k # mask的shape为(batch_size,1,len_q,len_k)
        return mask
    # 构建因果掩码
    def make_casual_mask(self,q:torch.Tensor,k:torch.Tensor):
        len_q,len_k=q.size(1),k.size(1)
        batch_size=q.size(0)
        mask=torch.tril(torch.ones(size=(len_q,len_k)).to(dtype=torch.bool,device=torch.device(self.device)).unsqueeze(0).unsqueeze(1).repeat(batch_size,1,1,1))
        return mask # mask的shape为(len_q,len_k)
    def forward(self,src,trg):
        src_mask=self.make_pad_mask(src,src,self.src_pad_idx,self.src_pad_idx)
        trg_mask=self.make_pad_mask(trg,trg,self.trg_pad_idx,self.trg_pad_idx)&self.make_casual_mask(trg,trg)
        enc=self.encoder(src,src_mask)
        out=self.decoder(trg,enc,trg_mask,src_mask)
        return out