In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, d_model, hidden, drop_prob=0.1):
        self.fc1 = nn.Linear(d_model, hidden)
        self.fc2 = nn.Linear(hidden, d_model)
        self.relu == nn.ReLU()
        self.dropout = nn.Dropout(p=drop_prob)
        
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [49]:
import math
class ScaleDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaleDotProductAttention, self).__init__()
        self.softmax == nn.Softmax(dim=-1)
        
    def forward(self, q, k, v, mask=None, e=1e-12):
        # the input is [batch_size, head, length, dim]
        # dim表示每个词向量的维度
        batch_size, head, length, dim = k.size()
        
        k_t = k.transpose(2, 3)
        score = (q @ k) / math.sqrt(dim)
        if mask is not None:
            score = score.masked_fill(mask == 0, -10000)
        
        score = self.softmax(score)
        
        v = score @ v
        
        return v, score

In [3]:
class transformer_blocks(nn.Module):
    
    def __init__(self):
        super(transformer_blocks, self).__init__()
        

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.attention = ScaleDotProductAttention()
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_concat = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v)
        
        q, k, v = self.split(q), self.split(k), self.split(v)
        
        out, attention = self.attention(q, k, v, mask=mask)
        out = self.concat(out)
        out = self.w_concat        
    
    def split(self, tensor):
        
        batch_size, length, d_model = tensor.size()
        d_tensor = d_model // self.n_head
        tensor = tensor.view(batch_size, length, self.n_head, d_tensor).transpose(1,2)
        
        return tensor
    
    def concat(self, tensor):
        batch_size, head, length, dim = tensor.size()
        d_model = head * dim
        
        tensor = tensor.transpose(1, 2).contiguous().view(batch_size, length, d_model)
        return tensor

In [13]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True) # unbiased是否使用无偏估计
        
        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out

In [27]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PositionalEncoding, self).__init__()
        
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False
        
        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=-1)
        
        _2i = torch.arange(0, d_model, step=2, device=device).float()
        
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        
    def forward(self, x):
        batch_size, seq_len = x.size()
        
        return self.encoding[:seq_len, :]

In [28]:
class TokenEmbedding(nn.Embedding):
    
    def __init__(self, vocab_size, d_model):
        super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)

In [44]:
class TransformerEmbedding(nn.Module):
    
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super(TransformerEmbedding, self).__init__()
        self.tok_emb = TokenEmbedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.drop_out = nn.Dropout(p=drop_prob)
    
    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        
        return self.drop_out(tok_emb + pos_emb)

In [46]:
class EncoderLayer(nn.Module):
    
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(EncoderLayer, self).__init__()
        self.attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm1 = LayerNorm(d_model=d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)
        
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm2 = LayerNorm(d_model=d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)
    
    def forward(self, x, s_mask):
        x_ = x
        x = self.attention(q=x, k=x, v=x, mask=s_mask)
        
        x = self.dropout1(x)
        x = self.norm1(x + x_)
        
        x_ = x
        x = self.ffn(x)
        
        x = self.dropout2(x)
        x = self.norm2(x + x_)
        
        return x
    
class DecoderLayer(nn.Module):
    
    def __init__(self, d_model, ffn_hidden, n_head, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm1 = LayerNorm(d_model=d_model)
        self.dropout1 = nn.Dropout(drop_prob=drop_prob)
        
        self.enc_dec_attention = MultiHeadAttention(d_model=d_model, n_head=n_head)
        self.norm2 = LayerNorm(d_model=d_model)
        self.dropout2 = nn.Dropout(drop_prob=drop_prob)
        
        self.ffn = PositionwiseFeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm3 = LayerNorm(d_model=d_model)
        self.dropout3 = nn.Dropout(drop_prob=drop_prob)
    
    def forward(self, dec, enc, t_mask, s_mask):
        x_ = dec
        x = self.self_attention(q=dec, k=dec, v=dec, mask=t_mask)
        
        x = self.dropout1(x)
        x = self.norm1(x + x_)
        
        if enc is not None:
            x_ = x
            x = self.enc_dec_attention(q=x, k=enc, v=enc, mask=s_mask)
            
            x = self.dropout2(x)
            x = self.norm2(x + x_)
        
        x_ = x
        x = self.ffn(x)
        
        x = self.dropout3(x)
        x = self.norm3(x + x_)
        return x 

In [47]:
class Encoder(nn.Module):
    
    def __init__(self, enc_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
        super().__init__()
        self.emb = TransformerEmbedding(d_model=d_model,
                                       max_len=max_len,
                                       vocab_size=enc_voc_size,
                                       drop_prob=drop_prob,
                                       device=device)
        self.layers = nn.ModuleList([EncoderLayer(d_model=d_model,
                                                 ffn_hidden=ffn_hidden,
                                                 n_head=n_head,
                                                  drop_prob=drop_prob)
                                                  for _ in n_layers])
        
    def forward(self, x, s_mask):
        x = self.emb(x)
        for layer in self.layers:
            x = layer(x, s_mask)
        
        return x

In [48]:
class Decoder(nn.Module):
    def __init__(self, dec_voc_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob, device):
        super().__init__()
        self.emb = TransformerEmbedding(d_model=d_model,
                                        drop_prob=drop_prob,
                                        max_len=max_len,
                                        vocab_size=dec_voc_size,
                                        device=device)
        self.layers = nn.ModuleList([DecoderLayer(d_model=d_model,
                                                 ffn_hidden=ffn_hidden,
                                                 n_head=n_head,
                                                 drop_prob=drop_prob)
                                    for _ in range(n_layers)])
        self.fc = nn.Linear(d_model, dec_voc_size)
        
    def forward(self, trg, enc_src, trg_mask, src_mask):
        trg = self.emb(trg)
        
        for layer in self.layers:
            trg = layer(trg, enc_src, trg_mask, src_mask)
        
        output = self.fc(trg)
        return output

In [42]:
# emb = nn.Embedding(1000, 64, padding_idx=1)
# input_ = torch.randint(1,10,(4,10))
# print(input_)
# emb(input_)

tensor([[4, 1, 8, 1, 6, 2, 3, 3, 7, 8],
        [7, 7, 6, 9, 5, 5, 9, 8, 4, 3],
        [4, 8, 5, 1, 9, 6, 3, 7, 7, 2],
        [5, 8, 4, 2, 8, 2, 6, 3, 5, 2]])


tensor([[[ 1.5287, -2.0504, -1.6227,  ...,  0.8604, -2.1784, -1.1553],
         [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
         [ 0.2667,  0.6300, -0.5273,  ...,  0.1845,  0.0470,  1.7130],
         ...,
         [ 0.0839, -0.2690,  1.4557,  ..., -1.7293, -2.3495, -0.6752],
         [-0.1391, -0.7260,  0.0273,  ...,  0.6196, -1.2513, -0.0145],
         [ 0.2667,  0.6300, -0.5273,  ...,  0.1845,  0.0470,  1.7130]],

        [[-0.1391, -0.7260,  0.0273,  ...,  0.6196, -1.2513, -0.0145],
         [-0.1391, -0.7260,  0.0273,  ...,  0.6196, -1.2513, -0.0145],
         [ 0.2435, -0.9968, -0.4456,  ...,  0.2026,  1.1445, -1.7328],
         ...,
         [ 0.2667,  0.6300, -0.5273,  ...,  0.1845,  0.0470,  1.7130],
         [ 1.5287, -2.0504, -1.6227,  ...,  0.8604, -2.1784, -1.1553],
         [ 0.0839, -0.2690,  1.4557,  ..., -1.7293, -2.3495, -0.6752]],

        [[ 1.5287, -2.0504, -1.6227,  ...,  0.8604, -2.1784, -1.1553],
         [ 0.2667,  0.6300, -0.5273,  ...,  0