In [11]:
import torch
import math
import torch.nn as nn

In [12]:
class SelfAttentionV1(nn.Module):
    def __init__(self,dim) ->None:
        super().__init__()
        self.dim = dim
        self.query_proj = nn.Linear(dim,dim)
        self.key_proj = nn.Linear(dim,dim)
        self.value_proj = nn.Linear(dim,dim)
        self.output_proj = nn.Linear(dim,dim)
        self.att_drop = nn.Dropout(0.1)

    def forward(self,X,atten_mask = None):
        Q = self.query_proj(X)
        K = self.key_proj(X)
        V = self.value_proj(X)

        atten_weight = Q @ K.transpose(-1,-2) / math.sqrt(self.dim)

        if(atten_mask != None):
            atten_weight = atten_weight.masked_fill(atten_mask == 0, float("-inf"))
        
        atten_weight = torch.softmax(atten_weight,dim = -1)
        atten_weight = self.att_drop(atten_weight)
        output = atten_weight@V
        ret = self.output_proj(output)

        return ret

In [13]:
X = torch.rand(3, 4, 2)
b = torch.tensor(
    [
        [1, 1, 1, 0],
        [1, 1, 0, 0],
        [1, 0, 0, 0],
    ]
 )
print("b shape:", b.shape)

# 构造 mask
# b.unsqueeze(dim=1) -> (3, 1, 4)
# .repeat(1, 4, 1) -> (3, 4, 4)
mask = b.unsqueeze(dim=1).repeat(1, 4, 1)
print("Mask shape:", mask.shape)

net = SelfAttentionV1(2)
output = net(X, mask)
print("Output shape:", output.shape)

b shape: torch.Size([3, 4])
Mask shape: torch.Size([3, 4, 4])
Output shape: torch.Size([3, 4, 2])


In [14]:
class SelfAttention(nn.Module):
    def __init__(self,dim) -> None:
        super().__init__()
        self.dim = dim

        self.query_proj = nn.Linear(dim, dim)
        self.key_proj = nn.Linear(dim, dim)
        self.value_proj = nn.Linear(dim, dim)

        self.att_dropout = nn.Dropout(0.1)
        