In [1]:
import math
import torch
from torch import nn

torch.manual_seed(2021)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### 点积缩放注意力机制

In [2]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self, dropout=0.1):
        super(ScaledDotProductAttention, self).__init__()
        self.dropout = nn.Dropout(dropout)

    def forward(self, q, k, v, mask=None):
        # q,k,v = [bs, seq, feats]
        d_k = k.size(-1)
        assert d_k == q.size(-1)

        # 1. calculate score
        k = k.transpose(-1, -2)  # [ bs, feats, seq]
        scores = torch.bmm(q, k)  # [bs, seq, seq]

        # 2. divide by sqrt(d_k)
        scores = scores / math.sqrt(d_k)

        # 3. mask optional
        if mask is not None:
            scores = scores.masked_fill(mask, 0)

        # 4. softmax
        scores = torch.exp(scores)
        scores = scores / scores.sum(dim=-1, keepdim=True)  # [bs, seq, seq]

        scores = self.dropout(scores)

        # 5. matmul with value matrix
        context = torch.bmm(scores, v)  # [bs, seq, feats]

        return context

In [8]:
attn = ScaledDotProductAttention()
q = torch.rand(5, 10, 20)
k = torch.rand(5, 10, 20)
v = torch.rand(5, 10, 20)
result = attn(q, k, v)
print(f"ScaledDotProductAttention parameters: {sum(x.numel() for x in attn.parameters())} \t size: {result.size()}")

ScaledDotProductAttention parameters: 0 	 size: torch.Size([5, 10, 20])


### 多头注意力机制

In [9]:
class AttentionHead(nn.Module):
    def __init__(self, d_model, d_feature, dropout=0.1):
        super(AttentionHead, self).__init__()
        self.attn = ScaledDotProductAttention(dropout)
        self.tfm_query = nn.Linear(d_model, d_feature)
        self.tfm_key = nn.Linear(d_model, d_feature)
        self.tfm_value = nn.Linear(d_model, d_feature)

    def forward(self, queries, keys, values, mask=None):
        Q = self.tfm_query(queries)  # [bs, seq, feats]
        K = self.tfm_key(keys)  # [bs, seq, feats]
        V = self.tfm_value(values)  # [bs, seq, feats]

        context = self.attn(Q, K, V, mask)
        return context

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, d_feature, n_heads, dropout=0.1):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.d_feature = d_feature
        self.n_heads = n_heads
        self.dropout = dropout

        assert d_model == d_feature * n_heads

        self.attn_heads = nn.ModuleList([
            AttentionHead(d_model=d_model, d_feature=d_feature, dropout=dropout) for _ in range(n_heads)
        ])
        self.projection = nn.Linear(n_heads * d_feature, d_model)

    def forward(self, queries, keys, values, mask=None):
        xs = [attn(queries, keys, values, mask=mask) for i, attn in enumerate(self.attn_heads)]  # n_heads * [bs, seq, feats]

        # 拼接
        xs = torch.cat(xs, dim=-1)  # [bs, seq, feats * n_heads(=d_model)]

        # 将多个ATTN输出结果映射
        out = self.projection(xs)  # [bs, seq, d_model]
        return out

In [10]:
attn_head = AttentionHead(20, 20)
result = attn_head(q, k, v)
print(f"AttentionHead parameters: {sum(x.numel() for x in attn_head.parameters())} \t size: {result.size()}")

d_model = 20 * 8
d_feature = 20
n_heads = 8

heads = MultiHeadAttention(d_model=d_model, d_feature=d_feature, n_heads=n_heads)
result = heads(q.repeat(1, 1, 8), k.repeat(1, 1, 8), v.repeat(1, 1, 8))
print(f"MultiHeadAttention parameters: {sum(x.numel() for x in heads.parameters())} \t size: {result.size()}")

AttentionHead parameters: 1260 	 size: torch.Size([5, 10, 20])
MultiHeadAttention parameters: 103040 	 size: torch.Size([5, 10, 160])


### LayerNorm

In [11]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-8):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std = x.std(-1, keepdim=True)
        return self.gamma * (x - mean) / (std + self.eps) + self.beta

### Encoder子层

In [13]:
class EncoderBlock(nn.Module):
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super(EncoderBlock, self).__init__()
        print
        self.attn_head = MultiHeadAttention(d_model=d_model, d_feature=d_feature, n_heads=n_heads, dropout=dropout)
        self.layer_norm1 = LayerNorm(d_model=d_model)
        self.dropout = nn.Dropout(dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.layer_norm2 = LayerNorm(d_model=d_model)

    def forward(self, x, mask=None):
        # 1.1 multi head attention
        att = self.attn_head(x, x, x, mask=mask)
        # 1.2 apply normalization and residual connection
        x = x + self.dropout(self.layer_norm1(att))

        # 2.1 apply position-wise feedforward network
        pos = self.position_wise_feed_forward(x)
        # 2.2 apply normalization and residual connection
        x = x + self.dropout(self.layer_norm2(pos))
        return x

In [14]:
enc = EncoderBlock()
result = enc(torch.rand(5, 10, 512))
print(f"EncoderBlock parameters: {sum(x.numel() for x in enc.parameters())} \t size: {result.size()}")

EncoderBlock parameters: 3152384 	 size: torch.Size([5, 10, 512])


### Encoder

In [15]:
class TransformerEncoder(nn.Module):
    def __init__(self, n_blocks=6, d_model=512, n_heads=8, d_ff=2048, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.encoders = nn.ModuleList([
            EncoderBlock(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, n_heads=n_heads, dropout=dropout) for _ in range(n_blocks)
        ])

    def forward(self, x: torch.FloatTensor, mask=None):
        for encoder in self.encoders:
            x = encoder(x, mask=mask)
        return x

In [19]:
t_enc = TransformerEncoder()
result = t_enc(torch.rand(5, 10, 512))
print(f"TransformerEncoder parameters: {sum(x.numel() for x in t_enc.parameters())} \t size: {result.size()}")

TransformerEncoder parameters: 18914304 	 size: torch.Size([5, 10, 512])


### Decoder子层

In [18]:
class DecoderBlock(nn.Module):
    def __init__(self, d_model=512, d_feature=64, d_ff=2048, n_heads=8, dropout=0.1):
        super(DecoderBlock, self).__init__()
        self.masked_attn_head = MultiHeadAttention(d_model=d_model, d_feature=d_feature, n_heads=n_heads, dropout=dropout)
        self.attn_head = MultiHeadAttention(d_model=d_model, d_feature=d_feature, n_heads=n_heads, dropout=dropout)
        self.position_wise_feed_forward = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model)
        )
        self.layer_norm1 = LayerNorm(d_model=d_model)
        self.layer_norm2 = LayerNorm(d_model=d_model)
        self.layer_norm3 = LayerNorm(d_model=d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_out, dec_self_attn_mask=None, dec_enc_attn_mask=None):
        # 1.1 apply masked multihead attention
        att = self.masked_attn_head(x, x, x, mask=dec_self_attn_mask)
        # 1.2 layer normalization
        x = x + self.dropout(self.layer_norm1(att))

        # 2.1 apply multihead attention (encoder and decoder)
        att = self.attn_head(queries=x, keys=enc_out, values=enc_out, mask=dec_enc_attn_mask)
        # 2.2 layer normalization
        x = x + self.dropout(self.layer_norm2(att))

        # 3.1 apply position wise feedforward network
        pos = self.position_wise_feed_forward(x)
        # 3.2 layer normalization
        x = x + self.dropout(self.layer_norm3(x))

        return x

In [22]:
t_enc = TransformerEncoder()
dec = DecoderBlock()

x = torch.rand(5, 10, 512)
enc_out = t_enc(torch.rand(5, 10, 512))
result = dec(x, enc_out)
print(f"DecoderBlock parameters: {sum(x.numel() for x in dec.parameters())} \t size: {result.size()}")

DecoderBlock parameters: 4204032 	 size: torch.Size([5, 10, 512])


### Decoder

In [21]:
class TransformerDecoder(nn.Module):
    def __init__(self, n_blocks=6, d_model=512, d_ff=2048, n_heads=8, dropout=0.1):
        super(TransformerDecoder, self).__init__()
        self.decoders = nn.ModuleList([
            DecoderBlock(d_model=d_model, d_feature=d_model // n_heads, d_ff=d_ff, n_heads=n_heads, dropout=dropout) for _ in range(n_blocks)
        ])

    def forward(self, x: torch.FloatTensor, enc_out: torch.FloatTensor, dec_self_attn_mask=None, dec_enc_attn_mask=None):
        for decoder in self.decoders:
            x = decoder(x, enc_out, dec_self_attn_mask, dec_enc_attn_mask)
        return x

In [23]:
t_enc = TransformerEncoder()
t_dec = TransformerDecoder()

x = torch.rand(5, 10, 512)
enc_out = t_enc(torch.rand(5, 10, 512))
result = t_dec(x, enc_out)
print(f"TransformerDecoder parameters: {sum(x.numel() for x in t_dec.parameters())} \t size: {result.size()}")

TransformerDecoder parameters: 25224192 	 size: torch.Size([5, 10, 512])


### 词向量编码

In [24]:
# 位置编码
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=512):
        super(PositionalEmbedding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.weight = nn.Parameter(pe, requires_grad=False)

    def forward(self, x):
        return self.weight[:, :x.size(1), :]  # [1, seq, feat)


# 词向量编码
class WordPositionEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model=512):
        super(WordPositionEmbedding, self).__init__()
        self.word_embedding = nn.Embedding(vocab_size, d_model)
        self.positional_embedding = PositionalEmbedding(d_model)

    def forward(self, x: torch.LongTensor) -> torch.FloatTensor:
        return self.word_embedding(x) + self.positional_embedding(x)

In [27]:
emb = WordPositionEmbedding(vocab_size=1000)
encoder = TransformerEncoder()
decoder = TransformerDecoder()

src_ids = torch.randint(1000, (5, 30))
tgt_ids = torch.randint(1000, (5, 30))
x = encoder(emb(src_ids))
result = decoder(emb(tgt_ids), x)
print(f"Transformer parameters: {sum(x.numel() for x in list(encoder.parameters()) + list(decoder.parameters()))} \t size: {result.size()}")

Transformer parameters: 44138496 	 size: torch.Size([5, 30, 512])
