# Self-attention

To manually implement **scaled dot-product attention** (the core mechanism used in transformers), we can break it down into a few key steps. Here's how we can approach it:

Scaled dot-product attention computes the attention scores for a query with respect to a set of keys and uses those scores to combine the values.

Given:
- **Q** (Query matrix)
- **K** (Key matrix)
- **V** (Value matrix)

The steps are:
1. Compute the raw attention scores by taking the dot product of **Q** and **K**.
2. Scale the attention scores by dividing by the square root of the dimension of the keys  $\sqrt{d_k}$ , to prevent very large dot product values from making the softmax output too sharp.
3. Apply a softmax function to obtain normalized attention weights. 
4. Use these attention weights to compute a weighted sum of the **V** (Values) to get the output.

$$
\text{Attention}(Q, K, V) = \text{softmax}\left( \frac{Q K^T}{\sqrt{d_k}} \right) V
$$

Where:
- $ Q $ is the query matrix with shape $(n_q, d_k)$ 
- $ K $ is the key matrix with shape $(n_k, d_k)$
- $ V $ is the value matrix with shape $(n_v, d_v)$
- $ d_k $ is the dimensionality of the query and key vectors





In [2]:
import math
import torch
import torch.nn as nn

## 1.简化版本

In [None]:
class SelfAttention_V1(nn.Module):
    
    def __init__(self, hidden_dim:int=728) -> None:
        super(SelfAttention_V1, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.query_proj = nn.Linear(hidden_dim, hidden_dim)
        self.key_proj = nn.Linear(hidden_dim, hidden_dim)
        self.value_proj = nn.Linear(hidden_dim, hidden_dim)
        
    def forward(self, X:torch.tensor):
        # x shape: (batch_size, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = X.size()
        # QKV shape: (batch_size, seq_len, hidden_dim)
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)
        
        # (batch_size, seq_len, seq_len)
        attention_value = torch.matmul(Q, K.transpose(-1,-2))
        attention_weight = torch.softmax(attention_value/math.sqrt(hidden_dim), dim=-1)
        print(attention_weight)
        
        output = torch.matmul(attention_weight, V)
        
        return output
    

X = torch.rand((3,3,5))
print(X.shape)
sa = SelfAttention_V1(5)
Y = sa(X)
print(Y.shape)

## 2.效率优化

合并QKV权重矩阵

In [None]:
class SelfAttention_V2(nn.Module):
    
    def __init__(self, hidden_dim:int=728) -> None:
        super(SelfAttention_V2, self).__init__()
        self.hidden_dim = hidden_dim
        
        self.qkv_proj = nn.Linear(hidden_dim, hidden_dim*3)
        
    def forward(self, X:torch.tensor):
        # x shape: (batch_size, seq_len, hidden_dim)
        batch_size, seq_len, hidden_dim = X.size()
        # QKV shape: (batch_size, seq_len, hidden_dim)
        QKV = self.qkv_proj(X)
        Q, K, V = torch.split(QKV, self.hidden_dim, -1)
        
        # (batch_size, seq_len, seq_len)
        attention_value = torch.matmul(Q, K.transpose(-1,-2))
        attention_weight = torch.softmax(attention_value/math.sqrt(hidden_dim), dim=-1)
        
        output = torch.matmul(attention_weight, V)
        
        return output
    

X = torch.rand((3,3,5))
print(X.shape)
sa = SelfAttention_V2(5)
Y = sa(X)
print(Y.shape)

## 3. 细节

1. dropout位置 
2. attention_mask
3. output_proj

In [None]:
class SelfAttention_V3(nn.Module):
    
    def __init__(self, dim, drop_out_rate:float=0.1) -> None:
        super(SelfAttention_V3, self).__init__()
        self.dim = dim
        
        self.proj = nn.Linear(dim, dim * 3)
        self.dropout = nn.Dropout(drop_out_rate)
        self.output_proj = nn.Linear(dim, dim)
        
        
    def forward(self, X, attention_mask):
        # X: batch, seq, dim
        QKV = self.proj(X)
        Q, K, V = torch.split(QKV, self.dim, dim=-1)
        attention_weight = Q @ K.transpose(-1,-2) / math.sqrt(self.dim)
        
        if attention_mask is not None:
            attention_weight = attention_weight.masked_fill(attention_mask==0, float('-1e20'))
        
        attention_weight = torch.softmax(attention_weight, dim=-1)
        print(attention_weight)
        attention_weight = self.dropout(attention_weight)
        
        output = attention_weight @ V
        
        output = self.output_proj(output)
        return output
        
X = torch.rand((2,3,5))
mask = torch.tensor(
    [
        [1,1,1],
        [1,1,0]
    ]
)
mask = mask.unsqueeze(dim=1).repeat(1,3,1)
print(X.shape)
print(mask.shape)
sa = SelfAttention_V3(5)
Y = sa(X, mask)
print(Y.shape)
            

In [None]:
mask = torch.tensor(
    [
        [1,1,1],
        [1,1,0]
    ]
)
mask = mask.unsqueeze(dim=1).repeat(1,3,1)
mask

## 4.面试


In [None]:
class SelfAttention_Final(nn.Module):
    
    def __init__(self, dim:int, drop_out_rate:float=0.1) -> None:
        super(SelfAttention_Final, self).__init__()
        
        self.dim = dim
        self.query_proj = nn.Linear(dim, dim)
        self.key_proj = nn.Linear(dim, dim)
        self.value_proj = nn.Linear(dim, dim)
        self.attention_dropout = nn.Dropout(drop_out_rate)
        self.output_proj = nn.Linear(dim, dim)
    
    def forward(self, X:torch.tensor, attention_mask:torch.tensor):
    
        assert X.shape[:-1] == attention_mask.shape
        
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)
        
        attention_score = Q @ K.transpose(-1,-2) / math.sqrt(self.dim)
        
        
        if attention_mask is not None:
            attention_mask = attention_mask.unsqueeze(1).repeat(1,X.shape[1],1)
            attention_score = torch.masked_fill(attention_score, attention_mask==0, -float('inf'))
        
        attention_weight = torch.softmax(attention_score, dim=-1)
        
        attention_weight = self.attention_dropout(attention_weight)
        
        output = attention_weight @ V
        output = self.output_proj(output)
        
        return output
        
        
X = torch.rand((2,3,5))
mask = torch.tensor(
    [
        [1,1,1],
        [1,1,0]
    ]
)
model = SelfAttention_Final(5, 0.1)
y = model(X, mask)
print(y.shape)

以上代码还有个问题，self attention 和 cross attention其实没有本质的区别，因此最好还是把qkv的输入区分开.

# Multi-head attention

To manually implement **Multi-Head Attention** in PyTorch, we need to combine several scaled dot-product attention mechanisms (one per head). Each attention head has its own set of query, key, and value weights, and after computing the attention for each head, we concatenate the results and apply a final linear transformation.

### Multi-Head Attention Overview:

1. **Input**: Given:
   - $ X $ (Input matrix) of shape $ (n_{batch}, n_{seq}, d_{embedding}) $
   - **$ d_{embedding} $** is the model's dimension (e.g., 512 or 1024)
   - **$ n_{\text{heads}} $** is the number of attention heads (e.g., 8)

2. **Linear Projections**: For each attention head, we project $ Q $, $ K $, and $ V $ into smaller dimensions using learned weight matrices:
   - $ Q_i = XW_q^i $
   - $ K_i = XW_k^i $
   - $ V_i = XW_v^i $
   where $ W_q^i$ , $ W_k^i $, and $ W_v^i $ are the learned projection matrices for each head $ i $.

3. **Scaled Dot-Product Attention**: For each head, we compute the scaled dot-product attention:
   $$
   \text{Output}_i = \text{Attention}(Q_i, K_i) \cdot V_i = \text{softmax}\left(\frac{Q_i K_i^T}{\sqrt{d_k}}\right) V_i
   $$
   where $ d_k = \frac{d_{embedding}}{n_{\text{heads}}} $.

4. **Concatenate Heads**: After computing the attention for each head, we concatenate the results across all heads and apply a final linear transformation to the concatenated vector of shape $ (n_{seq}, d_{embedding}) $.

   $$
   \text{MultiHeadOutput} = Concat(\text{Output}_1, \text{Output}_2, \dots, \text{Output}_H) \cdot W_{linear}
   $$


In [3]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, dim:int, num_head:int, attention_dropout_rate:float=0.1) -> None:
        super(MultiHeadAttention, self).__init__()
        
        assert dim % num_head == 0 
        self.dim = dim
        self.num_head = num_head
        self.dim_head = dim // num_head
        
        self.q_proj = nn.Linear(dim, dim)
        self.k_proj = nn.Linear(dim, dim)
        self.v_proj = nn.Linear(dim, dim)
        self.o_proj = nn.Linear(dim, dim)
        
        self.attention_dropout = nn.Dropout(attention_dropout_rate)
        
        
    def forward(self, query, key, value, attention_mask):
        
        # assert query.shape == key.shape == value.shape
        bs, sl, hs = query.shape
        
        Q = self.q_proj(query)
        K = self.k_proj(key)
        V = self.v_proj(value)
        print(Q.shape)
        
        # (b, s, d) => (b, n, s, d/n)
        Q_heads = torch.reshape(Q, (bs, sl, self.num_head, self.dim_head)).transpose(1,2)
        K_heads = torch.reshape(K, (bs, sl, self.num_head, self.dim_head)).transpose(1,2)
        V_heads = torch.reshape(V, (bs, sl, self.num_head, self.dim_head)).transpose(1,2)
        print(Q_heads.shape)
        
        
        # (b, n, s, s)
        attention_score = Q_heads @ K_heads.transpose(-1,-2) / math.sqrt(self.num_head)
        
        if attention_mask is not None:
            attention_score = attention_score.masked_fill(attention_mask==0, float('-inf'))
            
        attention_weight = torch.softmax(attention_score, dim=-1)
        attention_weight = self.attention_dropout(attention_weight)
        
        # (b, n, s, d/n)
        output = attention_weight @ V_heads
         # (b, s, n, d/n)
        output = output.transpose(1,2)
        output = torch.reshape(output, (bs,sl,-1))
        
        output = self.o_proj(output)
        return output
    
    
X = torch.rand((2,3,6))
mask = torch.tensor(
    [
        [1,1,1],
        [1,1,0]
    ]
)

mask = mask.unsqueeze(1).repeat(1,3,1)
model = MultiHeadAttention(6,2, 0.1)
y = model(X, X, X, mask)
print(y.shape)

torch.Size([2, 3, 6])
torch.Size([2, 2, 3, 3])
torch.Size([2, 3, 6])
