In [10]:
import torch
import torch.nn as nn
import math
from torch.nn import functional as F
# 定义相对位置编码模块
class RelativePositionEmbedding(nn.Module):
    def __init__(self, max_seq_len: int, embedding_dim: int):
        super().__init__()
        self.max_seq_len = max_seq_len
        self.embedding_dim = embedding_dim
        self.relative_positions = nn.Parameter(torch.randn(max_seq_len * 2 - 1, embedding_dim))
    
    def forward(self, seq_len: int):
        relative_positions_matrix = self._generate_relative_positions_matrix(seq_len)
        relative_embeddings = F.embedding(relative_positions_matrix, self.relative_positions)
        return relative_embeddings
    
    def _generate_relative_positions_matrix(self, seq_len: int):
        range_vec = torch.arange(seq_len)
        range_matrix = range_vec.unsqueeze(0).expand(seq_len, seq_len)
        distance_matrix = range_matrix - range_matrix.t()
        distance_matrix = distance_matrix + self.max_seq_len - 1
        return distance_matrix
# 定义多头注意力模块
class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim: int, num_heads: int, max_seq_len: int):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.rel_pos_embeddings = RelativePositionEmbedding(max_seq_len, self.head_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
        batch_size, seq_len, embed_dim = query.size()
        query = self.q_proj(query).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        key = self.k_proj(key).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        value = self.v_proj(value).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        rel_pos_embeddings = self.rel_pos_embeddings(seq_len)
        rel_pos_embeddings = rel_pos_embeddings.unsqueeze(0).expand(batch_size, seq_len, seq_len, self.head_dim)
        query = query.unsqueeze(3)
        key = key.unsqueeze(2)
        attention_scores = torch.matmul(query, key.transpose(-1, -2))
        attention_scores += torch.matmul(query, rel_pos_embeddings.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(self.head_dim)
        attention_weights = F.softmax(attention_scores, dim=-1)
        output = torch.matmul(attention_weights, value)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
        output = self.out_proj(output)
        return output
# 示例使用
max_seq_len = 10
embed_dim = 64
num_heads = 4
batch_size = 2
seq_len = 5

query = torch.randn(batch_size, seq_len, embed_dim)
key = torch.randn(batch_size, seq_len, embed_dim)
value = torch.randn(batch_size, seq_len, embed_dim)

attention = MultiHeadAttention(embed_dim, num_heads, max_seq_len)
print(attention.rel_pos_embeddings._generate_relative_positions_matrix(5))
print(attention.rel_pos_embeddings.relative_positions.shape)
print(attention.rel_pos_embeddings(5).shape)
# output = attention(query, key, value)
# print(output.shape)  # 输出形状为 (batch_size, seq_len, embed_dim)

tensor([[ 9, 10, 11, 12, 13],
        [ 8,  9, 10, 11, 12],
        [ 7,  8,  9, 10, 11],
        [ 6,  7,  8,  9, 10],
        [ 5,  6,  7,  8,  9]])
torch.Size([19, 16])
torch.Size([5, 5, 16])
