In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs to be divisible by heads"

        self.W_V = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.W_K = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.W_Q = nn.Linear(self.embed_size, self.embed_size, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = self.W_V(values).reshape(N, value_len, self.heads, self.head_dim)
        keys = self.W_K(keys).reshape(N, key_len, self.heads, self.head_dim)
        queries = self.W_Q(query).reshape(N, query_len, self.heads, self.head_dim)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        attention = torch.softmax(energy / (self.embed_size ** (1/2)), dim=3)

        out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        out = self.fc_out(out)
        return out, attention

# 示例使用
embed_size = 256
heads = 8
seq_length = 10
batch_size = 1

mha = MultiHeadAttention(embed_size, heads)
x = torch.randn(batch_size, seq_length, embed_size)
mask = torch.ones(batch_size, 1, seq_length, seq_length)

output, attention = mha(x, x, x, mask)

# 可视化不同头的注意力分布
plt.figure(figsize=(15, 8))
for i in range(heads):
    plt.subplot(2, 4, i+1)
    plt.imshow(attention[0, i].detach().numpy(), cmap='viridis')
    plt.title(f'Head {i+1}')
    plt.axis('off')
plt.tight_layout()
plt.show()