In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [15]:
class SelfAttention(nn.Module):
    def __init__(self, embedding_dim, num_heads, qkv_bias = False):
        super(SelfAttention, self).__init__()
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads

        assert self.embedding_dim % self.num_heads == 0, "embedding_dim should be divisible by num of heads"

        self.wq = nn.Linear(self.embedding_dim, self.embedding_dim, bias = qkv_bias)
        self.wk = nn.Linear(self.embedding_dim, self.embedding_dim, bias = qkv_bias)
        self.wv = nn.Linear(self.embedding_dim, self.embedding_dim, bias = qkv_bias)

        # self.softmax = F.softmax(dim = -1)

    def forward(self, query, key, value, mask):
        batch_size, sequence_length, channel = query.shape

        query = self.wq(query)  # [b, L, c]
        key = self.wk(key)
        value = self.wv(value)

        query = query.view([batch_size, sequence_length, self.num_heads, -1]).transpose(1,2)  # [b, L, c] -> [b, num_heads, L, c//num_heads]
        key = key.view([batch_size, sequence_length, self.num_heads, -1]).transpose(1,2)
        value = value.view([batch_size, sequence_length, self.num_heads, -1]).transpose(1,2)

        attention_scores = torch.matmul(query, key.transpose(-1,-2)) * channel ** -0.5

        if mask is not None:
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)  # Shape: [batch_size, 1, seq_length, seq_length]
            mask = mask.to(torch.bool)
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))
            
        attention_scores = F.softmax(attention_scores, dim = -1)
        attention_output = torch.matmul(attention_scores, value)

        attention_output = attention_output.view([batch_size, sequence_length, -1])
        return attention_output

In [16]:
batch_size, sequence_length, embedding_dim, num_heads = 2, 77, 768, 8

query = torch.randn([batch_size, sequence_length, embedding_dim])
key = torch.randn([batch_size, sequence_length, embedding_dim])
value = torch.randn([batch_size, sequence_length, embedding_dim])

mask = torch.ones([batch_size, sequence_length, sequence_length])
mask[:, :, sequence_length // 2 :] = 0

model = SelfAttention(embedding_dim = embedding_dim, num_heads = num_heads)
output = model(query, key, value, mask)
print(output.shape)

torch.Size([2, 77, 768])


##### Attention-mask的作用是什么？

1. 首先在自注意力机制中, attention-mask的作用主要是为了控制模型在计算自注意力时, 对某些位置的特征进行关注和忽略。
2. attention-mask是一个矩阵, 用于在注意力计算中掩盖掉某些位置。通过, mask中的掩码位置值为一个非常大的负数, 以确保这些位置在softmax后接近0。