In [1]:
import torch
import math
from torch import nn
import torch.nn.functional as F

In [2]:
# Embedding Part: Token Embedding
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, d_model):
        super().__init__(vocab_size, d_model, padding_idx=1)

In [3]:
# Embedding Part: Positional Embedding
class PositionalEmbedding(nn.Module):
    def __init__(self, max_len, d_model):
        super(PositionalEmbedding, self).__init__()
        # P.shape -> (batch_size, max_len, d_model)
        self.P = torch.zeros(size=(1, max_len, d_model))
        # pos_matrix.shape -> (max_len, 1)
        pos_matrix = torch.arange(0, max_len, dtype=torch.float32).reshape((-1, 1))
        # div_matrix.shape -> (1, d_model / 2)
        div_matrix = torch.pow(10000, torch.arange(0, d_model, 2, dtype=torch.float32) / d_model)
        # X.shape -> (max_len, d_model / 2)
        X = pos_matrix / div_matrix
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)
    def forward(self, X):
        # X.shape -> (batch_size, num_steps , d_model)
        # num_steps is the position, num_steps is the index
        print('X.shape = ', X.shape)
        print('P.shpae = ', self.P.shape)
        return X + self.P[:, :X.shape[1], :]

In [4]:
# Embedding Part: Transformer Embedding
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, dropout_rate=0.1):
        super(TransformerEmbedding, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, d_model)
        self.position_embedding = PositionalEmbedding(max_len, d_model)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, X):
        # X -> (batch_size, num_steps)
        X = self.token_embedding(X)
        # X -> (batch_size, num_steps, d_model)
        X = self.position_embedding(X)
        return self.dropout(X)

In [5]:
# 测试 Embedding
X = torch.arange(12).reshape((3, 4))
embedding = TransformerEmbedding(15, 128, 15)
Y = embedding(X)
Y.shape

X.shape =  torch.Size([3, 4, 128])
P.shpae =  torch.Size([1, 15, 128])


torch.Size([3, 4, 128])

In [6]:
# MultiHeadAttention 多头注意力(不包含 Dropout)
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, d_model):
        super(MultiHeadAttention, self).__init__()
        self.d_head = d_model // num_heads
        self.query_proj = nn.Linear(d_model, d_model)
        self.key_proj = nn.Linear(d_model, d_model)
        self.value_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self._attention_weights = None
    def forward(self, q, k, v, mask=None):
        # q, k, v.shape -> (batch_size, seq_len, d_model)
        batch_size, seq_len, _ = q.shape
        # q, k, v 经过 Linear 层
        q = self.query_proj(q)
        k = self.key_proj(k)
        v = self.value_proj(v)
        # q, k, v.shape -> (batch_size, seq_len, d_model)
        # -> (batch_size, num_heads, seq_len, d_head)
        q = q.reshape(batch_size, seq_len, -1, self.d_head).permute(0, 2, 1, 3)
        k = k.reshape(batch_size, seq_len, -1, self.d_head).permute(0, 2, 1, 3)
        v = v.reshape(batch_size, seq_len, -1, self.d_head).permute(0, 2, 1, 3)
        # q, k, v.shape -> (batch_size, num_heads, seq_len, d_head)
        # 计算注意力分数
        score = q @ k.permute(0, 1, 3, 2) / math.sqrt(self.d_head)
        # mask 操作
        if mask is not None:
            score = score.masked_fill(
                mask == 0,
                float("-inf")
            )
        # score -> (batch_size, num_heads, seq_len, seq_len)
        self._attention_weights = torch.softmax(score, dim=-1)
        # attention_weights -> (batch_size, num_heads, seq_len, seq_len)
        output = self._attention_weights @ v
        # output.shape -> (batch_size, num_heads, seq_len, d_head)
        # -> (batch_size, seq_len, d_model)
        output = output.permute(0, 2, 1, 3).reshape(batch_size, seq_len, -1)
        output = self.out_proj(output)
        return output
    @property
    def attention_weights(self):
        return self._attention_weights

In [7]:
# 测试 MultiHead
multi_head = MultiHeadAttention(8, 4)
# mask -> (batch_size, num_heads, seq_len, seq_len) -> (3, 8, 4, 4)
# Y -> (3, 4, 128)
# mask 的含义:
# 第 i 行的第 j 列为 0  -> 第 i 个词元, 屏蔽第 j 个词元
# 但是现实情况中一般是排除无用词元, 所以一般来说都是通过 (n,) 的向量进行拓展的
mask = torch.tensor([
    [1, 1, 1, 1],
    [1, 1, 1, 0],
    [1, 1, 0, 0],
    [1, 0, 0, 0]
])
mask = mask.unsqueeze(0).unsqueeze(0).repeat(3, 8, 1, 1)
multi_head = MultiHeadAttention(8, 128)
Y = multi_head(Y, Y, Y, mask)
# (3, 8, 4, 4)
Y.shape, multi_head.attention_weights

(torch.Size([3, 4, 128]),
 tensor([[[[0.1822, 0.1772, 0.4764, 0.1642],
           [0.2799, 0.2984, 0.4217, 0.0000],
           [0.4610, 0.5390, 0.0000, 0.0000],
           [1.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.2176, 0.2830, 0.2426, 0.2568],
           [0.3026, 0.3580, 0.3394, 0.0000],
           [0.3696, 0.6304, 0.0000, 0.0000],
           [1.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1944, 0.2520, 0.2342, 0.3194],
           [0.2906, 0.3639, 0.3455, 0.0000],
           [0.5404, 0.4596, 0.0000, 0.0000],
           [1.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.2188, 0.2023, 0.3029, 0.2760],
           [0.3124, 0.2421, 0.4456, 0.0000],
           [0.4839, 0.5161, 0.0000, 0.0000],
           [1.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.1371, 0.2103, 0.1460, 0.5067],
           [0.2634, 0.4252, 0.3114, 0.0000],
           [0.5425, 0.4575, 0.0000, 0.0000],
           [1.0000, 0.0000, 0.0000, 0.0000]],
 
          [[0.3783, 0.1540, 0.2116, 0.2561],
           [0.

In [8]:
# Layer Norm, 对于维度进行线性变换
class LayerNorm(nn.Module):
    # eps 用于防止 var = 0
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta  = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
    def forward(self, X):
        # X.shape -> (batch_size, dim)
        mean = X.mean(dim=-1, keepdim=True)
        var  = X.var(dim=-1,  keepdim=True, unbiased=False)
        out = (X - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out

In [9]:
# FFN 前馈网络
class PositionWiseFFN(nn.Module):
    def __init__(self, d_model, num_hidden, dropout_rate=0.1):
        super(PositionWiseFFN, self).__init__()
        self.fc1 = nn.Linear(d_model, num_hidden)
        self.fc2 = nn.Linear(num_hidden, d_model)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, X):
        X = self.dropout(F.relu(self.fc1(X)))
        X = self.fc2(X)
        return X

In [10]:
# Encoder Block
class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, num_hiddens, dropout_rate=0.1):
        super(EncoderBlock, self).__init__()
        self.attention = MultiHeadAttention(num_heads, d_model)
        self.norm1 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.ffn = PositionWiseFFN(d_model, num_hiddens)
        self.norm2 = LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout_rate)
    def forward(self, X, mask=None):
        # 残差连接
        init_X = X
        X = self.dropout1(self.attention(X, X, X, mask))
        X = self.norm1(X + init_X)
        init_X = X
        X = self.dropout1(self.ffn(X))
        X = self.norm2(X + init_X)
        return X

In [11]:
# 定义 Encoder 结构
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, num_heads, num_hiddens, num_layers, dropout_rate=0.1):
        super(TransformerEncoder, self).__init__()
        self.embedding = TransformerEmbedding(vocab_size, d_model, max_len, dropout_rate)
        self.layers = nn.ModuleList([
            EncoderBlock(d_model, num_heads, num_hiddens, dropout_rate) for _ in range(num_layers)
        ])
    def forward(self, X, mask=None):
        X = self.embedding(X)
        for layer in self.layers:
            X = layer(X, mask)
        return X

In [12]:
X = torch.arange(12).reshape(3, 4)
encoder = TransformerEncoder(20, 128, 20, 8, 256, 6)
Y = encoder(X)
Y.shape

X.shape =  torch.Size([3, 4, 128])
P.shpae =  torch.Size([1, 20, 128])


torch.Size([3, 4, 128])

In [13]:
# Decoder
class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, num_hiddens, dropout_rate=0.1):
        super(DecoderBlock, self).__init__()
        self.attention = MultiHeadAttention(num_heads, d_model)
        self.norm1 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout_rate)
        self.cross_attention = MultiHeadAttention(num_heads, d_model)
        self.norm2 = LayerNorm(d_model)
        self.dropout2 = nn.Dropout(dropout_rate)
        self.ffn = PositionWiseFFN(d_model, num_hiddens, dropout_rate)
        self.norm3 = LayerNorm(d_model)
        self.dropout3 = nn.Dropout(dropout_rate)
    def forward(self, enc_output, X, t_mask=None, s_mask=None):
        init_X = X
        X = self.dropout1(self.attention(X, X, X, t_mask))
        X = self.norm1(X + init_X)
        init_X = X
        X = self.dropout2(self.cross_attention(X, enc_output, enc_output, s_mask))
        X = self.norm2(X + init_X)
        init_X = X
        X = self.dropout3(self.ffn(X))
        X = self.norm3(X + init_X)
        return X

In [14]:
# Decoder 结构
class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_hiddens, max_len, num_layers, dropout_rate=0.1):
        super(TransformerDecoder, self).__init__()
        self.embedding = TransformerEmbedding(vocab_size, d_model, max_len, dropout_rate)
        self.decoder_layers = nn.ModuleList(
            [
                DecoderBlock(d_model, num_heads, num_hiddens, dropout_rate) for _ in range(num_layers)
            ]
        )
        self.fc = nn.Linear(d_model, vocab_size)
    def forward(self, enc_output, X, t_mask=None, s_mask=None):
        X = self.embedding(X)
        for layer in self.decoder_layers:
            X = layer(enc_output, X, t_mask, s_mask)
        X = self.fc(X)
        return X

In [15]:
# 测试 Decoder
# Encoder 输入 X
X = torch.arange(12).reshape(3, 4)
encoder = TransformerEncoder(20, 128, 20, 8, 256, 6)
# Encoder 输出
enc_output = encoder(X)

DX = torch.arange(12).reshape(3, 4)
decoder = TransformerDecoder(20, 128, 8, 256, 20, 6)
Out = decoder(enc_output, X)
Out = Out.argmax(dim=-1)
Out

X.shape =  torch.Size([3, 4, 128])
P.shpae =  torch.Size([1, 20, 128])
X.shape =  torch.Size([3, 4, 128])
P.shpae =  torch.Size([1, 20, 128])


tensor([[13, 13, 13, 13],
        [12, 13,  7, 13],
        [13, 13, 13, 13]])

In [32]:
# Transfomer
class Transformer(nn.Module):
    def __init__(self,
                 src_pad_idx,
                 trg_pad_idx,
                 enc_voc_size,
                 num_heads,
                 max_len,
                 trg_voc_size,
                 d_model,
                 ffn_hidden,
                 num_layers,
                 dropout_rate=0.1):
        super(Transformer, self).__init__()
        self.encoder = TransformerEncoder(enc_voc_size, d_model, 
                                          max_len, num_heads, 
                                          ffn_hidden, num_layers, dropout_rate)
        self.decoder = TransformerDecoder(trg_voc_size, d_model, num_heads,ffn_hidden,
                                          max_len, num_layers, dropout_rate)
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
    def make_pad_mask(self, q, k, pad_idx_q, pad_idx_k):
        # 注意此时 q, k 长度 -> seq_len
        # 最终需要比较的序列 -> (batch_size, num_heads, seq_len, seq_len)
        # 所以这里可以使用掩码, 并且涉及到注意力分数的计算只需要考虑 q, k 即可
        len_q, len_k = q.size(1), k.size(1)
        q = q.ne(pad_idx_q).unsqueeze(1).unsqueeze(3)
        q = q.repeat(1, 1, 1, len_k)
        k = k.ne(pad_idx_k).unsqueeze(1).unsqueeze(2)
        k = k.repeat(1, 1, len_q, 1)
        mask = q & k
        return mask
    def make_causal_mask(self, q, k):
        len_q, len_k = q.size(1), k.size(1)
        mask = torch.tril(torch.ones(len_q, len_k)).type(torch.float32)
        return mask
    def forward(self, src, trg):
        src_mask = self.make_pad_mask(src, src, self.src_pad_idx, self.src_pad_idx)
        trg_mask = self.make_pad_mask(trg, trg, self.trg_pad_idx, self.trg_pad_idx)*self.make_causal_mask(trg, trg)
        enc = self.encoder(src, src_mask)
        out = self.decoder(enc, trg, trg_mask, src_mask)
        return out

In [None]:
mask = torch.tril(torch.ones(3, 3))
mask 

(tensor([[1., 0., 0.],
         [1., 1., 0.],
         [1., 1., 1.]]),
 tensor([[ True, False, False],
         [ True,  True, False],
         [ True,  True,  True]]))