#### Transformer架构
<div align=center><img decoding="async" src=img/transformer.webp width="60%">


In [32]:
import torch
import torch.nn as nn
import numpy as np
import math

构建config类，用于保存超参数

In [33]:
class Config(object):
    def __init__(self):
        self.vocab_size = 6

        self.d_model = 20
        self.n_heads = 2

        assert self.d_model % self.n_heads == 0
        self.dim_k = self.d_model % self.n_heads
        self.dim_v = self.d_model % self.n_heads

        self.padding_size = 30
        self.UNK = 5
        self.PAD = 4

        self.N = 6
        self.p = 0.1


config = Config()

Embedding 进行元素编码

In [34]:
class Embedding(nn.Module):
    def __init__(self, vocab_size):
        super(Embedding, self).__init__()
        # 一个普通的 embedding层，我们可以通过设置padding_idx=config.PAD 来实现论文中的 padding_mask
        self.embedding = nn.Embedding(
            vocab_size, config.d_model, padding_idx=config.PAD
        )

    def forward(self, x):
        # 根据每个句子的长度，进行padding，短补长截
        for i in range(len(x)):
            if len(x[i]) < config.padding_size:
                x[i].extend(
                    [config.UNK] * (config.padding_size - len(x[i]))
                )  # 注意 UNK是你词表中用来表示oov的token索引，这里进行了简化，直接假设为6
            else:
                x[i] = x[i][: config.padding_size]
        x = self.embedding(torch.tensor(x))  # batch_size * seq_len * d_model
        return x

position embedding

In [35]:
class Positional_Encoding(nn.Module):
    def __init__(self, d_model):
        super(Positional_Encoding, self).__init__()
        self.d_model = d_model

    def forward(self, seq_len, embedding_dim):
        positional_encoding = np.zeros((seq_len, embedding_dim))
        for pos in range(positional_encoding.shape[0]):
            for i in range(positional_encoding.shape[1]):
                positional_encoding[pos][i] = (
                    math.sin(pos / (10000 ** (2 * i / self.d_model)))
                    if i % 2 == 0
                    else math.cos(pos / (10000 ** (2 * i / self.d_model)))
                )
        return torch.from_numpy(positional_encoding)

Multi Head Attention

In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, dim_k, dim_v, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.dim_k = dim_k
        self.dim_v = dim_v
        self.d_model = d_model
        self.q = nn.Linear(d_model, dim_k)
        self.k = nn.Linear(d_model, dim_k)
        self.v = nn.Linear(d_model, dim_v)
        self.o = nn.Linear(dim_v, d_model)
        self.__normal_factor = 1 / math.sqrt(d_model)

    def generate_mask(self, dim):
        mask = torch.from_numpy(np.triu(np.ones((dim, dim)), k=1).astype("bool"))
        return mask

    def forward(self, x, y, require_mask=False):
        Q = self.q(x)
        K = self.k(x)
        V = self.v(y)

        Q_ = Q.view(Q.shape[0], Q.shape[1], self.n_head, self.dim_k // self.n_head)
        K_ = K.view(K.shape[0], K.shape[1], self.n_head, self.dim_k // self.n_head)
        V_ = V.view(V.shape[0], V.shape[1], self.n_head, self.dim_v // self.n_head)

        AttenMatrix = torch.matmul(Q_, K_.permute(0, 1, 3, 2)) * self.__normal_factor

        if require_mask:
            mask = self.generate_mask(AttenMatrix.shape[1])
            AttenMatrix.masked_fill_(mask, -float("inf"))

        output = torch.matmul(AttenMatrix, V_).reshape(Q.shape[0], Q.shape[1], -1)

        output = self.o(output)
        return output

Feed Forward

In [37]:
class Feed_forward(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super(Feed_forward, self).__init__()
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear2 = nn.Linear(hidden_dim, input_dim)
        self.relu = nn.ReLU()

    def forward(self,x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

Add&Normal

In [38]:
class Add_Normal(nn.Module):
    def __init__(self):
        super(Add_Normal, self).__init__()
        self.dropout = nn.Dropout(config.p)

    def forward(self, x, sub_layer, **kwargs):
        sub_output = sub_layer(**kwargs)
        x = self.dropout(x + sub_output)
        layer_norm = nn.LayerNorm(x.size()[1:])
        x = layer_norm(x)
        return x

Encoder

In [39]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        self.positional_encoding = Positional_Encoding(config.d_model)
        self.multi_head_attention = MultiHeadAttention(
            config.d_model, config.dim_k, config.dim_v, config.n_heads
        )
        self.feed_forward = Feed_forward(config.d_model)
        self.add_normal = Add_Normal()

    def forward(self, x):
        positional_encoding = self.positional_encoding(x.shape[1], x.shape[2])
        x = x + positional_encoding
        x = self.add_normal(x, self.multi_head_attention, x=x, y=x)
        x = self.add_normal(x, self.feed_forward)
        return x