# Transformer 

逐步实现 Transformer 完整架构：

1. 实现主体网络, 用 embedding 层 和 Linear 层代替关键组件
2. 实现输入层
3. 实现 encoder
4. 实现 decoder
5. 实现 输出层
6. 数据集制作
7. 训练代码, 收敛
8. 推理代码
9. 模型保存
10. 模型加载

## Config 

In [1]:
# for debug 

dim = 512 
num_layers = 6
heads = 8

batch_size = 2
src_len = 256
trg_len = 128
max_len = 512

src_vocab_size = 100
trg_vocab_size = 200

IGNORE_INDEX = -100
PAD_TOKEN_ID = 0
SOS_TOKEN_ID = 1
EOS_TOKEN_ID = 2

## Transformer 

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
torch.manual_seed(42)

<torch._C.Generator at 0x11efb8eb0>

In [3]:
class TransformerBasic(nn.Module):
    """
    仅通过 Embedding 和 Lienar 实现 Transformer 计算逻辑
    输入:[bs, src_seq_len]
    输出:[bs, trg_seq_len, vocab_size]
    """
    def __init__(self, src_vocab_size = 100, trg_vocab_size = 200, dim = 512, num_layers = 6, heads = 8, max_len = 512):
        super().__init__()
        self.encoder_input = nn.Embedding(src_vocab_size, dim)
        self.encoder = nn.Linear(dim, dim)
        
        self.decoder_input = nn.Embedding(trg_vocab_size, dim)
        self.decoder = nn.Linear(dim, dim)

        self.output_layer = nn.Linear(dim, trg_vocab_size)
        
    def forward(self, src_ids, trg_ids, src_mask = None, trg_mask = None, src_trg_mask = None):
        X = self.encoder_input(src_ids)
        X_src = self.encoder(X)
        print('encoder output:\t', X_src.shape)

        Y = self.decoder_input(trg_ids)
        Y = self.decoder(Y) + X_src.mean(dim = 1, keepdim = True)
        print('decoder output:\t', Y.shape)

        logits = self.output_layer(Y) 
        prob = F.softmax(logits, dim = -1)
        
        return logits, prob
    
model = TransformerBasic()
print(model)


src_ids = torch.randint(src_vocab_size, (batch_size, src_len))
trg_ids = torch.randint(trg_vocab_size, (batch_size, trg_len))
print('encode input shape: ', src_ids.shape)
print('decode output shape: ', trg_ids.shape)

logits, _ = model(src_ids, trg_ids)
print('transformer output:\t', logits.shape)

TransformerBasic(
  (encoder_input): Embedding(100, 512)
  (encoder): Linear(in_features=512, out_features=512, bias=True)
  (decoder_input): Embedding(200, 512)
  (decoder): Linear(in_features=512, out_features=512, bias=True)
  (output_layer): Linear(in_features=512, out_features=200, bias=True)
)
encode input shape:  torch.Size([2, 256])
decode output shape:  torch.Size([2, 128])
encoder output:	 torch.Size([2, 256, 512])
decoder output:	 torch.Size([2, 128, 512])
transformer output:	 torch.Size([2, 128, 200])


- Decoder 输入 `bs x trg_len` 与 输出 ` bs x trg_len x trg_vocab_size`, 一个序列输出 `trg_len` 个 概率分布

## Transformer Input Layer

In [4]:
torch.arange(0, 10, 2)
a = torch.tensor([1,2,3]) # a is row
b = torch.tensor([2,2,2]) # b is col
torch.outer(a, b)

tensor([[2, 2, 2],
        [4, 4, 4],
        [6, 6, 6]])

In [5]:
class TransformerInputLayer(nn.Module):
    """
    词向量 + 位置编码
    """
    def __init__(self, vocab_size = 100, dim = 512, max_len = 1024, base = 10000.0):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, dim)
        self.position_encoding = nn.Parameter( torch.randn(max_len, dim) ) # learnable
        self.max_len = max_len

        # sin-cos position encoding
        # 人工设计的 位置编码
        group = dim // 2
        theta_ids = torch.arange(0, dim, 2) # 0, 2, 4, ..., 512
        theta =  1 / ( base ** ( theta_ids / dim ) )
        pe = torch.zeros(dim) # 512, sin( theta_0 ),cos( theta_0), ...
        pe[theta_ids] = theta
        pe[theta_ids+1] = theta

        position_ids = torch.arange(0, max_len) # 0, 1, 2, ..., 1024
        self.PE = torch.outer(position_ids, pe) # 1024 x 512
        
        self.PE[:, theta_ids] = torch.sin(self.PE[:, theta_ids])
        self.PE[:, theta_ids+1] = torch.sin(self.PE[:, theta_ids+1])

    def forward_NoPE(self, input_ids):
        """
        嵌入向量 + 无位置编码
        """
        X = self.embedding(input_ids)
        return X

    def forward_basic(self, input_ids):
        """
        嵌入向量 + 常数向量位置编码
        """
        bs, seq_len = input_ids.shape
        X = self.embedding(input_ids)
        PE = torch.arange(seq_len).unsqueeze(dim = 0).unsqueeze(dim = 2)
        X_ = X + PE / self.max_len
        return X_

    def forward_learn(self, input_ids):
        """
        嵌入向量 + 可学习位置编码
        """
        bs, seq_len = input_ids.shape
        X = self.embedding(input_ids)
        PE = self.position_encoding[:seq_len, :]
        X_ = X + PE
        return X_

    def forward(self, input_ids):
        """
        嵌入向量 + 绝对位置编码(标准实现)
        """
        bs, seq_len = input_ids.shape
        X = self.embedding(input_ids)
        PE = self.PE[:seq_len, :]
        X_ = X + PE
        return X_

input_layer = TransformerInputLayer(vocab_size = src_vocab_size, dim = 6)
print('NoPE: ', input_layer.forward_NoPE(src_ids[:1, :3]))
print('Constant: ', input_layer.forward_basic(src_ids[:1, :3]))
print('Learnable: ', input_layer.forward_learn(src_ids[:1, :3]))
print('sin-cos-PE: ', input_layer.forward(src_ids[:1, :3]))

NoPE:  tensor([[[-0.5031, -0.2963, -1.6388, -1.1080,  0.0778,  0.0641],
         [ 0.2451,  0.6976, -0.2715,  0.9356,  1.8059, -0.5141],
         [ 0.0044,  0.1593,  0.0037, -0.5692, -0.6155, -0.1125]]],
       grad_fn=<EmbeddingBackward0>)
Constant:  tensor([[[-0.5031, -0.2963, -1.6388, -1.1080,  0.0778,  0.0641],
         [ 0.2461,  0.6986, -0.2705,  0.9366,  1.8068, -0.5131],
         [ 0.0063,  0.1613,  0.0056, -0.5672, -0.6135, -0.1106]]],
       grad_fn=<AddBackward0>)
Learnable:  tensor([[[ 0.3895,  0.2742, -0.8089, -1.9987, -0.0772,  1.6185],
         [-1.3515,  0.7869,  1.7763,  2.6277,  1.2768, -0.4698],
         [ 0.9530,  0.7793, -1.3005, -2.0253, -0.2677, -2.4765]]],
       grad_fn=<AddBackward0>)
Sin-cos-PE:  tensor([[[-0.5031, -0.2963, -1.6388, -1.1080,  0.0778,  0.0641],
         [ 1.0866,  1.5391, -0.2251,  0.9820,  1.8080, -0.5119],
         [ 0.9137,  1.0686,  0.0964, -0.4765, -0.6112, -0.1082]]],
       grad_fn=<AddBackward0>)


## Transformer Encoder

- 归一化层
- 多头注意力层
- 前馈层
- 残差链接

### LayerNorm

In [6]:
class LayerNorm(nn.Module):
    def __init__(self, dim, ):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(dim))
        self.beta = nn.Parameter(torch.zeros(dim))
        self.epsilon = 1e-8
    def forward(self, X, ):
        mu = X.mean( dim = -1, keepdim = True)
        var = X.var( dim = -1, keepdim = True)
        X_hat = ( X - mu ) / torch.sqrt( var + self.epsilon)
        Y = X_hat * self.gamma + self.beta
        return Y
        
tmp_dim = 4
LN = LayerNorm(dim = tmp_dim)
print(LN)
X = torch.randn(2,3,4)
print(LN(X))

LayerNorm()
tensor([[[-0.6895,  0.2339, -0.8599,  1.3155],
         [-0.9774, -0.7341,  0.9996,  0.7119],
         [ 1.4324, -0.8155, -0.0928, -0.5241]],

        [[ 0.4565,  0.5562, -1.4987,  0.4860],
         [ 0.5503,  1.1316, -0.8761, -0.8057],
         [ 1.4116, -0.2015, -0.2627, -0.9475]]], grad_fn=<AddBackward0>)


### Feed Forward Network (FFN)

In [7]:
class FeedForwardNetwork(nn.Module):
    def __init__(self, dim, ):
        super().__init__()
        self.dim = dim
        self.W_up = nn.Linear(self.dim, 4 * self.dim)
        self.ReLU = nn.ReLU()
        self.W_down = nn.Linear(4 * self.dim, self.dim)
    def forward(self, X):
        X_ = self.ReLU(self.W_up(X))
        Y = self.W_down(X_)
        return Y
        
FFN = FeedForwardNetwork(dim = 4)
X = torch.randn(2,3,4)
print(FFN)
print(FFN(X))

FeedForwardNetwork(
  (W_up): Linear(in_features=4, out_features=16, bias=True)
  (ReLU): ReLU()
  (W_down): Linear(in_features=16, out_features=4, bias=True)
)
tensor([[[-0.3446,  0.1762, -0.4142, -0.0142],
         [-0.4503,  0.0020, -0.2082,  0.2073],
         [-0.4478,  0.4252, -0.6522, -0.2090]],

        [[-0.5538, -0.0942, -0.1867,  0.0850],
         [-0.0023, -0.1210, -0.1883,  0.1941],
         [-0.0823, -0.0515, -0.1285,  0.3077]]], grad_fn=<ViewBackward0>)


### Multi Heads Attention

self-attention

In [8]:
import math

class MultiHeadScaleDotProductAttention(nn.Module):
    def __init__(self, dim_in, dim_out, heads = 8):
        super().__init__()
        self.WQ = nn.Linear(dim_in, dim_out)
        self.WK = nn.Linear(dim_in, dim_out)
        self.WV = nn.Linear(dim_in, dim_out)
        self.WO = nn.Linear(dim_in, dim_out)
        self.heads = 8
        self.head_dim = dim_out // self.heads
        
    def forward(self, X_Q, X_K, X_V, mask = None):
        bs, seq_len, dim = X_Q.shape
        bs, seq_K_len, dim = X_K.shape
        bs, seq_V_len, dim = X_V.shape
        Q = self.WQ(X_Q)
        K = self.WK(X_K)
        V = self.WV(X_V)

        # 拆分维度
        Q_h = Q.view(bs, seq_len, self.heads, self.head_dim).transpose(1,2)
        K_h = K.view(bs, seq_K_len, self.heads, self.head_dim).transpose(1,2) # KV len 可以不等同于 Q len
        V_h = V.view(bs, seq_V_len, self.heads, self.head_dim).transpose(1,2)

        # 多个 q_i 计算注意力特征
        S = Q_h @ K_h.transpose(2,3) / math.sqrt(self.head_dim) # 1. 为什么要除于 \sqrt{d}

        if mask is not None:
            idx = torch.where(mask == 0)
            # 维度 “：” 表示头并行共享 mask 矩阵
            S[idx[0], :,idx[1],idx[2]] = -10000.0 
        
        P = torch.softmax(S, dim = -1) # 行 softmax
        Z = P @ V_h

        # 恢复维度
        Z = Z.transpose(1,2).reshape(bs, seq_len, dim)
        
        output = self.WO(Z)
        
        return output
tmp_dim = 16
X = torch.randn(2, 4, tmp_dim)
mask = torch.ones(2, 4, tmp_dim)
model = MultiHeadScaleDotProductAttention(tmp_dim, tmp_dim, 8)
Y = model(X, X, X, mask)
print(X.shape, Y.shape)

torch.Size([2, 4, 16]) torch.Size([2, 4, 16])


### Encoder Block

In [9]:
class TransformerEncoderBlock(nn.Module):
    def __init__(self, dim = 512, heads = 8):
        super().__init__()
        self.attn = MultiHeadScaleDotProductAttention(dim, dim, heads)
        self.ln1 = LayerNorm(dim)
        self.ffn = FeedForwardNetwork(dim)
        self.ln2 = LayerNorm(dim)
        
    def forward(self, X, src_mask = None):
        X_attn = self.attn(X, X, X, mask = src_mask)
        X_ln = self.ln1(X_attn)
        X = X + X_ln

        X_ffn = self.ffn(X)
        X_ln = self.ln2(X_ffn)
        X = X + X_ln

        return X

tmp_dim = 16
X = torch.randn(2, 4, tmp_dim)
mask = torch.ones(2, 4, tmp_dim)
model = TransformerEncoderBlock(tmp_dim, 8)
Y = model(X, mask)
print(X.shape, Y.shape)

torch.Size([2, 4, 16]) torch.Size([2, 4, 16])


In [10]:
class TransformerEncoder(nn.Module):
    """
    输入 原文本序列，输出 token 序列的编码表征
    输入:[bs, src_seq_len, dim]
    输出:[bs, src_seq_len, dim]
    """
    def __init__(self, dim = 512, num_layers = 6, heads = 8):
        super().__init__()
        # self.encoder = nn.Linear(dim, dim) 
        self.encoder = nn.ModuleList(
            [TransformerEncoderBlock(dim, heads) for i in range(num_layers)]
        )
    def forward(self, X, mask = None):
        for encode_block in self.encoder:
            X = encode_block(X, mask)
        return X
        
tmp_dim = 16
X = torch.randn(2, 4, tmp_dim)
mask = torch.ones(2, 4, tmp_dim)
model = TransformerEncoder(tmp_dim, tmp_dim, 8)
Y = model(X, mask)
print(X.shape, Y.shape)

torch.Size([2, 4, 16]) torch.Size([2, 4, 16])


In [11]:
print(model)

TransformerEncoder(
  (encoder): ModuleList(
    (0-15): 16 x TransformerEncoderBlock(
      (attn): MultiHeadScaleDotProductAttention(
        (WQ): Linear(in_features=16, out_features=16, bias=True)
        (WK): Linear(in_features=16, out_features=16, bias=True)
        (WV): Linear(in_features=16, out_features=16, bias=True)
        (WO): Linear(in_features=16, out_features=16, bias=True)
      )
      (ln1): LayerNorm()
      (ffn): FeedForwardNetwork(
        (W_up): Linear(in_features=16, out_features=64, bias=True)
        (ReLU): ReLU()
        (W_down): Linear(in_features=64, out_features=16, bias=True)
      )
      (ln2): LayerNorm()
    )
  )
)


### Encoder Mask detail

In [12]:
def get_src_mask(input_ids, pad_token_id = 0):
    bs, seq_len = input_ids.shape
    mask = torch.ones(bs, seq_len, seq_len)
    for i in range(bs):
        pad_idx =  torch.where(input_ids[i, :]  == pad_token_id)[0]
        mask[i, pad_idx, :] = 0
        mask[i, :, pad_idx] = 0
    return mask
    
input_ids = torch.tensor([[1, 2, 3, 0, 0],
                          [1, 2, 0, 0, 0]], dtype = torch.long) # 0 is pad
bs, seq_len = input_ids.shape
mask = get_src_mask(input_ids, pad_token_id = PAD_TOKEN_ID)
print(mask)

score = torch.randn(bs, seq_len, seq_len)
score * mask

tensor([[[1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 1., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])


tensor([[[-0.5355,  1.0730,  0.4873, -0.0000,  0.0000],
         [ 0.4597, -0.0714,  1.8693,  0.0000,  0.0000],
         [-1.1784, -0.6136, -0.9517, -0.0000, -0.0000],
         [-0.0000, -0.0000, -0.0000,  0.0000, -0.0000],
         [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000]],

        [[ 1.5798, -1.2041,  0.0000, -0.0000, -0.0000],
         [-2.3888,  0.3134, -0.0000,  0.0000,  0.0000],
         [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
         [-0.0000,  0.0000,  0.0000, -0.0000,  0.0000],
         [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000]]])

In [13]:
multi_head_score = torch.randn(bs, 2, seq_len, seq_len) # multi-head score
multi_head_score * mask.unsqueeze(1)

tensor([[[[-2.0993e+00, -6.1178e-01,  1.5540e+00,  0.0000e+00, -0.0000e+00],
          [ 4.3575e-01,  8.6956e-01, -1.0213e-01, -0.0000e+00, -0.0000e+00],
          [ 4.4461e-01,  1.5828e+00,  2.7130e-01,  0.0000e+00, -0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -0.0000e+00],
          [-0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00]],

         [[ 8.1156e-01,  9.3808e-02,  4.4381e-01,  0.0000e+00, -0.0000e+00],
          [-1.0390e+00,  1.7692e-01,  1.1105e+00, -0.0000e+00,  0.0000e+00],
          [ 9.5825e-01, -3.1649e-04,  7.1094e-01,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00, -0.0000e+00],
          [-0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000e+00, -0.0000e+00]]],


        [[[-7.9658e-01, -1.1100e+00,  0.0000e+00, -0.0000e+00,  0.0000e+00],
          [-3.5486e-01,  1.4551e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00, -0.0000e+00, -0.0000e+00, -0.0000e+00,  0.0000

## Transformer Decoder

In [14]:
class TransformerDecoderBlock(nn.Module):
    def __init__(self, dim = 512, heads = 8):
        super().__init__()
        self.masked_attn = MultiHeadScaleDotProductAttention(dim, dim, heads)
        self.ln1 = LayerNorm(dim)
    
        self.cross_attn = MultiHeadScaleDotProductAttention(dim, dim, heads)
        self.ln2 = LayerNorm(dim)
        
        self.ffn = FeedForwardNetwork(dim)
        self.ln3 = LayerNorm(dim)
        
    def forward(self, X, X_src, trg_mask = None, src_trg_mask = None):
        X_attn = self.masked_attn(X, X, X, trg_mask)
        X_ln = self.ln1(X_attn)
        X = X + X_ln
        
        X_attn = self.cross_attn(X, X_src, X_src, src_trg_mask)
        X_ln = self.ln2(X_attn)
        X = X + X_ln

        X_ffn = self.ffn(X)
        X_ln = self.ln3(X_ffn)
        X = X + X_ln

        return X

tmp_dim = 16
X = torch.randn(2, 4, tmp_dim)
X_src = torch.randn(2, 8, tmp_dim)
mask = torch.ones(2, 4, tmp_dim)
model = TransformerDecoderBlock(tmp_dim, 8)
Y = model(X, X_src)
print(X.shape, Y.shape)

torch.Size([2, 4, 16]) torch.Size([2, 4, 16])


In [15]:
class TransformerDecoder(nn.Module):
    """
    输入:[bs, trg_seq_len, dim]
    输出:[bs, trc_seq_len, dim]
    """
    def __init__(self, dim = 512, num_layers = 6, heads = 8):
        super().__init__()
        # self.decoder = nn.Linear(dim, dim) 
        self.decoder = nn.ModuleList(
            [TransformerDecoderBlock(dim, heads) for i in range(num_layers)]
        )
    def forward(self, X, X_src, trg_mask = None, src_trg_mask = None):
        for decoder_block in self.decoder:
            X = decoder_block(X, X_src, trg_mask, src_trg_mask)
        return X
        
tmp_dim = 16
X = torch.randn(2, 4, tmp_dim)
X_src = torch.randn(2, 3, tmp_dim)
model = TransformerDecoder(tmp_dim)
Y = model(X, X_src, )
print(X.shape, Y.shape)

torch.Size([2, 4, 16]) torch.Size([2, 4, 16])


### masked-self-attention mask detail

In [16]:
def get_trg_mask(input_ids, pad_token_id = 0):
    bs, seq_len = input_ids.shape
    mask = torch.tril(torch.ones(bs, seq_len, seq_len)) # tril
    for i in range(bs):
        pad_idx =  torch.where(input_ids[i, :]  == pad_token_id)[0]
        mask[i, pad_idx, :] = 0
        mask[i, :, pad_idx] = 0
    return mask
    
input_ids = torch.tensor([[1, 2, 3, 0, 0],
                          [1, 2, 0, 0, 0]], dtype = torch.long) # 0 is pad
mask = get_trg_mask(input_ids, pad_token_id = PAD_TOKEN_ID)
print(mask)

score = torch.randn(bs, seq_len, seq_len)
score * mask

tensor([[[1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 0., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])


tensor([[[ 0.3775, -0.0000, -0.0000, -0.0000, -0.0000],
         [-0.2391,  0.4032, -0.0000, -0.0000,  0.0000],
         [ 0.2447,  0.4989,  1.8628,  0.0000,  0.0000],
         [ 0.0000,  0.0000, -0.0000,  0.0000, -0.0000],
         [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000]],

        [[ 1.7257, -0.0000, -0.0000,  0.0000,  0.0000],
         [ 0.3218, -1.6551,  0.0000, -0.0000, -0.0000],
         [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
         [ 0.0000, -0.0000,  0.0000, -0.0000,  0.0000],
         [-0.0000, -0.0000,  0.0000, -0.0000,  0.0000]]])

In [17]:
multi_head_score = torch.randn(bs, 2, seq_len, seq_len) # multi-head score
multi_head_score * mask.unsqueeze(1)

tensor([[[[-0.5273,  0.0000,  0.0000, -0.0000, -0.0000],
          [ 0.3483,  1.0282,  0.0000,  0.0000,  0.0000],
          [ 1.1696, -2.3923,  1.0051, -0.0000, -0.0000],
          [ 0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000, -0.0000]],

         [[-0.3363,  0.0000,  0.0000, -0.0000,  0.0000],
          [-0.2155, -0.9674,  0.0000,  0.0000, -0.0000],
          [ 0.2109, -0.2376,  0.9523, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000,  0.0000,  0.0000],
          [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000]]],


        [[[ 2.1990, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 0.7963, -0.1189, -0.0000, -0.0000, -0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000,  0.0000],
          [ 0.0000,  0.0000, -0.0000, -0.0000, -0.0000],
          [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000]],

         [[-0.9682, -0.0000,  0.0000,  0.0000,  0.0000],
          [ 1.6906,  0.7477, -0.0000, -0.0000,  0.0000],
          [-0.0000, -0.

### cross-attention mask detail


In [18]:
def get_src_trg_mask(src_ids, trg_ids, pad_token_id = 0):
    bs, src_seq_len = src_ids.shape
    bs, trg_seq_len = trg_ids.shape
    
    mask = torch.ones(bs, trg_seq_len, src_seq_len) # tril
    for i in range(bs):
        src_pad_idx =  torch.where(src_ids[i, :]  == pad_token_id)[0]
        trg_pad_idx =  torch.where(trg_ids[i, :]  == pad_token_id)[0]
        mask[i, trg_pad_idx, :] = 0
        mask[i, :, src_pad_idx] = 0
    return mask
    
src_ids = torch.tensor([[1, 2, 3, 0, 0],
                          [1, 2, 0, 0, 0]], dtype = torch.long) # 0 is pad

trg_ids = torch.tensor([[4, 5, 0, 0, ],
                          [4, 5, 6, 0, ]], dtype = torch.long) # 0 is pad

mask = get_src_trg_mask(src_ids, trg_ids, PAD_TOKEN_ID)
print(mask)

score = torch.randn(bs, 4, 5)
score * mask

# 头并行同理

tensor([[[1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 1., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [1., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])


tensor([[[ 0.0457,  0.9644, -0.8429, -0.0000,  0.0000],
         [-0.8643, -1.7688, -0.9666, -0.0000, -0.0000],
         [-0.0000, -0.0000, -0.0000, -0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0000]],

        [[ 0.3906, -0.6990,  0.0000,  0.0000,  0.0000],
         [-2.5454, -1.3998,  0.0000, -0.0000, -0.0000],
         [ 0.8878,  0.9592,  0.0000, -0.0000,  0.0000],
         [-0.0000,  0.0000, -0.0000, -0.0000,  0.0000]]])

## Transformer output layer

In [19]:
class TransformerOutputLayer(nn.Module):
    """
    """
    def __init__(self, vocab_size = 100, dim = 512):
        super().__init__()
        self.lm_head = nn.Linear(dim, vocab_size)
        self.softmax = nn.Softmax(dim = -1)

    def forward(self, X):
        logits = self.lm_head(X)
        prob = self.softmax(logits)
        return logits

X = torch.randn(2, 8, tmp_dim)
model = TransformerOutputLayer(trg_vocab_size, tmp_dim)
print(model)
Y = model(X)
print(X.shape, Y.shape)

TransformerOutputLayer(
  (lm_head): Linear(in_features=16, out_features=200, bias=True)
  (softmax): Softmax(dim=-1)
)
torch.Size([2, 8, 16]) torch.Size([2, 8, 200])


## Loss Function

In [20]:
labels = torch.tensor([[1, 2, 1],
                          [1, 0, 0]], dtype = torch.long) # 0 is pad
logits = torch.randn(2, 3, 3)

print(logits)
logprob = torch.log_softmax(logits, dim = -1) # why use `log_softmax()` instead `softmax()`
logprob = torch.gather(logprob, dim = -1, index = labels.unsqueeze(dim = -1))
print(logprob)

mask = torch.ones(2, 3, 1)
mask[ torch.where(labels == PAD_TOKEN_ID) ] = 0
print(mask)

torch.mean(logprob * mask)

tensor([[[-0.1794, -1.7957, -0.6651],
         [-0.2517, -0.2236, -1.9435],
         [-1.0218,  0.3721,  1.4052]],

        [[-2.0408,  0.0177,  0.4979],
         [-0.0796, -1.9046,  0.4364],
         [ 1.5434,  0.7333, -0.1994]]])
tensor([[[-2.2118],
         [-2.4860],
         [-1.4007]],

        [[-1.0094],
         [-1.0426],
         [-0.4823]]])
tensor([[[1.],
         [1.],
         [1.]],

        [[1.],
         [0.],
         [0.]]])


tensor(-1.1847)

In [21]:
def CrossEntropyLoss(logits, labels, ignore_index = 0):
    """
    logits: bs x seq_len x num_classes
    label: bs x seq_len
    return: loss
    """
    bs, seq_len, num_classes = logits.shape
    log_prob = F.log_softmax(logits, dim = -1) # vec
    log_prob = torch.gather(log_prob, dim = -1, index = labels.unsqueeze(dim = -1)) # point
    
    if ignore_index is not None:
        mask = torch.ones(bs, seq_len, 1)
        mask[torch.where(labels == ignore_index)] = 0
    loss = torch.mean(log_prob * mask)
    return loss

trg_ids = torch.tensor([[4, 5, 0, 0, ],
                          [4, 5, 6, 0, ]], dtype = torch.long) # 0 is pad

logits = torch.randn(2, 4, trg_vocab_size)
print(logits.shape)

print(trg_ids[:, 1:])
print(logits[:, :-1, 0])

loss = CrossEntropyLoss(logits, labels = trg_ids, ignore_index = 0)
print(loss)

torch.Size([2, 4, 200])
tensor([[5, 0, 0],
        [5, 6, 0]])
tensor([[ 1.1350,  0.4015, -0.9549],
        [ 1.2343, -1.5474,  0.1226]])
tensor(-3.6103)


# Transformer New

In [22]:
class Transformer(nn.Module):
    """
    输入:[bs, src_seq_len]
    输出:[bs, trg_seq_len, vocab_size]
    """
    def __init__(self, src_vocab_size = 100, trg_vocab_size = 200, dim = 512, num_layers = 6, heads = 8, max_len = 512):
        super().__init__()
        #self.encoder_input = nn.Embedding(src_vocab_size, dim)
        self.encoder_input = TransformerInputLayer(vocab_size = src_vocab_size, 
                                                   dim = dim, 
                                                   max_len = max_len, )
        # self.encoder = nn.Linear(dim, dim)
        self.encoder = TransformerEncoder(dim = dim, 
                                          num_layers = num_layers, 
                                          heads = heads)
        
        # self.decoder_input = nn.Embedding(trg_vocab_size, dim)
        self.decoder_input = TransformerInputLayer(vocab_size = trg_vocab_size, 
                                                   dim = dim, 
                                                   max_len = max_len, )
        # self.decoder = nn.Linear(dim, dim)
        self.decoder = TransformerDecoder(dim = dim, 
                                          num_layers = num_layers, 
                                          heads = heads)

        # self.output_layer = nn.Linear(dim, trg_vocab_size)
        self.output_layer = TransformerOutputLayer(vocab_size = trg_vocab_size, 
                                                   dim = dim)
        
    def forward(self, src_ids, trg_ids, src_mask = None, trg_mask = None, src_trg_mask = None):
        X = self.encoder_input(src_ids)
        X_src = self.encoder(X, src_mask)
        # print('encoder output:\t', X_src.shape)

        Y = self.decoder_input(trg_ids)
        # Y = self.decoder(Y) + X_src.mean(dim = 1, keepdim = True)
        Y = self.decoder(Y, X_src, trg_mask = trg_mask, src_trg_mask = src_trg_mask)
        # print('decoder output:\t', Y.shape)

        logits = self.output_layer(Y) 
        prob = F.softmax(logits, dim = -1)
        
        return logits, prob
    
model = Transformer()
print(model)

Transformer(
  (encoder_input): TransformerInputLayer(
    (embedding): Embedding(100, 512)
  )
  (encoder): TransformerEncoder(
    (encoder): ModuleList(
      (0-5): 6 x TransformerEncoderBlock(
        (attn): MultiHeadScaleDotProductAttention(
          (WQ): Linear(in_features=512, out_features=512, bias=True)
          (WK): Linear(in_features=512, out_features=512, bias=True)
          (WV): Linear(in_features=512, out_features=512, bias=True)
          (WO): Linear(in_features=512, out_features=512, bias=True)
        )
        (ln1): LayerNorm()
        (ffn): FeedForwardNetwork(
          (W_up): Linear(in_features=512, out_features=2048, bias=True)
          (ReLU): ReLU()
          (W_down): Linear(in_features=2048, out_features=512, bias=True)
        )
        (ln2): LayerNorm()
      )
    )
  )
  (decoder_input): TransformerInputLayer(
    (embedding): Embedding(200, 512)
  )
  (decoder): TransformerDecoder(
    (decoder): ModuleList(
      (0-5): 6 x TransformerDecode

In [23]:
src_ids = torch.randint(src_vocab_size, (batch_size, src_len))
trg_ids = torch.randint(trg_vocab_size, (batch_size, trg_len))
print('encode input shape: ', src_ids.shape)
print('decode output shape: ', trg_ids.shape)

src_mask = get_src_mask(src_ids)
trg_mask = get_trg_mask(trg_ids)
src_trg_mask = get_src_trg_mask(src_ids, trg_ids)

logits, _ = model(src_ids, trg_ids, src_mask = src_mask, trg_mask = trg_mask, src_trg_mask = src_trg_mask)
print('transformer output:\t', logits.shape)

encode input shape:  torch.Size([2, 256])
decode output shape:  torch.Size([2, 128])
transformer output:	 torch.Size([2, 128, 200])


## Train config

In [24]:
dim = 512 
num_layers = 6
heads = 8

batch_size = 2
src_len = 256
trg_len = 128
max_len = 512

src_vocab_size = 100
trg_vocab_size = 200

N_train = 1024 # dataset
N_test = 128

IGNORE_INDEX = -100
PAD_TOKEN_ID = 0
SOS_TOKEN_ID = 1
EOS_TOKEN_ID = 2

## Dataset

In [25]:
import torch
from torch.utils.data import Dataset

class Seq2SeqDataset(Dataset):
    def __init__(self, X, Y):
        self.X = X
        self.Y = Y
    
    def __len__(self):
        return self.X.shape[0]
    
    def __getitem__(self, idx):
        return self.X[idx,:], self.Y[idx,:]

def create_dataset(N, src_len,trg_len,src_vocab_size,trg_vocab_size):
    X = torch.randint(2, src_vocab_size-1, (N, src_len), dtype = torch.long) 
    Y = torch.randint(2, trg_vocab_size-1, (N, trg_len), dtype = torch.long) 
    X[:, 0] = SOS_TOKEN_ID # <SOS>
    X[:, src_len - 1] = EOS_TOKEN_ID # <PAD> or <EOS>
    Y[:, 0] = SOS_TOKEN_ID # <SOS>
    Y[:, trg_len - 1] = EOS_TOKEN_ID # <PAD> or <EOS>
    dataset = Seq2SeqDataset(X, Y)
    return dataset

train_dataset = create_dataset(N = N_train, 
              src_len = src_len, 
              trg_len = trg_len,  
              src_vocab_size = src_vocab_size, 
              trg_vocab_size = trg_vocab_size)


test_dataset = create_dataset(N = N_test, 
              src_len = src_len, 
              trg_len = trg_len,  
              src_vocab_size =src_vocab_size, 
              trg_vocab_size = trg_vocab_size)

## Dataloader

In [26]:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(train_dataset, 
                    batch_size = 8 , 
                    # collate_fn=collate_fn,
                    pin_memory=True)

test_dataloader = DataLoader(test_dataset, 
                    batch_size= 32, 
                    # collate_fn=collate_fn,
                    pin_memory=True)

for i in train_dataloader:
    print(i)
    print(i[0].shape)
    print(i[1].shape)
    break

[tensor([[ 1, 47,  3,  ..., 62, 61,  2],
        [ 1, 77, 25,  ..., 69, 59,  2],
        [ 1, 45, 97,  ...,  4, 92,  2],
        ...,
        [ 1, 44, 53,  ..., 76, 44,  2],
        [ 1, 15,  4,  ..., 96, 85,  2],
        [ 1, 96, 38,  ..., 37, 84,  2]]), tensor([[  1,  62, 137,  ...,  40,  93,   2],
        [  1, 196, 155,  ...,  20, 171,   2],
        [  1, 137,   3,  ...,  64, 194,   2],
        ...,
        [  1, 116,  85,  ...,  83, 111,   2],
        [  1, 197, 111,  ..., 150, 146,   2],
        [  1, 153, 116,  ..., 175, 171,   2]])]
torch.Size([8, 256])
torch.Size([8, 128])


## Training

In [27]:
import torch.optim as optim

model = Transformer(src_vocab_size = src_vocab_size, 
                    trg_vocab_size = trg_vocab_size, 
                    dim = dim, 
                    num_layers = num_layers, 
                    heads = heads, 
                    max_len = max_len)

optimizer = optim.SGD(model.parameters(), lr = 1e-4)

loss_fn = nn.CrossEntropyLoss(ignore_index = -100)

epochs = 2
train_loss = []
test_loss = []
# PPL = []
total_step = 0


# train
for i in range(epochs):
    # train
    for k, batch in enumerate(train_dataloader):
        optimizer.zero_grad()
        X = batch[0]
        # trg_ids : <SOS>, t1, t2, t3, <EOS>
        # trg_ids[:,:-1]: <SOS>, t1, t2, t3
        # trg_ids[:, 1:]:    t1, t2, t3, <EOS>
        Y = batch[1][:, :-1] 

        src_mask = get_src_mask(X)
        trg_mask = get_trg_mask(Y)
        src_trg_mask = get_src_trg_mask(X, Y)

        logits, _ = model(X, Y, src_mask, trg_mask, src_trg_mask)

        label = batch[1][:, 1:] 
        bs, tmp_trg_len = label.shape
        loss = loss_fn(logits.reshape( bs*tmp_trg_len, trg_vocab_size) , label.reshape( bs * tmp_trg_len) )
        
        train_loss.append(loss.item())
        loss.backward()
        optimizer.step()

        total_step = total_step + 1
        
        if total_step % 10 == 0: 
            print(f"epochs:{i}, step:{total_step}, train_loss: {loss.item()}")

epochs:0, step:10, train_loss: 6.925608158111572
epochs:0, step:20, train_loss: 6.485151290893555
epochs:0, step:30, train_loss: 6.398404598236084
epochs:0, step:40, train_loss: 6.281773567199707
epochs:0, step:50, train_loss: 6.337588787078857
epochs:0, step:60, train_loss: 6.179869651794434
epochs:0, step:70, train_loss: 6.144567012786865
epochs:0, step:80, train_loss: 6.168476581573486
epochs:0, step:90, train_loss: 6.039062023162842
epochs:0, step:100, train_loss: 6.064135551452637
epochs:0, step:110, train_loss: 6.0511794090271
epochs:0, step:120, train_loss: 6.056873321533203
epochs:1, step:130, train_loss: 5.989947319030762
epochs:1, step:140, train_loss: 6.0201592445373535
epochs:1, step:150, train_loss: 5.9559736251831055
epochs:1, step:160, train_loss: 5.962271213531494
epochs:1, step:170, train_loss: 6.026018142700195
epochs:1, step:180, train_loss: 5.931738376617432
epochs:1, step:190, train_loss: 5.9036102294921875
epochs:1, step:200, train_loss: 5.962732791900635
epochs:1

## Inference

In [28]:
max_new_tokens = 10

src_len = 5

src_ids = torch.randint(0, src_vocab_size, (1, src_len))
src_mask = get_src_mask(src_ids, pad_token_id = PAD_TOKEN_ID)

拆分 forward 为两部分

```txt
    def forward(self, src_ids, trg_ids, src_mask = None, trg_mask = None, src_trg_mask = None):

        # --------------------------Stage 1--------------------------------
        X = self.encoder_input(src_ids)
        X_src = self.encoder(X, src_mask)


        # --------------------------Stage 2--------------------------------
        Y = self.decoder_input(trg_ids)
        Y = self.decoder(Y, X_src, trg_mask = trg_mask, src_trg_mask = src_trg_mask)

        logits = self.output_layer(Y) 
        prob = F.softmax(logits, dim = -1)
        
        return logits, prob
```

In [29]:
### Stage1: Encode

with torch.no_grad():
    X = model.encoder_input(src_ids)
    X_src = model.encoder(X, mask = src_mask)
    print(X_src.shape)

torch.Size([1, 5, 512])


In [31]:
### Stage2: Decode
trg_ids = torch.randint(0, src_vocab_size, (1, 1)) # 仅有 1 个 token
trg_ids[:,0] = SOS_TOKEN_ID # trg 一定要 <SOS> 


trg_mask = get_src_mask(trg_ids, pad_token_id = PAD_TOKEN_ID)
src_trg_mask = get_src_trg_mask(src_ids, trg_ids, pad_token_id = PAD_TOKEN_ID)

print(trg_ids)
with torch.no_grad():
    for i in range(max_new_tokens):
        Y = model.decoder_input(trg_ids)
        Y = model.decoder(Y, X_src, trg_mask = None, src_trg_mask = None)
        logits = model.output_layer(Y) # [bs, seq_len, trg_vocab_size], 并行多个 next token logits 预测
        next_token_logits = logits[:, -1, :] # [bs, trg_vocab_size]
        next_token_prob = F.softmax(next_token_logits, dim = -1)
        next_token = torch.argmax(next_token_prob, dim = -1, keepdim = True)
        trg_ids = torch.concat((trg_ids, next_token), dim = 1)
        print(trg_ids)
print('final:', trg_ids)

tensor([[1]])
tensor([[  1, 146]])
tensor([[  1, 146,  14]])
tensor([[  1, 146,  14,  14]])
tensor([[  1, 146,  14,  14,  14]])
tensor([[  1, 146,  14,  14,  14,  14]])
tensor([[  1, 146,  14,  14,  14,  14,  14]])
tensor([[  1, 146,  14,  14,  14,  14,  14,  14]])
tensor([[  1, 146,  14,  14,  14,  14,  14,  14,  14]])
tensor([[  1, 146,  14,  14,  14,  14,  14,  14,  14,  14]])
tensor([[  1, 146,  14,  14,  14,  14,  14,  14,  14,  14,  14]])
final: tensor([[  1, 146,  14,  14,  14,  14,  14,  14,  14,  14,  14]])


## Model Save

在训练过程中，要保存模型，具体包含：

- 保存网络权重参数, torch 用字典数据类型管理权重
- 保存模型超参数配置: config，可以由独立的文件来管理，如 json, yaml

In [35]:
import os
def model_save(filepath, model):
    # save_path = os.path(filepath)
    save_dict = {'model_state_dict': model.state_dict()}
    # torch.save(model, save_path)
    torch.save(save_dict, filepath)
    print(f"Model saved to {filepath}")

model_save('./output/model.pth', model)

Model saved to ./output/model.pth


## Model Load

权重加载是容易的（权重不等于一个模型类对象），但是会遇到 config 超参数加载问题。

1. 加载 config 文件，初始化 model, 此时参数是 random 的。
2. 加载 权重文件，并将参数 copy 到 model 对象中。

In [49]:
def load_model(filename, full_model=True):
    # load_path = os.path(filename)
    if not os.path.exists(filename):
        raise FileNotFoundError(f"No model found at {filename}")
    
    if full_model:
        model = torch.load(filename, weights_only = False)
    else:
        checkpoint = torch.load(filename)
        
        # 加载模型参数
        model.load_state_dict(checkpoint['model_state_dict'], strict=strict_load)
        
        # 返回其他保存的信息
        return {k: v for k, v in checkpoint.items() if k != 'model_state_dict'}
    return model
    
new_model = load_model('./output/model.pth', )
# print(new_model)

In [50]:
init_model = Transformer(src_vocab_size = src_vocab_size, 
                    trg_vocab_size = trg_vocab_size, 
                    dim = dim, 
                    num_layers = num_layers, 
                    heads = heads, 
                    max_len = max_len)
init_model.load_state_dict(new_model['model_state_dict'] )

<All keys matched successfully>

In [51]:
print(init_model.encoder_input.embedding.weight)

Parameter containing:
tensor([[-0.4045,  0.6274,  0.3160,  ..., -0.3810,  1.5396,  0.7373],
        [ 0.6353,  1.3252,  0.8607,  ..., -1.2796, -0.7702,  0.1237],
        [-0.2740, -0.2623,  1.5734,  ..., -0.8170, -1.2125,  0.3567],
        ...,
        [ 0.4068, -0.2399, -0.3577,  ..., -1.3916, -1.0082, -1.3997],
        [ 0.4731,  0.9009,  1.5134,  ...,  0.5403,  0.0127,  0.2677],
        [-1.8452,  0.3997, -0.2465,  ...,  0.3447, -0.8801, -0.4935]],
       requires_grad=True)


In [52]:
print(init_model)

Transformer(
  (encoder_input): TransformerInputLayer(
    (embedding): Embedding(100, 512)
  )
  (encoder): TransformerEncoder(
    (encoder): ModuleList(
      (0-5): 6 x TransformerEncoderBlock(
        (attn): MultiHeadScaleDotProductAttention(
          (WQ): Linear(in_features=512, out_features=512, bias=True)
          (WK): Linear(in_features=512, out_features=512, bias=True)
          (WV): Linear(in_features=512, out_features=512, bias=True)
          (WO): Linear(in_features=512, out_features=512, bias=True)
        )
        (ln1): LayerNorm()
        (ffn): FeedForwardNetwork(
          (W_up): Linear(in_features=512, out_features=2048, bias=True)
          (ReLU): ReLU()
          (W_down): Linear(in_features=2048, out_features=512, bias=True)
        )
        (ln2): LayerNorm()
      )
    )
  )
  (decoder_input): TransformerInputLayer(
    (embedding): Embedding(200, 512)
  )
  (decoder): TransformerDecoder(
    (decoder): ModuleList(
      (0-5): 6 x TransformerDecode

将在 `./model_io.ipynb` 中描写一个完整的模型 IO 类