In [1]:
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
  matmul_qk = torch.matmul(query, key.transpose(-2, -1))

# Scaling of the dot product by square root of the key dimension
  dk = key.size()[-1]
  scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(dk, dtype=torch.float32))

# Masking if provided
  if mask is not None:
    scaled_attention_logits = scaled_attention_logits + (mask * -1e9)

  attention_weights = F.softmax(scaled_attention_logits, dim=-1)
  output = torch.matmul(attention_weights, value)
  return output, attention_weights

In [2]:
def test_scaled_dot_product_attention():
  torch.manual_seed(0)

  batch_size = 1
  seq_len = 5
  d_k = 64

  query = torch.randn(batch_size, seq_len, d_k)
  key = torch.randn(batch_size, seq_len, d_k)
  value = torch.randn(batch_size, seq_len, d_k)
  mask = None

  output, attention_weights = scaled_dot_product_attention(query, key, value, mask)
  # print("Output:")
  # print(output)
  print("Attention Weights:")
  print(attention_weights)

test_scaled_dot_product_attention()

Attention Weights:
tensor([[[0.1289, 0.1020, 0.3652, 0.1170, 0.2869],
         [0.1229, 0.1987, 0.0404, 0.4620, 0.1760],
         [0.4934, 0.1396, 0.0509, 0.0403, 0.2759],
         [0.2127, 0.1136, 0.3153, 0.0549, 0.3035],
         [0.1189, 0.3081, 0.2382, 0.2010, 0.1338]]])
