In [1]:
import torch
from torch import nn
import torch.nn.functional as F



In [None]:
# def MultiHeadAttention(sekf): # основаная функция внимания


In [42]:
def scales_dot_product_attention(Q,K,V, mask = None):
  d_k = Q.shape[-1] # размерность d_k
  K = K.transpose(-2,-1)
  scores = torch.matmul(Q, K) / torch.sqrt(torch.tensor(d_k, dtype = torch.float32))

  if mask is not None:
    scores = scores.masked_fill(mask == 0, float('-inf')) # маскируем ненужные значения

  attention = F.softmax(scores, dim = 1)
  output = torch.matmul(attention, V)
  return output, attention



In [39]:

class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    assert d_model % num_heads == 0 , "d_model должен делиться нацело на кол-во голов"


    self.d_model = d_model
    self.num_heads = num_heads
    self.d_k = d_model//num_heads   # размерность каждой головы

    # линейные преобразования для Q,K,V
    self.W_q = nn.Linear(d_model, d_model)
    self.W_k = nn.Linear(d_model, d_model)
    self.W_v = nn.Linear(d_model, d_model)

    # линейное преобразование проекции output
    self.W_o = nn.Linear(d_model, d_model)

  def forward(self, K, Q, V, mask = None):

    batch_size = Q.shape[0]
    Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
    K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)
    V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1,2)

    attention_output, attention_weights = scales_dot_product_attention(Q,K,V)


    attention_output = attention_output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)

    output = self.W_o(attention_output)

    return output, attention_weights







In [43]:
# Тестируем на случайных данных
batch_size = 2
seq_len = 5
d_model = 512
num_heads = 8

attention = MultiHeadAttention(d_model, num_heads)
Q = torch.rand(batch_size, seq_len, d_model)
K = torch.rand(batch_size, seq_len, d_model)
V = torch.rand(batch_size, seq_len, d_model)

In [44]:
output, attention_w = attention(Q,K,V)

print("Выход Multi-Head Attention:", output.shape)  # Должно быть (batch_size, seq_len, d_model)
print("Веса внимания:", attention_w.shape)  # Должно быть (batch_size, num_heads, seq_len, seq_len)

Выход Multi-Head Attention: torch.Size([2, 5, 512])
Веса внимания: torch.Size([2, 8, 5, 5])
