### V1：Simple Self-Attention Layer

This is a simple implementation of a self-attention layer. It takes a sequence of vectors as input and produces a sequence of the same length, where each vector is a weighted sum of the input vectors.

The self-attention mechanism allows the model to focus on different parts of the input sequence when making predictions, which can be particularly useful for tasks such as language modeling and translation.

In [3]:
###simple version 1
import math
import torch
import torch.nn as nn
from sklearn.utils.extmath import softmax


class SelfAttentionV1(nn.Module):
    def __init__(self, hidden_dim: int = 728) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim
        
        self.query = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.key = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
        self.value = nn.Linear(in_features=hidden_dim, out_features=hidden_dim)
    
    def forward(self, X: torch.Tensor) -> torch.Tensor:
        #X shape is : (batch_size, seq_len, hidden_dim)
        Q = self.query(X)
        K = self.key(X)
        V = self.value(X)
        # Q K V shape (batch, seq , hidden_dim)
        
        # attention_value is : (batch, seq, seq)
        attention_value = torch.matmul(
            # K.T shape : (batch, hidden_dim, seq)
            Q, K.transpose(-1,-2)
        )
        
        # attention_weight shape : (batch, seq, seq)
        #divide by sqrt(hidden_dim) to prevent gradient vanish
        attention_weight = torch.softmax(
            attention_value / math.sqrt(self.hidden_dim),
            dim=-1 #softmax at last dim
        )
        print(attention_weight)
        
        # (batch, seq ,hidden_dim)
        output = torch.matmul(attention_weight, V)
    
        return output
    
X = torch.rand(3, 2, 4)

self_attention_net = SelfAttentionV1(4)
self_attention_net(X)

tensor([[[0.5332, 0.4668],
         [0.5271, 0.4729]],

        [[0.4961, 0.5039],
         [0.4906, 0.5094]],

        [[0.5510, 0.4490],
         [0.5345, 0.4655]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.4533, -0.0090,  0.1798,  0.5238],
         [-0.4566, -0.0082,  0.1788,  0.5268]],

        [[-0.2665, -0.2063,  0.0037,  0.5727],
         [-0.2649, -0.2069,  0.0046,  0.5712]],

        [[-0.5426, -0.3203, -0.2754,  0.7483],
         [-0.5384, -0.3169, -0.2683,  0.7455]]], grad_fn=<UnsafeViewBackward0>)

### V2：Efficiency Optimization

In [4]:
class SelfAttentionV2(nn.Module):
    def __init__(self, dim: int = 728) -> None:
        super().__init__()
        self.dim = dim
        
        self.qkv = nn.Linear(in_features=dim, out_features=dim * 3)
    
    def forward(self, X: torch.Tensor) -> torch.Tensor:
        # X shape (batch, seq, dim)
        # QKV shape (batch, seq, dim * 3)
        QKV = self.qkv(X)
        
        Q, K, V = torch.split(QKV, self.dim, dim=-1)

        attention_weight = torch.softmax(
            torch.matmul(Q, K.transpose(-1,-2)
            ) / math.sqrt(self.dim), 
            dim=-1
        )
        print(attention_weight)
        output = attention_weight @ V
        return output

X = torch.rand(3, 2, 4)
self_attention_net = SelfAttentionV2(4)
self_attention_net(X)

tensor([[[0.5095, 0.4905],
         [0.5166, 0.4834]],

        [[0.5135, 0.4865],
         [0.5115, 0.4885]],

        [[0.4994, 0.5006],
         [0.4932, 0.5068]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.2236, -0.0855,  0.4374,  0.8119],
         [-0.2232, -0.0832,  0.4369,  0.8130]],

        [[-0.1544, -0.3866,  0.5923,  0.9547],
         [-0.1546, -0.3868,  0.5921,  0.9546]],

        [[-0.1830, -0.0566,  0.3942,  0.9942],
         [-0.1830, -0.0578,  0.3947,  0.9941]]], grad_fn=<UnsafeViewBackward0>)

### V3: Add Some Details

In [5]:
# 1.dropout position
# 2.attention_mask
# 3.output projection

class SelfAttentionV3(nn.Module):
    def __init__(self, dim, dropout_rate=0.1) -> None:
        super().__init__()
        self.dim = dim
        self.dropout = dropout_rate
        self.qkv = nn.Linear(dim, dim * 3)
        self.attention_dropout = nn.Dropout(dropout_rate)
        
        self.output = nn.Linear(dim, dim)
       
    def forward(self, X: torch.Tensor, attention_mask=None) -> torch.Tensor:
        # X (batch, seq, dim)
        QKV = self.qkv(X)
        Q, K, V = torch.split(QKV, self.dim, dim=-1)
        
        attention_weight = Q @ K.transpose(-1,-2) / math.sqrt(self.dim)
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float('-inf')
            )
        # (batch, seq ,seq)
        attention_weight = torch.softmax(
            attention_weight, 
            dim=-1
        )
        print(attention_weight)
        
        attention_weight = self.attention_dropout(attention_weight)
        output = self.output(attention_weight @ V)
        # (batch, seq, dim)
        return output
    
X = torch.rand(3, 4, 2)
# (batch, seq, seq) (batch, seq)
mask = torch.tensor(
    [
        [1, 1, 1, 0], 
        [1, 1, 0, 0],
        [1, 0, 0, 0]
    ]
)
print(mask.shape)
#Expand the mask tensor to match the shape of the input tensor for broadcasting compatibility
mask = mask.unsqueeze(dim=1).repeat(1, 4, 1)
print(f"repeat shape is: {mask.size()}")

attention_net = SelfAttentionV3(2)
attention_net(X, mask)

torch.Size([3, 4])
repeat shape is: torch.Size([3, 4, 4])
tensor([[[0.3685, 0.3538, 0.2777, 0.0000],
         [0.3601, 0.3496, 0.2903, 0.0000],
         [0.3374, 0.3346, 0.3280, 0.0000],
         [0.3736, 0.3562, 0.2702, 0.0000]],

        [[0.4933, 0.5067, 0.0000, 0.0000],
         [0.4467, 0.5533, 0.0000, 0.0000],
         [0.4558, 0.5442, 0.0000, 0.0000],
         [0.4586, 0.5414, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)


tensor([[[-0.2232, -0.7291],
         [-0.0272, -0.7461],
         [-0.0371, -0.7525],
         [-0.0219, -0.7426]],

        [[-0.3550, -0.7401],
         [-0.1194, -0.7139],
         [-0.1213, -0.7164],
         [-0.3631, -0.7319]],

        [[ 0.1005, -0.6753],
         [ 0.1005, -0.6753],
         [ 0.1005, -0.6753],
         [ 0.1005, -0.6753]]], grad_fn=<ViewBackward0>)

### V4: Multi-Head Self-Attention

In [6]:
import math
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, dim: int, num_heads: int, dropout_rate=0.1) -> None:
        super().__init__()
        self.dim = dim
        self.head_dim = dim // num_heads # (num_heads * head_dim = dim)
        self.num_heads = num_heads
        
        self.qkv = nn.Linear(in_features=dim, out_features=dim * 3) # (dim, head_dim * num_heads * 3)
        self.output = nn.Linear(in_features=dim, out_features=dim)
        
        self.attention_dropout = nn.Dropout(dropout_rate)
        
    def forward(self, X: torch.Tensor, attention_mask=None) -> torch.Tensor:
        # X (b, s, h)

        batch, seq, _ = X.size()
        
        QKV = self.qkv(X)
        # (b, s, h)
        Q, K, V = torch.split(QKV, self.dim, dim=-1)
        # (b, s, h) => (b, num_heads, s, head_dim)
        # h => num_heads * head_dim
        q_state = Q.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        k_state = K.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        v_state = V.view(batch, seq, self.num_heads, self.head_dim).transpose(1, 2)
        
        # (b, num_heads, s, s)
        # (b, num_heads, s, head_dim) => (b, num_heads, head_dim, s)
        attention_weight = q_state @ k_state.transpose(-1,2) / math.sqrt(self.dim)
        
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float('-inf')
            )
        attention_weight = torch.softmax(attention_weight, dim=-1)
        print(attention_weight.shape)
        attention_weight = self.attention_dropout(attention_weight)
        output_mid = attention_weight @ v_state # (b, num_heads, s, head_dim)
        
        # (b, s, h)
        output_mid = output_mid.transpose(1, 2).contiguous()
        output_mid = output_mid.view(batch, seq, -1) # h
        output = self.output(output_mid)
        
        return output
    

mask = (
    torch.tensor(
        [
            [0, 1],
            [0, 0],
            [1, 0]
        ]
    )
    .unsqueeze(1)
    .unsqueeze(2)
    .expand(3, 8, 2, 2)
)

X = torch.rand(3, 2, 128)
attention_net = MultiHeadAttention(dim=128, num_heads=8) # head_dim = 16
attention_net(X, mask)

torch.Size([3, 8, 2, 2])


tensor([[[ 0.0997, -0.1123, -0.1628, -0.0790, -0.0662,  0.2278, -0.4019,
          -0.1423,  0.1320,  0.0509,  0.0600, -0.1963, -0.4512, -0.3134,
           0.1748, -0.0897,  0.0406, -0.1287, -0.2564,  0.0798,  0.2863,
          -0.0454, -0.2507,  0.0787, -0.5036,  0.1970, -0.0495,  0.2191,
           0.0671, -0.0168,  0.0675, -0.4743, -0.2189,  0.3280, -0.1422,
          -0.3457, -0.1562,  0.4148,  0.3013, -0.1683, -0.0898,  0.0133,
          -0.2101, -0.1575,  0.4553, -0.0332, -0.4689,  0.3274, -0.2142,
          -0.2233, -0.4041,  0.0894, -0.1285,  0.6013,  0.0090,  0.0839,
           0.2653, -0.0127,  0.2401, -0.0509, -0.0785,  0.1627,  0.1882,
          -0.0582, -0.0187, -0.0171,  0.0437,  0.0630,  0.2385, -0.1078,
           0.0395,  0.0492,  0.2838,  0.0390, -0.2167, -0.2571, -0.0387,
          -0.1593, -0.1299,  0.2698, -0.1954,  0.5995, -0.5483, -0.3328,
          -0.0645,  0.0079,  0.1874, -0.1107, -0.3210,  0.1329, -0.1958,
          -0.1351, -0.3772, -0.0134, -0.1852, -0.10