In [2]:
import torch
from torch import nn
from torch.nn import functional as F

In [3]:
# 伪代码
class LayerNorm(nn.Module):
    def __init__(self,*args):
        super().__init__()

In [7]:
# 伪代码
class MultiHeadAttention(nn.Module):
    def __init__(self,*args):
        super().__init__()

In [10]:
# 伪代码
class TransformerEmbedding(nn.Module):
    def __init__(self,*args):
        super().__init__()

$$
\text{FFN}(\mathbf{x}) = \max(0, \mathbf{x}W_1 + b_1)W_2 + b_2
$$

* $W_1 \in \mathbb{R}^{d_{model} \times d_{ff}}$
* $W_2 \in \mathbb{R}^{d_{ff} \times d_{model}}$
* $b_1 \in \mathbb{R}^{d_{ff}}, \; b_2 \in \mathbb{R}^{d_{model}}$
* $\max(0, \cdot)$ 表示 **ReLU 激活函数**（现在也常用 GELU）

In [5]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self,d_model,hidden,dropout=0.1):
        super().__init__()
        self.d_model=d_model
        self.hidden=hidden
        self.dropout=dropout
        self.fc1=nn.Linear(in_features=self.d_model,out_features=self.hidden)
        self.fc2=nn.Linear(in_features=self.hidden,out_features=self.d_model)
        self.dropout=nn.Dropout(p=self.dropout)
    def forward(self,x):
        x=self.fc1(x)
        x=F.relu(x)
        x=self.dropout(x)
        x=self.fc2(x)
        return x

1. **自注意力子层**

$$
\mathbf{z} = \text{LayerNorm}(\mathbf{x} + \text{MultiHeadAttention}(\mathbf{x}, \mathbf{x}, \mathbf{x}))
$$

2. **前馈子层**

$$
\mathbf{y} = \text{LayerNorm}(\mathbf{z} + \text{FFN}(\mathbf{z}))
$$

最终输出 $\mathbf{y}$ 与输入 $\mathbf{x}$ 维度相同： `(batch_size, seq_len, d_model)`。


In [6]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model,ffn_hidden,n_head,dropout=0.1):
        super().__init__()
        self.d_model=d_model
        self.ffn_hidden=ffn_hidden
        self.n_head=n_head
        self.dropout_rate=dropout
        self.attention=MultiHeadAttention(d_model,n_head)
        self.normal=LayerNorm(d_model)
        self.dropout=nn.Dropout(p=self.dropout_rate)
        self.ffn=PositionWiseFeedForward(d_model,ffn_hidden,self.dropout_rate)
    def forward(self,x:torch.Tensor):
        # x_copy=torch.clone(x)
        z=self.normal(x+self.attention(x,x,x))
        z=self.dropout(z)
        y=self.normal(z+self.dropout(self.ffn(z)))
        return y

## Encoder 的结构

Transformer 的 **Encoder** 是由 **N 层 EncoderLayer 堆叠**而成。

在 *Attention is All You Need (2017)* 中：

* 输入首先经过 **词嵌入 (Embedding)** 和 **位置编码 (Positional Encoding)**；
* 然后送入 **N 层 EncoderLayer**（论文里 $N=6$）；
* 最终得到每个位置的上下文表示。

In [None]:
class Encoder(nn.Module):
    def __init__(self,enc_voc_size,max_len,d_model,ffn_hidden,n_head,n_layer,device,dropout=0.1):
        super().__init__()
        self.enc_voc_size=enc_voc_size
        self.max_len=max_len
        self.d_model=d_model
        self.ffn_hidden=ffn_hidden
        self.n_head=n_head
        self.n_layer=n_layer
        self.device=device
        self.dropout_rate=dropout
        self.embedding=TransformerEmbedding(self.enc_voc_size,self.max_len,self.d_model,dropout)
        self.layers=nn.ModuleList([
            EncoderLayer(self.d_model,self.ffn_hidden,self.n_head,self.device) for _ in range(n_layer)
        ])
    def forward(self,x):
        x=self.embedding(x)
        for layer in self.layers:
            x=layer(x)
        return x