In [1]:
import numpy as np

def scaled_dot_product_attention(Q, K, V):
    """
    Computes scaled dot-product attention.

    Args:
        Q: Query matrix of shape (batch, seq_len, d_k)
        K: Key matrix   of shape (batch, seq_len, d_k)
        V: Value matrix of shape (batch, seq_len, d_v)

    Returns:
        attention_weights: (batch, seq_len, seq_len)
        context: (batch, seq_len, d_v)
    """
    d_k = Q.shape[-1]

    # Step 1: Compute scores = QK^T / sqrt(d_k)
    scores = np.matmul(Q, K.transpose(0, 2, 1)) / np.sqrt(d_k)

    # Step 2: Softmax along last dimension
    exp_scores = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    attention_weights = exp_scores / np.sum(exp_scores, axis=-1, keepdims=True)

    # Step 3: Compute context = attention_weights * V
    context = np.matmul(attention_weights, V)

    return attention_weights, context


# Example test
Q = np.random.rand(1, 5, 64)
K = np.random.rand(1, 5, 64)
V = np.random.rand(1, 5, 64)

attn_wt, context = scaled_dot_product_attention(Q, K, V)
print("Attention weights shape:", attn_wt.shape)
print("Context shape:", context.shape)



Attention weights shape: (1, 5, 5)
Context shape: (1, 5, 64)
