In [2]:
import torch
import torch.nn.functional as F


In [4]:
def scaled_dot_product_attention(Q, K, V):
    """
    Implements: softmax( (QK^T) / sqrt(d_k) ) V
    
    Q, K, V shape: (batch, seq_len, d_k)
    """
    d_k = Q.size(-1)
    
    # 1. Raw attention scores (QK^T)
    scores = torch.matmul(Q, K.transpose(-2, -1))  # (B, seq, seq)

    # Softmax BEFORE scaling (stability check)
    softmax_before = F.softmax(scores, dim=-1)

    # 2. Scale by sqrt(d_k)
    scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))

    # Softmax AFTER scaling
    softmax_after = F.softmax(scaled_scores, dim=-1)

    # 3. Final attention weights
    attn_weights = softmax_after

    # 4. Attention output = weights * V
    output = torch.matmul(attn_weights, V)

    return attn_weights, output, softmax_before, softmax_after


In [6]:
# Test parameters
batch = 1
seq_len = 4
d_k = 8

# Random Q, K, V
Q = torch.randn(batch, seq_len, d_k)
K = torch.randn(batch, seq_len, d_k)
V = torch.randn(batch, seq_len, d_k)

# Run attention
attn_w, out, soft_before, soft_after = scaled_dot_product_attention(Q, K, V)


In [8]:
print("=== Attention Weight Matrix ===")
print(attn_w)

print("\n=== Output Vectors ===")
print(out)

print("\n=== Softmax BEFORE scaling ===")
print(soft_before)

print("\n=== Softmax AFTER scaling ===")
print(soft_after)


=== Attention Weight Matrix ===
tensor([[[0.4929, 0.1561, 0.2338, 0.1171],
         [0.0238, 0.2865, 0.5080, 0.1818],
         [0.2184, 0.2550, 0.2867, 0.2399],
         [0.0476, 0.0218, 0.0434, 0.8872]]])

=== Output Vectors ===
tensor([[[ 0.2389,  0.1452,  0.3632,  0.4489,  0.0188,  0.8193,  0.7235,
           0.3997],
         [-0.2456, -0.0512, -0.1037,  0.6577, -0.2569,  0.3588,  0.9007,
           0.3187],
         [ 0.2671,  0.0984,  0.2333,  0.5416, -0.0735,  0.5456,  0.7390,
           0.1525],
         [ 1.5190,  0.3850,  0.7770,  0.8229, -0.4060,  0.1993,  0.3198,
          -2.0086]]])

=== Softmax BEFORE scaling ===
tensor([[[8.4944e-01, 3.2886e-02, 1.0308e-01, 1.4594e-02],
         [1.3814e-04, 1.5796e-01, 7.9829e-01, 4.3613e-02],
         [1.6636e-01, 2.5772e-01, 3.5900e-01, 2.1692e-01],
         [2.5505e-04, 2.7991e-05, 1.9590e-04, 9.9952e-01]]])

=== Softmax AFTER scaling ===
tensor([[[0.4929, 0.1561, 0.2338, 0.1171],
         [0.0238, 0.2865, 0.5080, 0.1818],
         