# **Notebook 12.2: Multihead Self-Attention**

This notebook builds a multihead self-attention mechanism as in figure 12.6

Work through the cells below, running each cell in turn. In various places you will see the words "TO DO". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.

Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.



In [1]:
import numpy as np
import matplotlib.pyplot as plt

The multihead self-attention mechanism maps $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ and returns $N$ outputs $\mathbf{x}'_{n}\in \mathbb{R}^{D}$.  



In [2]:
# Set seed so we get the same random numbers
np.random.seed(3)
# Number of inputs
N = 6
# Number of dimensions of each input
D = 8
# Create an empty list
X = np.random.normal(size=(D,N))
print(X)

[[ 1.78862847  0.43650985  0.09649747 -1.8634927  -0.2773882  -0.35475898]
 [-0.08274148 -0.62700068 -0.04381817 -0.47721803 -1.31386475  0.88462238]
 [ 0.88131804  1.70957306  0.05003364 -0.40467741 -0.54535995 -1.54647732]
 [ 0.98236743 -1.10106763 -1.18504653 -0.2056499   1.48614836  0.23671627]
 [-1.02378514 -0.7129932   0.62524497 -0.16051336 -0.76883635 -0.23003072]
 [ 0.74505627  1.97611078 -1.24412333 -0.62641691 -0.80376609 -2.41908317]
 [-0.92379202 -1.02387576  1.12397796 -0.13191423 -1.62328545  0.64667545]
 [-0.35627076 -1.74314104 -0.59664964 -0.58859438 -0.8738823   0.02971382]]


We'll use two heads.  We'll need the weights and biases for the keys, queries, and values (equations 12.2 and 12.4).  We'll use two heads, and (as in the figure), we'll make the queries keys and values of size D/H

In [3]:
# Number of heads
H = 2
# QDV dimension
H_D = int(D/H)

# Set seed so we get the same random numbers
np.random.seed(0)

# Choose random values for the parameters for the first head
omega_q1 = np.random.normal(size=(H_D,D))
omega_k1 = np.random.normal(size=(H_D,D))
omega_v1 = np.random.normal(size=(H_D,D))
beta_q1 = np.random.normal(size=(H_D,1))
beta_k1 = np.random.normal(size=(H_D,1))
beta_v1 = np.random.normal(size=(H_D,1))

# Choose random values for the parameters for the second head
omega_q2 = np.random.normal(size=(H_D,D))
omega_k2 = np.random.normal(size=(H_D,D))
omega_v2 = np.random.normal(size=(H_D,D))
beta_q2 = np.random.normal(size=(H_D,1))
beta_k2 = np.random.normal(size=(H_D,1))
beta_v2 = np.random.normal(size=(H_D,1))

# Choose random values for the parameters
omega_c = np.random.normal(size=(D,D))

Now let's compute the multiscale self-attention

In [4]:
# Define softmax operation that works independently on each column
def softmax_cols(data_in):
  # Exponentiate all of the values
  exp_values = np.exp(data_in) ;
  # Sum over columns
  denom = np.sum(exp_values, axis = 0);
  # Compute softmax (numpy broadcasts denominator to all rows automatically)
  softmax = exp_values / denom
  # return the answer
  return softmax

In [10]:
 # Now let's compute self attention in matrix form
def multihead_scaled_self_attention(X, omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1,
                                     omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2, omega_c):
    """
    Compute multihead scaled self-attention mechanism with 2 heads.

    Args:
        X: Input matrix of shape (sequence_length, embedding_dim)
        omega_v1, omega_q1, omega_k1: Weight matrices for head 1 (projections)
        beta_v1, beta_q1, beta_k1: Bias terms for head 1
        omega_v2, omega_q2, omega_k2: Weight matrices for head 2 (projections)
        beta_v2, beta_q2, beta_k2: Bias terms for head 2
        omega_c: Concatenation weight matrix combining both heads

    Returns:
        X_prime: Output of multihead attention
    """

    d_k = omega_k1.shape[1]

    print("Computing Head 1...")
    V1 = X @ omega_v1 + beta_v1
    Q1 = X @ omega_q1 + beta_q1
    K1 = X @ omega_k1 + beta_k1

    dot_products_1 = Q1 @ K1.T

    scaled_dot_products_1 = dot_products_1 / np.sqrt(d_k)

    attention_weights_1 = softmax(scaled_dot_products_1)

    head1_output = attention_weights_1 @ V1

    print("Computing Head 2...")
    V2 = X @ omega_v2 + beta_v2
    Q2 = X @ omega_q2 + beta_q2
    K2 = X @ omega_k2 + beta_k2

    dot_products_2 = Q2 @ K2.T

    scaled_dot_products_2 = dot_products_2 / np.sqrt(d_k)

    attention_weights_2 = softmax(scaled_dot_products_2)

    head2_output = attention_weights_2 @ V2

    print("Concatenating and combining heads...")
    concatenated_output = np.concatenate([head1_output, head2_output], axis=1)

    X_prime = concatenated_output @ omega_c

    return X_prime

In [14]:
# Run the self attention mechanism
def softmax(x):
    e_x = np.exp(x - np.max(x, axis=1, keepdims=True))
    return e_x / e_x.sum(axis=1, keepdims=True)


def multihead_scaled_self_attention(X, omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1,
                                     omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2, omega_c):

    print("Computing Head 1...")
    print(f"X shape: {X.shape}")
    print(f"omega_v1 shape: {omega_v1.shape}")

    V1 = X @ omega_v1 + beta_v1
    Q1 = X @ omega_q1 + beta_q1
    K1 = X @ omega_k1 + beta_k1

    print(f"V1 shape: {V1.shape}, Q1 shape: {Q1.shape}, K1 shape: {K1.shape}")

    dot_products_1 = Q1 @ K1.T
    print(f"dot_products_1 shape: {dot_products_1.shape}")

    d_k = K1.shape[1]
    scaled_dot_products_1 = dot_products_1 / np.sqrt(d_k)

    attention_weights_1 = softmax(scaled_dot_products_1)
    print(f"attention_weights_1 shape: {attention_weights_1.shape}")

    head1_output = attention_weights_1 @ V1
    print(f"head1_output shape: {head1_output.shape}")

    print("\nComputing Head 2...")
    V2 = X @ omega_v2 + beta_v2
    Q2 = X @ omega_q2 + beta_q2
    K2 = X @ omega_k2 + beta_k2

    print(f"V2 shape: {V2.shape}, Q2 shape: {Q2.shape}, K2 shape: {K2.shape}")

    dot_products_2 = Q2 @ K2.T
    print(f"dot_products_2 shape: {dot_products_2.shape}")

    d_k = K2.shape[1]
    scaled_dot_products_2 = dot_products_2 / np.sqrt(d_k)

    attention_weights_2 = softmax(scaled_dot_products_2)
    print(f"attention_weights_2 shape: {attention_weights_2.shape}")

    head2_output = attention_weights_2 @ V2
    print(f"head2_output shape: {head2_output.shape}")

    print("\nConcatenating and combining heads...")
    concatenated_output = np.concatenate([head1_output, head2_output], axis=1)
    print(f"concatenated_output shape: {concatenated_output.shape}")
    print(f"omega_c shape: {omega_c.shape}")

    X_prime = concatenated_output @ omega_c
    print(f"X_prime shape: {X_prime.shape}")

    return X_prime

In [15]:
import numpy as np

# Set seed so we get the same random numbers
np.random.seed(3)

# Number of inputs and dimensions
N = 6
D = 8

# Inputs matrix (each column is a token/vector)
X = np.random.normal(size=(D, N))
print(X)

# Number of heads
H = 2
# Per-head Q/K/V dimension
H_D = int(D / H)

# Set seed so we get the same random numbers for parameters
np.random.seed(0)

# Head 1 parameters
omega_q1 = np.random.normal(size=(H_D, D))
omega_k1 = np.random.normal(size=(H_D, D))
omega_v1 = np.random.normal(size=(H_D, D))
beta_q1  = np.random.normal(size=(H_D, 1))
beta_k1  = np.random.normal(size=(H_D, 1))
beta_v1  = np.random.normal(size=(H_D, 1))

# Head 2 parameters
omega_q2 = np.random.normal(size=(H_D, D))
omega_k2 = np.random.normal(size=(H_D, D))
omega_v2 = np.random.normal(size=(H_D, D))
beta_q2  = np.random.normal(size=(H_D, 1))
beta_k2  = np.random.normal(size=(H_D, 1))
beta_v2  = np.random.normal(size=(H_D, 1))

# Output projection
omega_c = np.random.normal(size=(D, D))

# Column-wise softmax (stable)
def softmax_cols(data_in):
    shifted = data_in - np.max(data_in, axis=0, keepdims=True)
    exp_values = np.exp(shifted)
    denom = np.sum(exp_values, axis=0, keepdims=True)
    return exp_values / denom

# Multihead scaled dot-product self-attention
def multihead_scaled_self_attention(
    X,
    omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1,
    omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2,
    omega_c
):
    D, N = X.shape
    H_D = omega_q1.shape[0]
    ones_row = np.ones((1, N))

    # ---- Head 1 ----
    Q1 = omega_q1 @ X + beta_q1 @ ones_row        # (H_D, N)
    K1 = omega_k1 @ X + beta_k1 @ ones_row        # (H_D, N)
    V1 = omega_v1 @ X + beta_v1 @ ones_row        # (H_D, N)

    scores1 = K1.T @ Q1                            # (N, N) with [m, n] = k_m^T q_n
    scores1 = scores1 * (1.0 / np.sqrt(H_D))       # scale by sqrt(d_k)
    A1 = softmax_cols(scores1)                     # (N, N)
    O1 = V1 @ A1                                   # (H_D, N)

    # ---- Head 2 ----
    Q2 = omega_q2 @ X + beta_q2 @ ones_row        # (H_D, N)
    K2 = omega_k2 @ X + beta_k2 @ ones_row        # (H_D, N)
    V2 = omega_v2 @ X + beta_v2 @ ones_row        # (H_D, N)

    scores2 = K2.T @ Q2                            # (N, N)
    scores2 = scores2 * (1.0 / np.sqrt(H_D))       # scale by sqrt(d_k)
    A2 = softmax_cols(scores2)                     # (N, N)
    O2 = V2 @ A2                                   # (H_D, N)

    # Concatenate head outputs and project
    O = np.vstack([O1, O2])                        # (D, N)
    X_prime = omega_c @ O                          # (D, N)

    return X_prime

# Run the self attention mechanism
X_prime = multihead_scaled_self_attention(
    X,
    omega_v1, omega_q1, omega_k1, beta_v1, beta_q1, beta_k1,
    omega_v2, omega_q2, omega_k2, beta_v2, beta_q2, beta_k2,
    omega_c
)

# Print out the results
np.set_printoptions(precision=3, suppress=True)
print("My answer:")
print(X_prime)

print("\nTrue values:")
print("[[-21.207  -5.373 -20.933  -9.179 -11.319 -17.812]")
print(" [ -1.995   7.906 -10.516   3.452   9.863  -7.24 ]")
print(" [  5.479   1.115   9.244   0.453   5.656   7.089]")
print(" [ -7.413  -7.416   0.363  -5.573  -6.736  -0.848]")
print(" [-11.261  -9.937  -4.848  -8.915 -13.378  -5.761]")
print(" [  3.548  10.036  -2.244   1.604  12.113  -2.557]")
print(" [  4.888  -5.814   2.407   3.228  -4.232   3.71 ]")
print(" [  1.248  18.894  -6.409   3.224  19.717  -5.629]]")


[[ 1.789  0.437  0.096 -1.863 -0.277 -0.355]
 [-0.083 -0.627 -0.044 -0.477 -1.314  0.885]
 [ 0.881  1.71   0.05  -0.405 -0.545 -1.546]
 [ 0.982 -1.101 -1.185 -0.206  1.486  0.237]
 [-1.024 -0.713  0.625 -0.161 -0.769 -0.23 ]
 [ 0.745  1.976 -1.244 -0.626 -0.804 -2.419]
 [-0.924 -1.024  1.124 -0.132 -1.623  0.647]
 [-0.356 -1.743 -0.597 -0.589 -0.874  0.03 ]]
My answer:
[[-21.207  -5.373 -20.933  -9.179 -11.319 -17.812]
 [ -1.995   7.906 -10.516   3.452   9.863  -7.24 ]
 [  5.479   1.115   9.244   0.453   5.656   7.089]
 [ -7.413  -7.416   0.363  -5.573  -6.736  -0.848]
 [-11.261  -9.937  -4.848  -8.915 -13.378  -5.761]
 [  3.548  10.036  -2.244   1.604  12.113  -2.557]
 [  4.888  -5.814   2.407   3.228  -4.232   3.71 ]
 [  1.248  18.894  -6.409   3.224  19.717  -5.629]]

True values:
[[-21.207  -5.373 -20.933  -9.179 -11.319 -17.812]
 [ -1.995   7.906 -10.516   3.452   9.863  -7.24 ]
 [  5.479   1.115   9.244   0.453   5.656   7.089]
 [ -7.413  -7.416   0.363  -5.573  -6.736  -0.848]
 