In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

paper : https://arxiv.org/abs/1706.03762

http://nlp.seas.harvard.edu/2018/04/03/attention.html

maps input sequences to a sequence of continuous representations $x=(x_1 \cdots, x_n) \rightarrow z=(z_1, \cdots, y_n)$, given $z$ generates an output sequence $y=(y_1, \cdots, y_n)$

## 3.2.1 Scaled Dot-Product Attention

In [2]:
class ScaledDotProductAttention(nn.Module):
    """Scaled Dot-Product Attention"""
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = torch.tensor(d_k).float()
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, q, k, v, mask=None):
        """
        Inputs:
        * q: (B, T_q, d_q), d_q = d_k
        * k: (B, T_k, d_k)
        * v: (B, T_v, d_v), T_k = T_v
        -------------------------------
        Outputs:
        * output: (B, T_q, d_v)
        * probs: (B, T_q, T_k)
        """
        assert q.size(2) == k.size(2), "d_q = d_k"
        assert k.size(1) == v.size(1), "T_k = T_v"
        attn = torch.bmm(q, k.transpose(1, 2))  # (B, T_q, d_k) * (B, T_k, d_k) -> (B, T_q, T_k)
        attn = attn / torch.sqrt(self.d_k)
        # why doing this? 
        # for the large values of d_k, the dot products grow large in magnitude, 
        # pushing the softmax function into regions where it has extremely small gradients
        # to counteract this effect, scaled the dot products by 1/sqrt(d_k)
        # to illustrate why the dot products get large,
        # check the function 'check_dotproduct_dist'
        if mask is not None:
            attn = attn.masked_fill_(mask, -np.inf)
        
        attn = self.softmax(attn)  # (B, T_q, T_k) --> (B, T_q, T_k)
        output = torch.bmm(attn, v)  # (B, T_q, T_k) * (B, T_v, d_v) --> (B, T_q, d_v), make sure that T_k == T_v
        return output, attn

In [3]:
# ex) translation: 
# q - previous decoder
# k, v - encoder output
batch = 6
T_q = 1
T_k, T_v = (7, 7)
d_k = 10
d_v = 12
q, k, v = torch.randn((batch, T_q, d_k)), torch.randn((batch, T_k, d_k)), torch.randn((batch, T_v, d_v))
attention = ScaledDotProductAttention(d_k)

In [4]:
output, attn = attention(q, k, v)
output.size(), attn.size()

(torch.Size([6, 1, 12]), torch.Size([6, 1, 7]))

In [5]:
def check_dotproduct_dist(d_k, sampling_size=1, seq_len=1, threshold=1e-10):
    """
    to check "https://arxiv.org/abs/1706.03762" Paper page 4, annotation 4
    -------------------------------
    To illustrate why the dot products get large, 
    assume that the components of q and k are independent random variables 
    with mean 0 and variance 1.
    Then their dot product has mean 0 and variance d_k
    """
    def cal_grad(attn):
        y = torch.softmax(attn, dim=2)
        return y * (1-y)
    
    q = nn.init.normal_(torch.rand((sampling_size, seq_len, d_k)), mean=0, std=1)
    k = nn.init.normal_(torch.rand((sampling_size, seq_len, d_k)), mean=0, std=1)
    attn = torch.bmm(q, k.transpose(1, 2))
    print('size of vector d_k is {}, sampling result, dot product distribution has \n - mean: {:.4f}, \n - var: {:.4f}'.\
          format(d_k, attn.mean().item(), attn.var().item()))
    grad = cal_grad(attn)
    print( "count of gradients that smaller than threshod({}) is {}, {:.4f}%".format(
        threshold, grad.le(threshold).sum(), grad.le(threshold).sum().item()/grad.view(-1).size(0)*100 ) )
    attn2 = attn / torch.sqrt(torch.as_tensor(d_k).float())
    grad2 = cal_grad(attn2)
    print( "after divide by sqrt(d_k), count of gradients that smaller than threshod({}) is {}, {:.4f}% \n".format(
        threshold, grad2.le(threshold).sum(), grad2.le(threshold).sum().item()/grad2.view(-1).size(0)*100 ) )

In [6]:
print("*** notice that the gradient of softmax is y(1-y) ***")
for d_k in [10, 100, 1000]:
    check_dotproduct_dist(d_k, sampling_size=100000, seq_len=5, threshold=1e-10)

*** notice that the gradient of softmax is y(1-y) ***
size of vector d_k is 10, sampling result, dot product distribution has 
 - mean: -0.0027, 
 - var: 10.0053
count of gradients that smaller than threshod(1e-10) is 176, 0.0070%
after divide by sqrt(d_k), count of gradients that smaller than threshod(1e-10) is 0, 0.0000% 

size of vector d_k is 100, sampling result, dot product distribution has 
 - mean: -0.0039, 
 - var: 100.0387
count of gradients that smaller than threshod(1e-10) is 402517, 16.1007%
after divide by sqrt(d_k), count of gradients that smaller than threshod(1e-10) is 0, 0.0000% 

size of vector d_k is 1000, sampling result, dot product distribution has 
 - mean: -0.0409, 
 - var: 998.2934
count of gradients that smaller than threshod(1e-10) is 1735779, 69.4312%
after divide by sqrt(d_k), count of gradients that smaller than threshod(1e-10) is 0, 0.0000% 



## 3.2.2 Multi-Head Attention

In [7]:
class XavierLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(XavierLinear, self).__init__()
        self.linear = nn.Linear(in_features, out_features, bias=bias)
        nn.init.xavier_normal_(self.linear.weight)

    def forward(self, inputs):
        return self.linear(inputs)

In [8]:
class MultiHeadAttention(nn.Module):
    """Multi-head Attention"""
    def __init__(self, n_head, d_model, d_k, d_v, drop_rate=0.1):
        """
        paper setting: n_head = 8, d_k = d_v = d_model / n_head = 64
        Multi-head attention allows the model to jointly attend to information from 
        different representation subspaces at different positions.
        with a single attention head, averaging inhibits this.
        """
        super(MultiHeadAttention, self).__init__()
        self.n_head = n_head
        self.d_model = d_model
        self.d_k = d_k
        self.d_v = d_v
        self.linear_q = XavierLinear(d_model, d_k)
        self.linear_k = XavierLinear(d_model, d_k)
        self.linear_v = XavierLinear(d_model, d_v)
        self.linear_o = XavierLinear(n_head*d_v, d_model)
        self.attention = ScaledDotProductAttention(d_k)
        self.drop_out = nn.Dropout(drop_rate)
        
    def forward(self, q, k, v, mask=None):
        """
        Inputs:
        * q: (B, T_q, d_model)
        * k: (B, T_k, d_model)
        * v: (B, T_v, d_model)
        ---------------------
        Outputs:
        * output: (B, T_q, d_model)
        * attn: (n_head * B, T_q, T_k)
        """
        n_head, d_model, d_k, d_v = self.n_head, self.d_model, self.d_k, self.d_v
        
        # repeat to compute n_heads
        n_qs = q.repeat(n_head, 1, 1)  # (n_head * B, T_q, d_model)
        n_ks = k.repeat(n_head, 1, 1)  # (n_head * B, T_k, d_model)
        n_vs = v.repeat(n_head, 1, 1)  # (n_head * B, T_v, d_model)
        if mask is not None:
            mask = mask.repeat(n_head, 1, 1)
        
        # through linear layer: 
        lin_qs = self.linear_q(n_qs)  # (n_head * B, T_q, d_model) --> (n_head * B, T_q, d_k) 
        lin_ks = self.linear_q(n_ks)  # (n_head * B, T_k, d_model) --> (n_head * B, T_k, d_k) 
        lin_vs = self.linear_q(n_vs)  # (n_head * B, T_v, d_model) --> (n_head * B, T_v, d_v)
        
        # attention: Scaled Dot-Product Attention
        ## heads: (n_head * B, T_q, d_v)
        ## attn: (n_head * B, T_q, T_k)
        heads, attn = self.attention(q=lin_qs, k=lin_ks, v=lin_vs, mask=mask)
        
        # concat
        heads_cat = torch.cat(list(heads.chunk(n_head, dim=0)), dim=-1)  # (n_head * B, T_q, d_v) --> (B, T_q, n_head * d_v)
        output = self.linear_o(heads_cat)  # (B, T_q, n_head * d_v) --> (B, T_q, d_model)
        output = self.drop_out(output)
        return output, attn

In [9]:
batch = 6
T_q = 1
T_k, T_v = (7, 7)
n_head = 8
d_model = 64*n_head
d_k = 64
d_v = 64
q, k, v = torch.randn((batch, T_q, d_model)), torch.randn((batch, T_k, d_model)), torch.randn((batch, T_v, d_model))
q.size(), k.size(), v.size()

(torch.Size([6, 1, 512]), torch.Size([6, 7, 512]), torch.Size([6, 7, 512]))

In [10]:
multiheadattention = MultiHeadAttention(n_head, d_model, d_k, d_v)
o, attn = multiheadattention(q, k, v, mask=None)
multiheadattention.modules

<bound method Module.modules of MultiHeadAttention(
  (linear_q): XavierLinear(
    (linear): Linear(in_features=512, out_features=64, bias=True)
  )
  (linear_k): XavierLinear(
    (linear): Linear(in_features=512, out_features=64, bias=True)
  )
  (linear_v): XavierLinear(
    (linear): Linear(in_features=512, out_features=64, bias=True)
  )
  (linear_o): XavierLinear(
    (linear): Linear(in_features=512, out_features=512, bias=True)
  )
  (attention): ScaledDotProductAttention(
    (softmax): Softmax()
  )
  (drop_out): Dropout(p=0.1)
)>

In [11]:
o.size(), attn.size()

(torch.Size([6, 1, 512]), torch.Size([48, 1, 7]))

## 3.2.3 Application

### Encoder-Decoder attention

* queries: the previous decoder layer
* keys & values: output of the encoder

### Encoder + self-attention

* queries, keys, values come form the output of the previous layer in the encoder

### Decoder + self-attention

* need to prevent leftward information flow in the decoder to preserve the auto-regressive property. 
* implement by masking out (-inf)

## 3.3 Position-wise Feed-Forward Networks

$$FFN(x) = max(0, xW_1 + b_1)W_2 + b_2$$
$$\begin{aligned} W_1 &\in \Bbb{R}^{d_{model} \times d_f} \\ 
b_1 &\in \Bbb{R}^{d_f} \\
W_2 &\in \Bbb{R}^{d_f \times d_{model}} \\ 
b_2 &\in \Bbb{R}^{d_{model}} \\
\end{aligned}$$

same as $FFN = Linear(ReLU(Linear(x)) = Conv1d(ReLU(Conv1d))$

In [12]:
class PositionWiseFFN(nn.Module):
    """Position-wise Feed-Forward Networks"""
    def __init__(self, d_model, d_f, drop_rate=0.1, use_conv=False):
        super(PositionWiseFFN, self).__init__()
        self.use_conv = use_conv
        if use_conv:
            self.fc = nn.Sequential(
                nn.Conv1d(d_model, d_f, kernel_size=1),
                nn.ReLU(),
                nn.Conv1d(d_f, d_model, kernel_size=1)
            )
        else:
            self.fc = nn.Sequential(
                nn.Linear(d_model, d_f),
                nn.ReLU(),
                nn.Linear(d_f, d_model)
            )
        self.drop_out = nn.Dropout(drop_rate)
    
    def forward(self, x):
        """
        Inputs:
        x: (B, T, d_model)
        -----------------------
        Ouputs:
        output: (B, T, d_model)
        """
        if self.use_conv:
            x = x.transpose(1, 2)  # (B, T, d_model) --> (B, d_model, T), reshape like (batch, channel, dim)
            output = self.fc(x).transpose(1, 2)  # (B, d_model, T) --> (B, T, d_model)
        else:
            output = self.fc(x)
            
        output = self.drop_out(output)
        return output

In [13]:
d_f = d_model*4
PWFFN = PositionWiseFFN(d_model, d_f, use_conv=False)
PWFFN(q).size()

torch.Size([6, 1, 512])

In [14]:
PWFFN = PositionWiseFFN(d_model, d_f, use_conv=True)
PWFFN(q).size()

torch.Size([6, 1, 512])

## 3.5 Positional Encoding

$$\begin{aligned} PE_{(pos, 2i)} &= sin(pos/10000^{2i / d_{model}}) \\
PE_{(pos, 2i+1)} &= cos(pos/10000^{2i / d_{model}}) \\
\end{aligned}$$

In [15]:
class PositionalEncoding(nn.Module):
    """Positional Encoding"""
    def __init__(self, n_pos, d_model, pad_idx=None):
        """
        n_pos = max sequence length + 1
        """
        super(PositionalEncoding, self).__init__()
        self.n_pos = n_pos
        self.d_model = d_model
        self.pad_idx = pad_idx
        self.pe_table = np.array(self.get_pe_table())
        self.pe_table[:, 0::2] = np.sin(self.pe_table[:, 0::2])
        self.pe_table[:, 1::2] = np.cos(self.pe_table[:, 1::2])
        if pad_idx is not None:
            # zero vector for padding dimension
            self.pe_table[pad_idx] = 0.
            
        self.pe = nn.Embedding.from_pretrained(torch.FloatTensor(self.pe_table), freeze=True)
        
    def cal_angle(self, pos, hid_idx):
        return pos / (10000 ** ((2*(hid_idx // 2) / self.d_model)) )
    
    def get_pe_table(self):
        return [[self.cal_angle(pos, i) for i in range(self.d_model)] for pos in range(self.n_pos)]         
        
    def forward(self, inputs):
        return self.pe(inputs)

In [16]:
n_pos = 3
vocab_len = 10
pos_layer = PositionalEncoding(n_pos+1, d_model, pad_idx=0)  # n_pos + 1 for pad idx
embed_layer = nn.Embedding(vocab_len, d_model)

In [17]:
x = torch.LongTensor(np.array([[6, 5, 3], [1, 3, 0], [5, 3, 6], [3, 6, 2]]))
po_x = torch.LongTensor(np.array([[1, 2, 3], [1, 2, 0], [1, 2, 3], [1, 2, 3]]))
embed_layer(x).size(), pos_layer(po_x).size()

(torch.Size([4, 3, 512]), torch.Size([4, 3, 512]))

In [18]:
inputs = embed_layer(x) + pos_layer(po_x)
inputs.size()

torch.Size([4, 3, 512])

## Layers

In [19]:
class Encode_Layer(nn.Module):
    """encode layer"""
    def __init__(self, n_head, d_model, d_k, d_v, d_f, drop_rate=0.1, use_conv=False):
        super(Encode_Layer, self).__init__()
        self.selfattn = MultiHeadAttention(n_head, d_model, d_k, d_v, drop_rate=drop_rate)
        self.pwffn = PositionWiseFFN(d_model, d_f, drop_rate=drop_rate, use_conv=use_conv)
        self.norm_selfattn = nn.LayerNorm(d_model)
        self.norm_pwffn = nn.LayerNorm(d_model)
        
    def forward(self, enc_input, enc_mask=None):
        """
        Inputs:
        * enc_input: (B, T, d_model)
        * enc_mask: (B, T, T)
        -------------------------------------
        Outputs:
        * enc_output: (B, T, d_model)
        * enc_attn: (n_head * B, T, T)
        """
        # Layer: Multi-Head Attention + Add & Norm
        # encode self-attention
        enc_output, enc_attn = self.selfattn(enc_input, enc_input, enc_input, mask=enc_mask)
        enc_output = self.norm_selfattn(enc_input + enc_output)
        
        # Layer: PositionWiseFFN + Add & Norm
        pw_output = self.pwffn(enc_output)
        enc_output = self.norm_pwffn(enc_output + pw_output)
        
        return enc_output, enc_attn


class Decode_Layer(nn.Module):
    """decode layer"""
    def __init__(self, n_head, d_model, d_k, d_v, d_f, drop_rate=0.1, use_conv=False):
        super(Decode_Layer, self).__init__()
        self.selfattn_masked = MultiHeadAttention(n_head, d_model, d_k, d_v, drop_rate=drop_rate)
        self.dec_enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, drop_rate=drop_rate)
        self.pwffn = PositionWiseFFN(d_model, d_f, drop_rate=drop_rate, use_conv=use_conv)
        self.norm_selfattn_masked = nn.LayerNorm(d_model)
        self.norm_dec_enc_attn = nn.LayerNorm(d_model)
        self.norm_pwffn = nn.LayerNorm(d_model)
    
    def forward(self, dec_input, enc_output, dec_self_mask=None, dec_enc_mask=None):
        """
        Inputs:
        * dec_input: (B, T_q, d_model)
        * enc_input: (B, T, d_model)
        * dec_self_mask: (B, T_q, T_q)
        * dec_enc_mask: (B, T_q, T)
        -------------------------------------
        Outputs:
        * dec_output: (B, T_q, d_model)
        * dec_self_attn: (n_head * B, T_q, T_q)
        * dec_enc_attn: (n_head * B, T_q, T)
        """
        # Layer: Multi-Head Attention + Add & Norm
        # decode self-attention
        dec_self_output, dec_self_attn = self.selfattn_masked(dec_input, dec_input, dec_input, 
                                                              mask=dec_self_mask)
        dec_self_output = self.norm_selfattn_masked(dec_input + dec_self_output)
        
        # Layer: Multi-Head Attention + Add & Norm
        # decode output(queries) + encode output(keys, values)
        dec_output, dec_enc_attn = self.dec_enc_attn(dec_self_output, enc_output, enc_output, 
                                                     mask=dec_enc_mask)
        dec_output = self.norm_dec_enc_attn(dec_self_output + dec_output)
        
        # Layer: PositionWiseFFN + Add & Norm
        pw_output = self.pwffn(dec_output)
        dec_output = self.norm_pwffn(dec_output + pw_output)
        
        return dec_output, dec_self_attn, dec_enc_attn

In [20]:
target_n_pos = 4  # equal to max_seq_len
t = torch.LongTensor(np.array([[1, 0, 0, 0], [5, 7, 9, 2], [3, 7, 0, 0], [2, 9, 4, 0]]))
po_t = torch.LongTensor(np.array([[1, 0, 0, 0], [1, 2, 3, 4], [1, 2, 0, 0], [1, 2, 3, 0]]))
pos_layer = PositionalEncoding(target_n_pos+1, d_model, pad_idx=0)  # n_pos + 1 for pad idx
embed_layer = nn.Embedding(10, d_model)
embed_layer(t).size(), pos_layer(po_t).size()

(torch.Size([4, 4, 512]), torch.Size([4, 4, 512]))

In [21]:
target_inputs = embed_layer(t) + pos_layer(po_t)
target_inputs.size()

torch.Size([4, 4, 512])

In [22]:
enc_layer = Encode_Layer(n_head, d_model, d_k, d_v, d_f)
dec_layer = Decode_Layer(n_head, d_model, d_k, d_v, d_f)
enc_output, enc_attn = enc_layer.forward(inputs)
dec_output, dec_self_attn, dec_enc_attn = dec_layer.forward(target_inputs, enc_output)

In [23]:
enc_output.size(), enc_attn.size()

(torch.Size([4, 3, 512]), torch.Size([32, 3, 3]))

In [24]:
dec_output.size(), dec_self_attn.size(), dec_enc_attn.size()

(torch.Size([4, 4, 512]), torch.Size([32, 4, 4]), torch.Size([32, 4, 3]))

## Models: Encoder & Decoder - Transformer

In [25]:
def get_padding_mask(q, k=None, pad_idx=0, mode='attn'):
    """
    mode: attn
    > mask out for pad in attention with queries & keys sequences
    > return shape: (B, T_q, T_k)
    mode: subseq
    > mask out next tokens to preserve 'auto-regressive property'
    > return shape: (B, T_q, T_q)
    """
    B, q_len = q.size()
    if mode == 'attn':
        assert k is not None, "must have key sequences"
        padding_mask = k.eq(pad_idx)
        padding_mask = padding_mask.unsqueeze(1).expand(B, q_len, -1)
        return padding_mask
    elif mode =='subseq':
        assert k is None, "don't need key sequences"
        subseq_mask = torch.triu(torch.ones((q_len, q_len), device=q.device, dtype=torch.uint8), 
                                 diagonal=1)
        subseq_mask = subseq_mask.unsqueeze(0).expand(B, -1, -1)
        return subseq_mask

In [26]:
class Encoder(nn.Module):
    def __init__(self, vocab_len, max_seq_len, n_layer, n_head, d_model, d_k, d_v, d_f, 
                 pad_idx=0, drop_rate=0.1, use_conv=False, return_attn=True):
        super(Encoder, self).__init__()
        self.pad_idx = pad_idx
        self.return_attn = return_attn
        self.embed_layer = nn.Embedding(vocab_len, d_model, padding_idx=pad_idx)
        self.pos_layer = PositionalEncoding(max_seq_len+1, d_model, pad_idx)
        self.layers = nn.ModuleList([Encode_Layer(n_head, d_model, d_k, d_v, d_f, 
                                                  drop_rate=drop_rate, 
                                                  use_conv=use_conv) \
                                     for i in range(n_layer)])
        
    def forward(self, enc, enc_pos):
        """
        Inputs:
        * enc: (B, T)
        * enc_pos: (B, T)
        -------------------------------------
        Outputs:
        * enc_output: (B, T, d_model)
        * self_attns: (n_layer, n_head*B, T, T)
        """
        self_attns = []  # (n_layer, n_head*B, T, T)
        # self attention padding mask: (B, T, T)
        attn_mask = get_padding_mask(q=enc, k=enc, pad_idx=self.pad_idx, mode='attn')
        
        # embedding + position encoding: (B, T) --> (B, T, d_model)
        enc_output = self.embed_layer(enc) + self.pos_layer(enc_pos)
        
        # forward encode layer
        for enc_layer in self.layers:
            enc_output, enc_self_attn = enc_layer(enc_input=enc_output, enc_mask=attn_mask)
            if self.return_attn:
                self_attns.append(enc_self_attn)
        
        if self.return_attn:
            return enc_output, self_attns
        return enc_output

In [27]:
class Decoder(nn.Module):
    def __init__(self, vocab_len, max_seq_len, n_layer, n_head, d_model, d_k, d_v, d_f, 
                 pad_idx=0, drop_rate=0.1, use_conv=False, return_attn=True):
        super(Decoder, self).__init__()
        self.pad_idx = pad_idx
        self.return_attn = return_attn
        self.embed_layer = nn.Embedding(vocab_len, d_model, padding_idx=pad_idx)
        self.pos_layer = PositionalEncoding(max_seq_len+1, d_model, pad_idx)
        self.layers = nn.ModuleList([Decode_Layer(n_head, d_model, d_k, d_v, d_f, 
                                                  drop_rate=drop_rate, 
                                                  use_conv=use_conv) \
                                     for i in range(n_layer)])
        
    def forward(self, dec, dec_pos, enc, enc_output):
        """
        Inputs:
        * dec: (B, T_q)
        * dec_pos: (B, T_q)
        * enc: (B, T)
        * enc_output: (B, T, d_model)
        -------------------------------------
        Outputs:
        * dec_output: (B, T_q, d_model)
        * self_attns: (n_layer, n_head*B, T_q, T_q)
        * dec_enc_attns: (n_layer, n_haed*B, T_q, T)
        """
        self_attns = []  # (n_layer, n_head*B, T_q, T_q)
        dec_enc_attns = []  # (n_layer, n_head*B, T_q, T)
        
        # self attention padding mask: (B, T_q, T)
        attn_mask = get_padding_mask(q=dec, k=dec, pad_idx=self.pad_idx, mode='attn')
        subseq_mask = get_padding_mask(q=dec, mode='subseq')
        self_attn_mask = (attn_mask + subseq_mask).gt(0)
        # enc_dec attention padding mask
        dec_enc_attn_mask = get_padding_mask(q=dec, k=enc, pad_idx=self.pad_idx, mode='attn')
        
        # embedding + position encoding: (B, T) --> (B, T, d_model)
        dec_output = self.embed_layer(dec) + self.pos_layer(dec_pos)
        
        # forward decode layer
        for dec_layer in self.layers:
            dec_output, dec_self_attn, dec_enc_attn = dec_layer(dec_input=dec_output, 
                                                                enc_output=enc_output, 
                                                                dec_self_mask=self_attn_mask, 
                                                                dec_enc_mask=dec_enc_attn_mask)
            if self.return_attn:
                self_attns.append(dec_self_attn)
                dec_enc_attns.append(dec_enc_attn)
        
        if self.return_attn:
            return dec_output, self_attns, dec_enc_attns
        return dec_output

In [28]:
n_layer = 3
encoder = Encoder(vocab_len, 3, n_layer, n_head, d_model, d_k, d_v, d_f)
decoder = Decoder(vocab_len, 4, n_layer, n_head, d_model, d_k, d_v, d_f)

In [29]:
enc_output, enc_self_attns = encoder(x, po_x)
dec_output, dec_self_attns, dec_enc_attns = decoder.forward(t, po_t, x, enc_output)

In [30]:
enc_output.size(), dec_output.size()

(torch.Size([4, 3, 512]), torch.Size([4, 4, 512]))

In [31]:
class Transformer(nn.Module):
    def __init__(self, enc_vocab_len, enc_max_seq_len, dec_vocab_len, dec_max_seq_len, 
                 n_layer, n_head, d_model, d_k, d_v, d_f, 
                 pad_idx=0, drop_rate=0.1, use_conv=False, return_attn=True,
                 linear_weight_share=True, embed_weight_share=True):
        super(Transformer, self).__init__()
        self.return_attn = return_attn
        self.encoder = Encoder(enc_vocab_len, enc_max_seq_len, n_layer, n_head, 
                               d_model, d_k, d_v, d_f, 
                               pad_idx=pad_idx, 
                               drop_rate=drop_rate, 
                               use_conv=use_conv, 
                               return_attn=return_attn)
        self.decoder = Decoder(dec_vocab_len, dec_max_seq_len, n_layer, n_head, 
                               d_model, d_k, d_v, d_f,
                               pad_idx=pad_idx, 
                               drop_rate=drop_rate, 
                               use_conv=use_conv, 
                               return_attn=return_attn)
        self.projection = XavierLinear(d_model, dec_vocab_len, bias=False)
        if linear_weight_share:
            # share the same weight matrix between the decoder embedding layer 
            # and the pre-softmax linear transformation
            self.projection.linear.weight = self.decoder.embed_layer.weight
        
        if embed_weight_share:
            # share the same weight matrix between the decoder embedding layer 
            # and the encoder embedding layer
            assert enc_vocab_len == dec_vocab_len, "vocab length must be same"
            self.encoder.embed_layer.weight = self.decoder.embed_layer.weight
            
    def forward(self, enc, enc_pos, dec, dec_pos):
        """
        Inputs:
        * enc: (B, T)
        * enc_pos: (B, T)
        * dec: (B, T_q)
        * dec_pos: (B, T_q)
        -------------------------------------
        Outputs:
        * dec_output: (B, T_q, d_model)
        * attns_dict:
            * enc_self_attns: (n_layer, n_head*B, T, T)
            * dec_self_attns: (n_layer, n_head*B, T_q, T_q)
            * dec_enc_attns: (n_layer, n_haed*B, T_q, T)
        """
        enc_output, enc_self_attns = self.encoder(enc, enc_pos)
        dec_output, dec_self_attns, dec_enc_attns = self.decoder(dec, dec_pos, enc, enc_output)
        dec_output = self.projection(dec_output)
        attns_dict = {'enc_self_attns': enc_self_attns, 
                     'dec_self_attns': dec_self_attns,
                     'dec_enc_attns': dec_enc_attns}
        if self.return_attn:
            return dec_output, attns_dict
        return dec_output

In [32]:
model = Transformer(vocab_len, 3, vocab_len, 4, n_layer, n_head, d_model, d_k, d_v, d_f, 
                 pad_idx=0, drop_rate=0.1, use_conv=False, return_attn=True,
                 linear_weight_share=True, embed_weight_share=True)

In [33]:
dec_output, attns_dict = model(x, po_x, t, po_t)

In [34]:
dec_output.size()

torch.Size([4, 4, 10])