In [17]:
import torch
import torch.nn as nn
import math

In [2]:
# positionalEncoding is a fixed mat, only add 2 matrix in forward pass
class PositionalEncoding(nn.Module):
    def __init__(self, max_len, d_model):
        super().__init__()
        self.max_len = max_len
        self.d_model = d_model
        self.encoding = torch.zeros([max_len, d_model], requires_grad=False)    # (max_len, d_model)
        
        pos = torch.arange(0, max_len).unsqueeze(1)   # (max_len, 1)
        _2i = torch.arange(0, d_model, 2)    # (d_model/2, )
        
        # PE(pos, 2i) = sin(pos/ 10000 ** (2i/d_model)) ; PE(pos, 2i+1) = cos(pos/ 10000 ** (2i/d_model))
        # broadcast in this way:
        # 10000 ** (_2i / d_model) -> (d_model/2, ) -> (1, d_model/2) -> (max_len, d_model/2)             
        # pos -> (max_len, 1) -> (max_len, d_model/2)
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))    
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        
    def forward(self, x):   # (batch_size, max_len, d_model)  
        return x + self.encoding    # (batch_size, max_len, d_model)

In [3]:
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(max_len, d_model)
        self.dropout = nn.Dropout(p=drop_prob)
        
    def forward(self, x):   # (batch, max_len)
        x = self.tok_emb(x)
        x = self.pos_emb(x)
        return self.dropout(x)  # (batch_ max_len, d_model)

In [4]:
# scaled dot-product attention
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, q, k, v, mask=None):
        # (batch_size, h, max_len, dk)
        # Attention(Q,K,V) = softmax(Q @ K_T / sqrt(dk)) @ V
        _, _, _, dk = k.shape
        
        k_T = k.transpose(2, 3) # (batch_size, h, dk, max_len)
        scores = torch.matmul(q, k_T) / math.sqrt(dk)  # (batch_size, h, max_len, max_len)
        
        if mask is not None:    # (batch_size, 1, 1, max_len) or (batch_size, 1, max_len, 1)
            scores = scores.masked_fill(mask == 0, -10000)
        
        scores = self.softmax(scores)  # (batch_size, h, max_len, max_len)
        
        outputs = torch.matmul(scores, v)   # (batch_size, h, max_len, dv)
        return outputs, scores

In [5]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_head):
        super().__init__()
        assert d_model % n_head == 0, "d_model is not divisible by n_head"
        
        self.n_head = n_head
        self.attention = ScaledDotProductAttention()
        
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)
        
    def forward(self, q, k, v, mask=None):
        # shapes of q,k,v are all (batch_size, max_len, d_model)
        q, k, v = self.w_q(q), self.w_k(k), self.w_v(v) # shapes do not change
        
        batch_size, max_len, d_model = k.shape
        q = q.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)    # (batch_size, n_head, max_len, d_tensor)  where d_tensor = d_model / n_head
        k = k.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)    
        v = v.view(batch_size, max_len, self.n_head, -1).transpose(1, 2)
        
        outputs, scores = self.attention(q, k, v, mask=mask)    # (batch_size, n_head, max_len, d_tensor), (batch_size, n_head, max_len, max_len)
        
        # concat
        outputs = outputs.transpose(1, 2)
        outputs = outputs.contiguous().view(batch_size, max_len, d_model)       # (batch_size, max_len, d_model)
        
        # Linear
        outputs = self.w_o(outputs) # (batch_size, max_len, d_model)
        return outputs

In [6]:
# layer norm (bn in dim -1)
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super().__init__()
        self.gama = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x):   # (batch_size, max_len, d_model)
        miu = torch.mean(x, dim=-1, keepdim=True)   # (batch_size, max_len, 1)
        var = torch.var(x, dim=-1, keepdim=True, unbiased=False)    # in paper, used biased var
        
        x_bar = (x - miu) / torch.sqrt(var + self.eps)  # (batch_size, max_len, d_model)
        return self.gama * x_bar + self.beta    # (batch_size, max_len, d_model)

In [7]:
# position wise feed forward
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, n_hidden, dropout=0.1):
        super().__init__()
        self.linear_1 = nn.Linear(d_model, n_hidden)
        self.relu = nn.ReLU()
        self.linear_2 = nn.Linear(n_hidden, d_model)
        self.dropout = nn.Dropout(p=dropout)
    
    def forward(self, x):   # (batch_size, max_len, d_model)
        x = self.linear_1(x) # (batch_size, max_len, n_hidden)
        x = self.relu(x)
        x = self.linear_2(x) # (batch_size, max_len, n_hidden)
        return self.dropout(x)

In [8]:
# encoder layer
class EncoderLayer(nn.Module):
    def __init__(self, d_model, n_hidden, n_head, drop_prob):
        super().__init__()
        # sublayer 1
        self.attention = MultiHeadAttention(d_model, n_head)
        self.norm1 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)
        
        # sublayer 2
        self.ffn = PositionWiseFeedForward(d_model, n_hidden, drop_prob)
        self.norm2 = LayerNorm(d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)
    
    def forward(self, x, src_mask):   # (batch_size, max_len, n_hidden)  all block will not change the shape of tensor
        # sublayer 1
        _x = x
        x = self.attention(x, x, x, mask=src_mask)
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        
        # sublayer2 
        _x = x
        x = self.ffn(x)
        x = self.dropout2(x)
        x = self.norm2(x + _x)
        return x

In [9]:
class Encoder(nn.Module):
    def __init__(self, enc_vocab_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob):
        super().__init__()
        self.emb = TransformerEmbedding(enc_vocab_size, d_model, max_len, drop_prob)
        self.layers = nn.ModuleList([EncoderLayer(
            d_model=d_model,
            n_hidden=ffn_hidden,
            n_head=n_head,
            drop_prob=drop_prob
        ) for _ in range(n_layers)])
        
    def forward(self, x, src_mask):   # (batch_size, max_len)
        x = self.emb(x)
        
        for layer in self.layers:
            x = layer(x, src_mask)
        
        return x

In [10]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, n_hidden, n_head, drop_prob):
        super().__init__()
        # sublayer 1
        self.attention1 = MultiHeadAttention(d_model, n_head)   # self attention
        self.norm1 = LayerNorm(d_model)
        self.dropout1 = nn.Dropout(p=drop_prob)
        
        # sublayer 2
        self.attention2 = MultiHeadAttention(d_model, n_head)   # enc_dec_attention
        self.norm2 = LayerNorm(d_model)
        self.dropout2 = nn.Dropout(p=drop_prob)
        
        # sublayer 3
        self.ffn = PositionWiseFeedForward(d_model, n_hidden, drop_prob)
        self.norm3 = LayerNorm(d_model)
        self.dropout3 = nn.Dropout(p=drop_prob)
        
    def forward(self, dec, enc, trg_mask, src_mask):
        # dec is the previous output  (batch_size, max_length, d_model)
        _x = dec
        x = self.attention1(dec, dec, dec, mask=trg_mask)   # don't get information of following tokens  -> tril mat
        x = self.dropout1(x)
        x = self.norm1(x + _x)
        
        if enc is not None:
            _x = x
            x = self.attention2(q=x, k=enc, v=enc, mask=src_mask)   # (batch_size, max_length, d_model)    -> change <pad> -> negative infinity
            
            x = self.dropout2(x)
            x = self.norm2(x + _x)
        
        _x = x  
        x = self.ffn(x)
        x = self.dropout3(x)
        x = self.norm3(x + _x)
        return x    # (batch_size, max_length, d_model)

In [11]:
class Decoder(nn.Module):
    def __init__(self, dec_vocab_size, max_len, d_model, ffn_hidden, n_head, n_layers, drop_prob):
        super().__init__()
        self.emb = TransformerEmbedding(d_model=d_model, drop_prob=drop_prob, max_len=max_len, vocab_size=dec_vocab_size)
        self.layers = nn.ModuleList([DecoderLayer(d_model, ffn_hidden, n_head, drop_prob) for _ in range(n_layers)])
        self.linear = nn.Linear(d_model, dec_vocab_size)
        
    def forward(self, trg, src, trg_mask, src_mask):
        trg = self.emb(trg)
        for layer in self.layers:
            trg = layer(trg, src, trg_mask, src_mask)
            
        output = self.linear(trg)
        return output

In [18]:
class Transformer(nn.Module):
    def __init__(self, src_pad_idx, trg_pad_idx, trg_bos_idx, enc_voc_size, dev_voc_size, d_model, n_head, max_len,
                 ffn_hidden, n_layers, drop_prob):
        super().__init__()
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.trg_bos_idx = trg_bos_idx
        self.encoder = Encoder(d_model=d_model,
                               n_head=n_head,
                               max_len=max_len,
                               ffn_hidden=ffn_hidden,
                               enc_vocab_size=enc_voc_size,
                               drop_prob=drop_prob,
                               n_layers=n_layers)
        self.decoder = Decoder(d_model=d_model,
                               n_head=n_head,
                               max_len=max_len,
                               ffn_hidden=ffn_hidden,
                               dec_vocab_size=dev_voc_size,
                               drop_prob=drop_prob,
                               n_layers=n_layers)
        
    def forward(self, src, trg):    
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(src)
        enc_src = self.encoder(src, src_mask)
        output = self.decoder(trg, enc_src, trg_mask, src_mask)
        
    def make_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2).to(torch.long)
        return src_mask
    
    def make_trg_mask(self, trg):
        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(3)
        trg_len = trg.shape[1]
        trg_sub_mask = torch.tril(torch.ones(trg_len, trg_len)).type(torch.ByteTensor).to("cuda")
        trg_mask = trg_pad_mask & trg_sub_mask
        return trg_mask

In [39]:
# internal mechanism of src_mask
a = torch.tensor([[1, 2, 3, 4, 0, 0, 0],    # (2, 7)    -> (batch_size, max_len)
                  [1, 9, 2, 0, 0, 0, 0]])
src_mask = (a != 0).unsqueeze(1).unsqueeze(2).to(torch.long)   # (batch_size, 1, 1, max_len)   -> (2, 1, 1, 7)
scores = torch.randn([2, 8, 7, 7])
_scores = scores.masked_fill(src_mask==0, -10000)

print(scores[0,0,:,:].shape)
print(scores[0,0,:,:])
print(_scores[0,0,:,:])
v = torch.softmax(_scores, dim=-1) @ torch.randn([2, 8, 7, 7])
print(v[0,0,:,:])
print(torch.softmax(_scores, dim=-1)[0,0,:,:])

torch.Size([7, 7])
tensor([[-0.9139,  0.4950, -0.5827,  1.7843,  0.0912, -0.0686,  0.9386],
        [-0.1362, -0.8931,  0.9399, -0.5369, -1.2713,  0.3504,  1.1994],
        [-0.4657,  1.3153, -2.0890,  0.1006, -0.1378,  1.2424, -0.8644],
        [ 0.5345,  0.4295, -1.4305, -0.6551,  0.6428,  0.4277,  1.8065],
        [ 1.5145, -0.0970, -0.3496, -1.2437,  0.1786, -0.1402,  0.3212],
        [-0.1933, -0.1760,  0.0326, -1.6319,  1.1695, -0.0429, -1.0467],
        [ 0.5202, -0.9314, -0.7990,  1.8442, -0.7157,  1.7078,  1.0778]])
tensor([[-9.1393e-01,  4.9496e-01, -5.8266e-01,  1.7843e+00, -1.0000e+04,
         -1.0000e+04, -1.0000e+04],
        [-1.3623e-01, -8.9307e-01,  9.3992e-01, -5.3694e-01, -1.0000e+04,
         -1.0000e+04, -1.0000e+04],
        [-4.6569e-01,  1.3153e+00, -2.0890e+00,  1.0064e-01, -1.0000e+04,
         -1.0000e+04, -1.0000e+04],
        [ 5.3445e-01,  4.2954e-01, -1.4305e+00, -6.5511e-01, -1.0000e+04,
         -1.0000e+04, -1.0000e+04],
        [ 1.5145e+00, -9.7021

In [44]:
_mask = src_mask[0, 0, :, :]
_mask.transpose(0, 1) * _mask

tensor([[1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]])

In [13]:
my_mask = src_mask * src_mask.transpose(2, 3)
print(torch.cat([src_mask[0,0,:,:] for _ in range(7)]))
print('-'*40)
print(my_mask[0,0,:,:])

tensor([[1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0]])
----------------------------------------
tensor([[1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0]])


In [26]:
trg = torch.tensor([1, 3, 8, 2, 0, 0, 0]).view(1, -1)   # (1, 7)    1 -> sos ; 2-> eos; 0->pad  (batch_size, max_len)
max_len = trg.shape[1]
trg_pad_mask = (trg != 0).unsqueeze(1).unsqueeze(3)     # (batch_size, 1, max_len, 1)
trg_sub_mask = torch.tril(torch.ones(max_len, max_len)).to(torch.long) # (max_len, max_len) 
trg_mask = trg_pad_mask & trg_sub_mask  # (batch_size, 1, 1, max_len)
print(trg_mask)

scores = torch.randn([1, 8, 7, 7])
_scores = scores.masked_fill(trg_mask==0, -10000)

print(scores[0,0,:,:])
print(torch.softmax(_scores[0,0,:,:], dim=-1))

tensor([[[[1, 0, 0, 0, 0, 0, 0],
          [1, 1, 0, 0, 0, 0, 0],
          [1, 1, 1, 0, 0, 0, 0],
          [1, 1, 1, 1, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0],
          [0, 0, 0, 0, 0, 0, 0]]]])
tensor([[-1.0952, -0.2178, -0.2489, -0.8556, -1.7917,  1.8808,  0.2819],
        [-0.9901, -0.3119, -0.0455,  0.3520, -0.1477,  0.0234, -0.7999],
        [-0.6403, -1.4774,  1.0258, -1.2600, -1.3721, -1.6903, -0.1964],
        [-0.4174, -0.0520,  1.7719,  0.2723,  1.7473,  0.2218, -0.0814],
        [ 1.6706,  0.8964, -0.5824,  0.1054,  0.2975, -0.8554, -1.0972],
        [-0.3924, -2.1945, -0.3546, -1.6486,  0.4085, -0.0718,  0.4826],
        [-0.2904,  0.5262, -0.8305,  0.0751,  2.2724,  1.4320,  0.1211]])
tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3367, 0.6633, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1487, 0.0644, 0.7869, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0748, 0.1078, 0.6682, 0.1491, 0.0000, 0.0000,

In [46]:
torch.tril(torch.ones(max_len, max_len))

tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1., 1.]])