In [1]:
!nvidia-smi

Tue Feb 18 11:24:05 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   47C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

# 手写 self-attention
## 1 公式
 $ Attention (Q,K,V) = softmax(\frac{Q·K^T}{\sqrt[]d_k})·V$

## 基础版本

In [3]:
import torch
import math
import torch.nn as nn

In [7]:
# 手动编写selfAttention层（禁止使用库函数）
   # v1
class selfAttentionV1(nn.Module):
  def __init__(self,hidden_dim:int = 728) -> None:
        super().__init__()
        self.hidden_dim = hidden_dim

        self.query_proj = nn.Linear(hidden_dim,hidden_dim)
        self.key_proj = nn.Linear(hidden_dim,hidden_dim)
        self.value_proj = nn.Linear(hidden_dim,hidden_dim)
  def forward(self,X):

     # X shape is :(batch_size,seq_len,hidden_dim)
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

     # Q K V shape (batch,seq,hidden_dim)

     # attention_value is :(batch,seq,seq)
        attention_value = torch.matmul(
         # K 要变成(batch,hidden_dim,seq)
            Q,K.transpose(-1,-2)
         )
     # (batch,seq,seq)
        attention_weight = torch.softmax(
         attention_value/math.sqrt(self.hidden_dim),# 防止梯度消失
         dim = -1
        )
        print(attention_weight)
     # (batch,seq,hidden_dim)
        return torch.matmul(attention_weight,V)

X = torch.rand(3,2,4)

self_att_net=selfAttentionV1(4)
print(self_att_net(X))

tensor([[[0.5258, 0.4742],
         [0.5487, 0.4513]],

        [[0.5083, 0.4917],
         [0.4965, 0.5035]],

        [[0.4969, 0.5031],
         [0.4967, 0.5033]]], grad_fn=<SoftmaxBackward0>)
tensor([[[-0.7143, -0.3710, -0.0556, -0.4849],
         [-0.7055, -0.3773, -0.0633, -0.4800]],

        [[-0.6510, -0.1262,  0.1903, -0.7163],
         [-0.6441, -0.1247,  0.1908, -0.7155]],

        [[-0.7555, -0.3069, -0.0465, -0.6092],
         [-0.7555, -0.3069, -0.0465, -0.6092]]], grad_fn=<UnsafeViewBackward0>)


## 效率优化

In [8]:
### 网络比较小 QKV 合并运算--效率优化  V2
class selfAttentionV2(nn.Module):
  def __init__(self,dim):
        super().__init__()
        self.dim = dim

        self.proj = nn.Linear(dim,dim * 3)
  def forward(self,X):
    # X
    QKV=self.proj(X)
    Q,K,V = torch.split(QKV,self.dim,dim = -1)#分开
    attention_weight = torch.softmax(
         torch.matmul(Q,K.transpose(-1,-2))/math.sqrt(self.dim),
         dim = -1
        )
    output = attention_weight @ V
    return output

X = torch.rand(3,2,4)

net=selfAttentionV2(4)
print(net(X))





tensor([[[ 0.0857, -0.1794,  0.8006,  0.1887],
         [ 0.0854, -0.1785,  0.7998,  0.1885]],

        [[ 0.2623, -0.3023,  0.9188,  0.1146],
         [ 0.2635, -0.3020,  0.9179,  0.1132]],

        [[-0.3038,  0.0788,  0.6206,  0.3794],
         [-0.3038,  0.0788,  0.6206,  0.3795]]], grad_fn=<UnsafeViewBackward0>)


## 加入一些细节

In [11]:
# 1.dropout 位置
#  2.attention_mask
#  3.output 矩阵映射()
class selfAttentionV3(nn.Module):
  def __init__(self,dim,*args,**kwargs) -> None:
        super().__init__(*args,**kwargs)
        self.dim = dim

        self.proj = nn.Linear(dim,dim * 3)
        self.attention_dropout = nn.Dropout(0.1)
        self.output_proj = nn.Linear(dim,dim)

  def forward(self,X,attention_mask = None):

    QKV=self.proj(X)

    Q,K,V = torch.split(QKV,self.dim,dim = -1)#分开
    attention_weight = Q @ K.transpose(-1,-2)/math.sqrt(self.dim)
    if attention_mask is not None:
      attention_weight =attention_weight.masked_fill(
          attention_mask == 0,
          float("-1e20"))
    attention_weight = torch.softmax(attention_weight,dim = -1)
    attention_weight = self.attention_dropout(attention_weight)
    output = attention_weight @ V
    output = self.output_proj(output) #output 矩阵映射
    return output


X=torch.rand(3,4,2)
mark=torch.tensor(
    [
    [1,1,1,0],
    [1,1,0,0],
    [1,0,0,0]
    ])
mark=mark.unsqueeze(dim=1).repeat(1,4,1)
print(f"repeat shape:{mark.size()}")

net=selfAttentionV3(2)
print(net(X,attention_mask=mark))



repeat shape:torch.Size([3, 4, 4])
tensor([[[-0.1209, -0.3495],
         [-0.1172, -0.3238],
         [-0.1174, -0.3492],
         [-0.1183, -0.3237]],

        [[-0.3265, -0.3805],
         [-0.3371, -0.3820],
         [-0.3314, -0.3812],
         [-0.3241, -0.3319]],

        [[-0.3254, -0.3647],
         [-0.3254, -0.3647],
         [-0.3254, -0.3647],
         [-0.3254, -0.3647]]], grad_fn=<ViewBackward0>)


## 面试写法


In [13]:
class selfAttentionInterview(nn.Module):
  def __init__(self,dim) -> None:
        super().__init__()
        self.dim = dim

        self.query = nn.Linear(dim,dim)
        self.key = nn.Linear(dim,dim)
        self.value = nn.Linear(dim,dim)


        self.attention_dropout = nn.Dropout(0.1)
  def forward(self,X,attention_mask = None):
    Q = self.query(X)
    K = self.key(X)
    V = self.value(X)

    attention_weight= Q @ K.transpose(-1,-2)/math.sqrt(self.dim)
    if attention_mask is not None:
      attention_weight=attention_weight.masked_fill(
          attention_mask == 0,
          float("-inf")
      )
    attention_weight = torch.softmax(attention_weight,dim = -1)
    print(attention_weight)#验证
    attention_weight = self.attention_dropout(attention_weight)
    output = attention_weight @ V
    return output

X = torch.rand(3,4,2)
mark = torch.tensor(
    [
    [1,1,1,0],
    [1,1,0,0],
    [1,0,0,0]
    ])
mark = mark.unsqueeze(dim=1).repeat(1,4,1)
net = selfAttentionInterview(2)
print(net(X,attention_mask=mark))


tensor([[[0.3487, 0.3186, 0.3327, 0.0000],
         [0.3655, 0.3010, 0.3335, 0.0000],
         [0.3527, 0.3142, 0.3331, 0.0000],
         [0.3653, 0.3008, 0.3340, 0.0000]],

        [[0.5355, 0.4645, 0.0000, 0.0000],
         [0.5609, 0.4391, 0.0000, 0.0000],
         [0.5380, 0.4620, 0.0000, 0.0000],
         [0.5385, 0.4615, 0.0000, 0.0000]],

        [[1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000],
         [1.0000, 0.0000, 0.0000, 0.0000]]], grad_fn=<SoftmaxBackward0>)
tensor([[[ 0.0567, -0.4883],
         [ 0.0519, -0.4861],
         [ 0.0307, -0.3295],
         [ 0.0519, -0.4861]],

        [[ 0.6038, -0.5999],
         [ 0.5965, -0.5966],
         [ 0.3495, -0.3087],
         [ 0.3491, -0.3083]],

        [[ 0.5236, -0.5632],
         [ 0.5236, -0.5632],
         [ 0.5236, -0.5632],
         [ 0.5236, -0.5632]]], grad_fn=<UnsafeViewBackward0>)
