In [2]:
import numpy as np

def self_attention(Q, K, V):
    """
    Self-Attention mechanism for a single attention head.

    Parameters:
    - Q (ndarray): Query matrix (sequence_length, d_model).
    - K (ndarray): Key matrix (sequence_length, d_model).
    - V (ndarray): Value matrix (sequence_length, d_model).

    Returns:
    - Attention output (ndarray): Weighted sum of values based on attention scores.
    """
    d_model = Q.shape[-1]
    print('d_model: ',d_model)

    # Calculate attention scores (scaled dot-product attention)
    scores = np.matmul(Q, K.T) / np.sqrt(d_model)
    print('scores: ',scores)
    # Apply softmax to obtain attention weights
    attention_weights = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
    attention_weights /= np.sum(attention_weights, axis=-1, keepdims=True)
    print('attention_weights: ',attention_weights)
    # Calculate attention output as a weighted sum of values
    attention_output = np.matmul(attention_weights, V)

    return attention_output

# Example usage
sequence_length = 5
d_model = 3

# Example query, key, and value matrices
Q = np.random.rand(sequence_length, d_model)
K = np.random.rand(sequence_length, d_model)
V = np.random.rand(sequence_length, d_model)

# Apply self-attention
attention_output = self_attention(Q, K, V)

print("Query Matrix (Q):\n", Q)
print("Key Matrix (K):\n", K)
print("Value Matrix (V):\n", V)
print("Attention Output:\n", attention_output)


d_model:  3
scores:  [[0.23597567 0.50181978 0.4517456  0.49794841 0.25514919]
 [0.57074441 0.43283436 0.75930215 0.39959128 0.57724248]
 [0.17780042 0.53484149 0.37844119 0.53948227 0.21710484]
 [0.30004891 0.16501067 0.35482251 0.14731527 0.31059413]
 [0.17739261 0.31114326 0.17613404 0.31917965 0.26748215]]
attention_weights:  [[0.1705263  0.22245696 0.21159191 0.22159741 0.17382743]
 [0.20292638 0.1767848  0.24503498 0.17100454 0.2042493 ]
 [0.1632091  0.23324124 0.19947182 0.23432617 0.16975167]
 [0.20837921 0.18205727 0.22011126 0.17886403 0.21058824]
 [0.18558201 0.21214023 0.18534858 0.21385194 0.20307724]]
Query Matrix (Q):
 [[0.34389445 0.84954018 0.22716043]
 [0.15460511 0.66601988 0.96000454]
 [0.48509182 0.81343594 0.06162514]
 [0.08220974 0.17436809 0.52095687]
 [0.62363214 0.02672537 0.07829541]]
Key Matrix (K):
 [[0.37479224 0.08617332 0.90960026]
 [0.81081018 0.63966366 0.20656881]
 [0.34752642 0.52670637 0.94856179]
 [0.84111984 0.63620184 0.14411171]
 [0.62562712 0.0