In [4]:
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(Q, K, V):
    """
    Q: Queries (shape: [n_queries, d_k])
    K: Keys    (shape: [n_keys, d_k])
    V: Values  (shape: [n_keys, d_v])
    """
    d_k = Q.size(-1)  # key dimension
    
    # 1. Raw scores = Q @ K^T
    scores = Q @ K.T  # alternative syntax
    # transpose -2 and -1 because we may have batch dimensions
    print(f"Raw scores: {scores}")
    
    # 2. Scale
    scaled_scores = scores / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    print(f"Scaled scores: {scaled_scores}")
    
    # 3. Softmax along keys axis
    weights = scaled_scores/torch.sum(scaled_scores, dim=-1, keepdim=True)
    # dim -1 is the keys axis, keepdim to maintain shape for broadcasting

    print(f"Attention weights: {weights}")
    
    # 4. Weighted sum of values
    output = weights @ V  # shape: [n_queries, d_v]
    print(f"Final output: {output}")
    
    return output, weights

# Example sentence: "He sat by the river bank"
words = ["He", "sat", "river", "bank"]

Q = torch.tensor([[1.0, 0.0]])   # the query vector for "bank"
K = torch.tensor([
    [1.0, 0.0],  # "He"
    [0.0, 1.0],  # "sat"
    [1.0, 1.0],  # "river"
    [1.0, 0.0],  # "bank"
])
V = torch.tensor([
    [1.0, 0.0],  # info from "He"
    [0.0, 1.0],  # info from "sat"
    [0.5, 0.5],  # info from "river"
    [1.0, 0.0],  # info from "bank"
])

print('Sentence: "He sat by the river bank"')
print('Query: "bank" looking for relevant context\n')

output, weights = scaled_dot_product_attention(Q, K, V)

print(f"\nAttention distribution:")
for word, weight in zip(words, weights[0]):
    print(f"{word}: {weight:.3f}")

Sentence: "He sat by the river bank"
Query: "bank" looking for relevant context

Raw scores: tensor([[1., 0., 1., 1.]])
Scaled scores: tensor([[0.7071, 0.0000, 0.7071, 0.7071]])
Attention weights: tensor([[0.3333, 0.0000, 0.3333, 0.3333]])
Final output: tensor([[0.8333, 0.1667]])

Attention distribution:
He: 0.333
sat: 0.000
river: 0.333
bank: 0.333
