In [103]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [104]:
torch.logical_not(torch.ones((5, 5), dtype=bool).tril())

tensor([[False,  True,  True,  True,  True],
        [False, False,  True,  True,  True],
        [False, False, False,  True,  True],
        [False, False, False, False,  True],
        [False, False, False, False, False]])

In [110]:
def scaled_dot_product_self_attention(
    Q,  # query (...BATCH, TARGET_SEQ_DIM, QK_DIM)
    K,  # key   (...BATCH, SRC_SEQ_DIM, QK_DIM)
    V,  # value (...BATCH, SRC_SEQ_DIM, V_DIM)
    attn_mask=None, # (...BATCH, TARGET_SEQ_DIM, SRC_SEQ_DIM) of BOOL
    is_causal=False,  # causal attention masking
):  # -> (...BATCH, TARGET_SEQ_DIM, V_DIM)
    *batch_dim, target_seq_dim, qk_dim = Q.shape
    src_seq_dim = K.size(-2)
    
    K_T = K.transpose(-2, -1)
    attn_logits = (Q @ K_T / (qk_dim ** 0.5))
    
    if attn_mask is None:
        attn_mask = torch.ones((target_seq_dim, src_seq_dim), dtype=bool, requires_grad=False)
    if is_causal is True:
        attn_mask &= torch.ones((target_seq_dim, src_seq_dim), dtype=bool, requires_grad=False).tril()

    attn_logits = attn_logits.masked_fill(torch.logical_not(attn_mask), -float('inf'))
    attn_weight = torch.softmax(attn_logits, dim=-1)
    result = attn_weight @ V
    
    return result, attn_weight

In [111]:
Q = torch.rand(2, 4, 6)
K = torch.rand(2, 5, 6)
V = torch.rand(2, 5, 7)
attn_mask = torch.rand(2, 4, 5) < 0.9

scaled_dot_product_self_attention(Q, K, V, attn_mask=attn_mask, is_causal=True)

(tensor([[[0.6975, 0.4676, 0.4804, 0.7067, 0.6771, 0.9294, 0.8087],
          [0.5592, 0.5975, 0.4835, 0.7846, 0.6701, 0.7906, 0.5447],
          [0.5577, 0.5248, 0.3294, 0.6061, 0.7338, 0.6542, 0.5945],
          [0.4969, 0.5615, 0.5238, 0.5219, 0.7216, 0.6420, 0.4770]],
 
         [[0.1245, 0.5985, 0.9666, 0.4224, 0.0725, 0.8845, 0.6975],
          [0.1245, 0.5985, 0.9666, 0.4224, 0.0725, 0.8845, 0.6975],
          [0.2116, 0.5361, 0.6706, 0.5960, 0.3072, 0.8712, 0.4641],
          [0.3834, 0.6951, 0.4935, 0.5209, 0.4729, 0.7073, 0.4580]]]),
 tensor([[[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5741, 0.4259, 0.0000, 0.0000, 0.0000],
          [0.3090, 0.3100, 0.3810, 0.0000, 0.0000],
          [0.2505, 0.1797, 0.2708, 0.2990, 0.0000]],
 
         [[1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [1.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.5421, 0.0000, 0.4579, 0.0000, 0.0000],
          [0.2533, 0.2228, 0.2694, 0.2545, 0.0000]]]))

In [118]:
class MultiheadAttention(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads,
        bias=True,
        kfeatdim=None,
        vfeatdim=None,
        vdim=None,
    ):
        assert embed_dim % num_heads == 0, "Error: the embedding dimension should be divisible by the number of heads"

        super().__init__()
        
        self.embed_dim = embed_dim
        self.kfeatdim = embed_dim if kfeatdim is None else kfeatdim
        self.vfeatdim = embed_dim if vfeatdim is None else vfeatdim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.vdim = self.head_dim if vdim is None else vdim
        
        self.WQ = nn.Linear(embed_dim, embed_dim, bias=bias)  # (num_heads, embed_dim, head_dim) ~ (embed_dim, embed_dim)
        self.WK = nn.Linear(self.kfeatdim, embed_dim, bias=bias)  # (num_heads, kfeatdim, head_dim) ~ (kfeatdim, embed_dim)
        self.WV = nn.Linear(self.vfeatdim, num_heads * self.vdim, bias=bias)  # (num_heads, vfeatdim, vdim) - (vfeatdim, num_heads * vdim)
        self.WO = nn.Linear(num_heads * self.vdim, embed_dim, bias=bias)  # (num_heads * vdim, embed_dim)

    def forward(
        self,
        Q,  # (...BATCH, TARGET_SEQ_DIM, embed_dim)
        K,  # (...BATCH, SRC_SEQ_DIM, kfeatdim)
        V,  # (...BATCH, SRC_SEQ_DIM, vfeatdim)
        key_padding_mask=None,  # (...BATCH, SRC_SEQ_DIM)
        attn_mask=None,  # (TARGET_SEQ_DIM, SRC_SEQ_DIM) or (...BATCH, num_heads, TARGET_SEQ_DIM, SRC_SEQ_DIM)
        is_causal=False
    ):  # -> (...BATCH, TARGET_SEQ_DIM, embed_dim) ; OPT attention weights
        *batch_dim, target_seq_dim, _ = Q.shape
        *_, src_seq_dim, _ = K.shape
        
        Q_proj = self.__separate_heads(self.WQ(Q))  # (...BATCH, num_heads, TARGET_SEQ_DIM, head_dim)
        K_proj = self.__separate_heads(self.WK(K))  # (...BATCH, num_heads, SRC_SEQ_DIM, head_dim)
        V_proj = self.__separate_heads(self.WV(V))  # (...BATCH, num_heads, SRC_SEQ_DIM, vdim)

        # Padding to mask (...BATCH, SRC_SEQ_DIM) -> (...BATCH, 1, 1, SRC_SEQ_DIM)
        if key_padding_mask is not None:
            key_padding_mask = key_padding_mask.view(*batch_dim, 1, 1, src_seq_dim)
            
        attn_mask = self.__merge_masks(key_padding_mask, attn_mask)
        
        heads, attn_weigth = scaled_dot_product_self_attention(
            Q_proj, K_proj, V_proj, attn_mask, is_causal
        )  # (...BATCH, num_heads, TARGET_SEQ_DIM, vdim)
        
        return self.WO(self.__join_heads(heads)), attn_weigth

    def __separate_heads(self, mat):
        # (...BATCH, SEQ, num_heads * proj_dim) -> (...BATCH, num_heads, SEQ, proj_dim)
        *batch_dim, seq_dim, proj_dim = mat.shape
        return mat.view(*batch_dim, seq_dim, self.num_heads, proj_dim // self.num_heads).transpose(-2, -3)
    
    def __join_heads(self, mat):
        # (...BATCH, num_heads, SEQ, proj_dim) -> (...BATCH, SEQ, num_heads * proj_dim)
        *batch_dim, num_heads, seq_dim, proj_dim = mat.shape
        return mat.transpose(-2, -3).contiguous().view(*batch_dim, seq_dim, num_heads * proj_dim)

    def __merge_masks(self, *masks):
        masks = [mask for mask in masks if mask is not None]
        if len(masks) == 0: return None
        final_mask = masks[0]
        for mask in masks[1:]:
            final_mask &= mask
        return final_mask

In [121]:
ma = MultiheadAttention(16, 4)
Q = torch.rand(2, 2, 4, 16)  # (...BATCH, TARGET_SEQ_DIM, embed_dim)
K = torch.rand(2, 2, 3, 16) # (...BATCH, SRC_SEQ_DIM, kfeatdim)
V = torch.rand(2, 2, 3, 16) # (...BATCH, SRC_SEQ_DIM, vfeatdim)
ma.forward(Q, K, V)[0].shape

torch.Size([2, 2, 4, 16])

In [132]:
class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        embed_dim,
        num_heads=4,
        p_dropout=0.1,
        dim_feedforward=2048,
        dropout=0.1,
        activation=nn.ReLU,
        layer_norm_eps=1e-05,
    ):
        super().__init__()
        
        self.mh_attention = MultiheadAttention(embed_dim, num_heads)
        self.dropout_1 = nn.Dropout(p_dropout)
        self.layer_norm_1 = nn.LayerNorm(embed_dim, layer_norm_eps)
        
        self.ff_l1 = nn.Linear(embed_dim, dim_feedforward)
        self.activation = activation()
        self.ff_l2 = nn.Linear(dim_feedforward, embed_dim)
        self.dropout_2 = nn.Dropout(p_dropout)
        self.layer_norm_2 = nn.LayerNorm(embed_dim, layer_norm_eps)
    
    def forward(
        self,
        src,  # (...BATCH, SRC_SEQ_DIM, embed_dim)
        src_mask=None,  # (TARGET_SEQ_DIM, SRC_SEQ_DIM) or (...BATCH, num_heads, TARGET_SEQ_DIM, SRC_SEQ_DIM)
        src_key_padding_mask=None,  # (...BATCH, SRC_SEQ_DIM)
        is_causal=False
    ):  # -> (...BATCH, TARGET_SEQ_DIM, embed_dim) ; OPT attention weights
        *batch_dim, src_seq_dim, _ = src.shape

        x, _ = self.mh_attention(src, src, src, src_key_padding_mask, src_mask, is_causal)
        x = self.dropout_1(x)
        x = self.layer_norm_1(src + x)
        
        y = self.ff_l2(self.activation(self.ff_l1(x)))
        y = self.dropout_2(y)
        return self.layer_norm_2(x + y)

In [138]:
el = TransformerEncoderLayer(16)
src = torch.rand(2, 2, 4, 16)  # (...BATCH, SRC_SEQ_DIM, embed_dim)
el.forward(src)

tensor([[[[ 0.5431, -1.0846, -0.9011,  0.8957,  0.5347, -0.6398,  0.1175,
            1.9653, -1.0611,  0.5459, -2.0089,  0.7471,  0.4047,  0.8768,
            0.2236, -1.1588],
          [-0.8203, -0.1239, -0.7349, -0.8493,  1.3454,  1.2977,  0.4614,
            1.2601, -1.6294,  0.1631, -0.4998, -0.1653,  1.9158, -0.2224,
            0.0425, -1.4405],
          [-0.3493,  0.0955,  0.6185, -1.9106,  1.4605,  0.5502,  0.8610,
           -0.6984, -0.8787,  0.7809, -2.2249,  0.6605,  0.7090, -0.4498,
            0.8332, -0.0577],
          [-0.9797, -0.5216, -1.3075, -0.8019,  1.5519,  0.7661,  0.6680,
            0.8565, -0.7424, -0.0350, -1.9466,  0.1383,  1.8502,  0.5921,
           -0.2695,  0.1811]],

         [[ 0.1245, -1.1639, -0.7406, -1.4156,  1.1140, -0.4898,  0.8720,
            1.8175, -0.2387,  0.5932, -1.9118,  0.5538,  0.9475,  0.8126,
           -0.1045, -0.7701],
          [ 1.3201, -0.6396,  0.4988, -0.0741,  1.0548, -0.0987, -0.4132,
            0.5036, -1.3795, -0.41