In [None]:
import numpy as np
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data.sampler import SubsetRandomSampler

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class MultiHeadSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadSelfAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
         
        self.query = nn.Linear(embed_dim, embed_dim)
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)
     
     
    def forward(self, x):
         batch_size, seq_len, embed_dim = x.size()
         
         q = self.query(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
         k = self.key(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
         v = self.value(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
         
         attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float))
         attn_weights = torch.softmax(attn_weights, dim=-1)
    
        # 注意力加权求和
         attended_values = torch.matmul(attn_weights, v).transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
    
        # 经过线性变换和残差连接
         x = self.fc(attended_values) + x
    
         return x

In [None]:
class MultiHead(nn.Module):
    def __init__(self, embed_dim, head_num):
        self.embed_dim = embed_dim
        self.head_num = head_num
        self.head_dim = embed_dim // head_num
        self.q = nn.Linear(embed_dim, embed_dim)
        self.k = nn.Linear(embed_dim, embed_dim)
        self.v = nn.Linear(embed_dim, embed_dim)
        self.fc = nn.Linear(embed_dim, embed_dim)
    
    def forward(self, x):
        seq_len, embed_dim = x.size()
        q = self.q(x).view(seq_len, self.head_num, self.head_dim).transpose(1, 2)
        k = self.k(x).view(seq_len, self.head_num, self.head_dim).transpose(1, 2)
        v = self.v(x).view(seq_len, self.head_num, self.head_dim).transpose(1, 2)
        
        attn_weights = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim))
        attn_weights = torch.softmax(attn_weights, dim=-1)
        attn_value = torch.matmul(attn_weights, v).transpose(1, 2)
        
        x = self.fc(attn_value) + x