In [1]:
import paddle
paddle.__version__

'2.2.2'

In [2]:
# First put on the official api call, forgive my laziness.
import paddle
from paddle.nn import Transformer

# src: [batch_size, tgt_len, d_model]
enc_input = paddle.rand((2, 4, 128))
# tgt: [batch_size, src_len, d_model]
dec_input = paddle.rand((2, 6, 128))
# src_mask: [batch_size, n_head, src_len, src_len]
enc_self_attn_mask = paddle.rand((2, 2, 4, 4))
# tgt_mask: [batch_size, n_head, tgt_len, tgt_len]
dec_self_attn_mask = paddle.rand((2, 2, 6, 6))
# memory_mask: [batch_size, n_head, tgt_len, src_len]
cross_attn_mask = paddle.rand((2, 2, 6, 4))
transformer = Transformer(128, 2, 4, 4, 512)
output = transformer(enc_input,
                     dec_input,
                     enc_self_attn_mask,
                     dec_self_attn_mask,
                     cross_attn_mask)  # [2, 6, 128]
print(output)

Tensor(shape=[2, 6, 128], dtype=float32, place=CPUPlace, stop_gradient=False,
       [[[ 0.83344501, -1.21844459,  0.48055834, ...,  0.20295902,
          -0.24485584,  0.63060206],
         [ 0.02267257, -1.50108027,  0.92258799, ..., -0.56661445,
          -0.15777314,  0.36035648],
         [ 0.10309691, -0.62477779,  0.71399939, ..., -0.90896100,
           0.35118568,  1.10827327],
         [-0.28135449, -1.22320652,  0.58787996, ...,  0.04233040,
           0.45569146,  0.88440514],
         [ 0.93284696, -0.67029673,  0.18253788, ...,  0.07425120,
           0.05858831,  1.09639645],
         [ 0.81336921, -1.01010621,  0.89278758, ..., -0.03291705,
           0.10923889,  0.28272694]],

        [[ 1.28254485, -0.68442333,  0.47062826, ..., -0.49197811,
          -0.76637751,  1.21772254],
         [ 1.70308220, -0.29101193,  1.18233836, ...,  0.54455656,
          -0.71247047,  0.82578307],
         [ 1.38030028, -1.23640335,  0.07849841, ...,  0.14094420,
          -0.34609523

In [3]:
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import numpy as np

In [4]:
class ScaledDotProductAttention(nn.Layer):
   
    def __init__(self, temp, attn_dropout=0.1):
        super().__init__()
        self.temp = temp
        self.dropout = nn.Dropout(p=attn_dropout)

    def forward(self, q, k, v, mask=None):
        
        attn = paddle.matmul(q/self.temp, k, transpose_y=True)

        if mask is not None:
            attn = attn * mask
        attn = self.dropout(F.softmax(attn, axis=-1))

        output = paddle.matmul(attn, v)

        return output, attn

In [5]:
class MultiHeadAttention(nn.Layer):
   
    def __init__(self, n_head=8, d_model=512, d_k=None, d_v=None, dropout=0.1):
        super().__init__()

        self.n_head = n_head
        self.d_k = d_k
        self.d_v = d_v

        self.w_qs = nn.Linear(d_model, n_head * d_k, bias_attr=False)
        self.w_ks = nn.Linear(d_model, n_head * d_k, bias_attr=False)
        self.w_vs = nn.Linear(d_model, n_head * d_v, bias_attr=False)
        self.fc   = nn.Linear(n_head * d_v, d_model, bias_attr=False)

        self.attention = ScaledDotProductAttention(temp= d_k**0.5)

        self.dropout   = nn.Dropout(dropout)

        self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-5) 

    def forward(self, q, k, v, mask=None):
      
        d_k, d_v, n_head = self.d_k, self.d_v, self.n_head

        batch_size, len_q, len_k, len_v = q.shape[0], q.shape[1], k.shape[1], v.shape[1]

        residual = q
        
        q = self.w_qs(q).reshape((batch_size, len_q, n_head, d_k))
        k = self.w_ks(k).reshape((batch_size, len_k, n_head, d_k))
        v = self.w_vs(v).reshape((batch_size, len_v, n_head, d_v))

        
        q, k, v = q.transpose([0, 2, 1, 3]), k.transpose([0, 2, 1, 3]), v.transpose([0, 2, 1, 3])

        if mask is not None:
            mask = mask.unsqueeze(1)
        
        q, attn = self.attention(q, k, v, mask=mask)
        

        q = q.transpose([0, 2, 1, 3]).reshape((batch_size, len_q, -1))
        q = self.dropout(self.fc(q))

        q += residual
        q = self.layer_norm(q)

        return q, attn

In [6]:
class PositionwiseForward(nn.Layer):
    
    def __init__(self, d_in, d_hid, dropout=0.1):
        super().__init__()
        self.w_1 = nn.Linear(d_in, d_hid)
        self.w_2 = nn.Linear(d_hid, d_in)
        self.layer_norm = nn.LayerNorm(d_in, epsilon=1e-5) # d_in: 需规范化的shape  epsilon:指明在计算过程中是否添加较小的值到方差中以防止除零,paddle中一般都采用1e-5
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):

        residual = x

        x = self.w_2(F.relu(self.w_1(x)))
        x = self.dropout(x)
        x += residual

        x = self.layer_norm(x)

        return x

In [7]:
class EncoderLayer(nn.Layer):

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(EncoderLayer, self).__init__()

        self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn  = PositionwiseForward(d_model, d_inner, dropout=dropout)

    def forward(self, enc_input, self_attn_mask=None):

        enc_output, enc_self_attn = self.slf_attn(enc_input, enc_input, enc_input, mask=self_attn_mask)
        enc_output = self.pos_ffn(enc_output)

        return enc_output, enc_self_attn

In [8]:
class DecoderLayer(nn.Layer):

    def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
        self.pos_ffn  = PositionwiseForward(d_model, d_inner, dropout=dropout)
    def forward(self, dec_input, enc_output, self_attn_mask=None, dec_enc_attn_mask=None):
        dec_output, dec_self_attn = self.self_attn(dec_input, dec_input, dec_input, mask=self_attn_mask)
        dec_output, dec_enc_attn = self.enc_attn(dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
        dec_output = self.pos_ffn(dec_output)

        return dec_output, dec_self_attn, dec_enc_attn

In [9]:
#I don't understand it entirely
class PositionalEncoding(nn.Layer):

    def __init__(self, d_hid, n_position=200):
        super(PositionalEncoding, self).__init__()
        self.register_buffer('pos_table', self._get_sin_encoding_table(n_position, d_hid))

    def _get_sin_encoding_table(self, n_position, d_hid):

        def get_position_angle_vec(position):
            return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]

        sin_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
        sin_table[:, 0::2] = np.sin(sin_table[:, 0::2])  # dim 2i
        sin_table[:, 1::2] = np.cos(sin_table[:, 1::2])  # dim 2i+1

        return paddle.to_tensor(sin_table, dtype='float32').unsqueeze(0)

    def forward(self, x):
  
        return x + paddle.cast(self.pos_table[:, :x.shape[1]], dtype='float32').detach()

In [10]:
class Encoder(nn.Layer):

    def __init__(self, n_src_vocab=200, d_word_vec=20, n_layers=6, n_head=2, 
        d_k=10, d_v=10, d_model=20, d_inner=10, pad_idx= 0, dropout=0.1, n_position=200, emb_weight=None):

        super().__init__()

        self.src_word_emb = nn.Embedding(n_src_vocab, d_word_vec, sparse=True, padding_idx=pad_idx)

        if emb_weight is not None:
            self.src_word_emb.weight.set_value(emb_weight)
            self.src_word_emb.stop_gradient=True

        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        self.dropout      = nn.Dropout(dropout)
        self.layer_stack  = nn.LayerList(
            [
                EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
                for _ in range(n_layers)
            ]
        )
        self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-5)

    def forward(self, src_seq, src_mask, return_attns=False):

        enc_slf_attn_list = []
        print("src_seq:",src_seq.shape)
    
        enc_output = self.dropout(self.position_enc(self.src_word_emb(src_seq)))
        enc_output = self.layer_norm(enc_output)

        for enc_layer in self.layer_stack:
            enc_output, enc_slf_attn = enc_layer(enc_output, self_attn_mask=src_mask)
            enc_slf_attn_list += [enc_slf_attn] if return_attns else []

        if return_attns:
            return enc_output, enc_slf_attn_list
        return enc_output, 

In [11]:
class Decoder(nn.Layer):
    
    def __init__(self, n_trg_vocab=200, d_word_vec=20, n_layers=6, n_head=2, d_k=10, d_v=10,
        d_model=20, d_inner=10, pad_idx=0, dropout=0.1, n_position=200, emb_weight=None):

        super().__init__()
        self.trg_word_emb = nn.Embedding(n_trg_vocab, d_word_vec, padding_idx=pad_idx)
        if emb_weight is not None:
            self.trg_word_emb.weight.set_value(emb_weight)
            self.trg_word_emb.stop_gradient=True
        self.position_enc = PositionalEncoding(d_word_vec, n_position=n_position)
        self.dropout      = nn.Dropout(dropout)
        self.layer_stack  = nn.LayerList([
            DecoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=dropout)
            for _ in range(n_layers)
        ])
        self.layer_norm = nn.LayerNorm(d_model, epsilon=1e-5)

    def forward(self, trg_seq, trg_mask, enc_output, src_mask, return_attns=False):

        dec_self_attn_list, dec_enc_attn_list = [], []

        dec_output = self.dropout(self.position_enc(self.trg_word_emb(trg_seq)))
        dec_output = self.layer_norm(dec_output)

        for dec_layer in self.layer_stack:
            dec_output, dec_self_attn, dec_enc_attn = dec_layer(
                dec_output, enc_output, self_attn_mask=trg_mask, dec_enc_attn_mask=src_mask
            )
            dec_self_attn_list += [dec_self_attn] if return_attns else []
            dec_enc_attn_list += [dec_enc_attn] if return_attns else []
        if return_attns:
            return dec_output, dec_self_attn_list, dec_enc_attn_list
        return dec_output, 

In [12]:
def get_pad_mask(seq, pad_idx):
    return (seq != pad_idx).unsqueeze(-2)

def get_subsquent_mask(seq):

    batch_size, len_s = seq.shape[0], seq.shape[1]
    subsequent_mask = (1 - paddle.triu(paddle.ones((1, len_s, len_s)), diagonal=1)) 
    return subsequent_mask

class Transformer(nn.Layer):

    def __init__(
        self, n_src_vocab, n_trg_vocab, src_pad_idx=0, trg_pad_idx=0, 
        d_word_vec=512, d_model=512, d_inner=2048,
        n_layers=6, n_head=8, d_k=64, d_v=64, dropout=0.1, n_position=200,
        src_emb_weight=None, trg_emb_weight=None,
        trg_emd_prj_weight_sharing=True, emb_src_trg_weight_sharing=True,
    ):
        super(Transformer, self).__init__()
        self.src_pad_idx, self.trg_pad_idx = src_pad_idx, trg_pad_idx
        self.encoder = Encoder(
            n_src_vocab=n_src_vocab, pad_idx=src_pad_idx, d_word_vec=d_word_vec,
            n_layers=n_layers, n_head=n_head, d_model=d_model, d_inner=d_inner,
            d_k=d_k, d_v=d_v, dropout=dropout, n_position=n_position,
            emb_weight=src_emb_weight)

        self.decoder = Decoder(
            n_trg_vocab=n_trg_vocab, pad_idx=trg_pad_idx, d_word_vec=d_word_vec,
            n_layers=n_layers, n_head=n_head, d_model=d_model, d_inner=d_inner,
            d_k=d_k, d_v=d_v, dropout=dropout, n_position=n_position,
            emb_weight=trg_emb_weight)
        
        self.trg_word_prj = nn.Linear(d_model, n_trg_vocab, bias_attr=False, 
                                      weight_attr=nn.initializer.XavierUniform())

        for p in self.parameters():
            if p.dim()>1:
                print(p.shape)
                nn.initializer.XavierUniform(p)

        # 判断维度是否相等，残差链接的维度是相等的
        assert d_model == d_word_vec, 'To facilitate the residual connections, the dimensions of all module outputs shall be the same'
        
        self.x_logit_scale = 1.

        if trg_emd_prj_weight_sharing:
            weight = self.decoder.trg_word_emb.weight.numpy()
            weight = np.transpose(weight)
            self.trg_word_prj.weight.set_value(weight)
            self.x_logit_scale= (d_model ** -0.5)

        if emb_src_trg_weight_sharing:
            weight = self.decoder.trg_word_emb.weight.numpy()
            self.encoder.src_word_emb.weight.set_value(weight)
        
    def forward(self, src_seq, trg_seq):
        src_mask = get_pad_mask(src_seq, self.src_pad_idx)

        trg_mask = get_pad_mask(trg_seq, self.trg_pad_idx).numpy().astype(bool) & get_subsquent_mask(trg_seq).numpy().astype(bool)
        trg_mask = paddle.to_tensor(trg_mask)
        print("trg_mask:",trg_mask.shape)
        enc_output, *_ = self.encoder(src_seq, src_mask)

        print("trg_seq,enc_output:",trg_seq.shape, enc_output.shape)
        dec_output, *_ = self.decoder(trg_seq, trg_mask, enc_output, src_mask)

        seq_logit= self.trg_word_prj(dec_output) * self.x_logit_scale
        print("seq_logit:",seq_logit.shape)

        return seq_logit.reshape((-1, seq_logit.shape[2]))

In [13]:
import warnings
warnings.filterwarnings("ignore")

In [14]:
test_data = paddle.to_tensor(100*np.random.random((3, 10)), dtype='int64')
print("*"*30)
print(test_data)
print("*"*30)
enc = Encoder()
dec = Decoder()
transformer = Transformer(n_head=3, n_layers=6, src_pad_idx=0, trg_pad_idx=0, n_src_vocab=200, n_trg_vocab=200)
enc_output, *_ = enc(test_data, src_mask=None)


dec(test_data, trg_mask=None, enc_output=enc_output, src_mask=None)
t = transformer(test_data, test_data)
print("*"*30)
print(t)
print("*"*30)

******************************
Tensor(shape=[3, 10], dtype=int64, place=CPUPlace, stop_gradient=True,
       [[16, 89, 43, 31, 64, 79, 75, 41, 19, 47],
        [40, 91, 60, 58, 48, 62, 80, 65, 10, 89],
        [75, 62, 5 , 92, 8 , 81, 34, 65, 32, 77]])
******************************
[200, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 2048]
[2048, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 2048]
[2048, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 2048]
[2048, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 2048]
[2048, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 2048]
[2048, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 2048]
[2048, 512]
[200, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 2048]
[2048, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 2048]
[2048, 512]
[512, 192]
[512, 192]
[512, 192]
[192, 512]
[512, 192]
[512, 1