<a href="https://colab.research.google.com/github/warn4n/dl2025/blob/main/Notebooks/Chap12/12_1_Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Notebook 12.1: Self Attention**

This notebook builds a self-attention mechanism from scratch, as discussed in section 12.2 of the book.

Work through the cells below, running each cell in turn. In various places you will see the words "TODO". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.

Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.



In [1]:
import numpy as np
import matplotlib.pyplot as plt

The self-attention mechanism maps $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ and returns $N$ outputs $\mathbf{x}'_{n}\in \mathbb{R}^{D}$.  



In [2]:
# Set seed so we get the same random numbers
np.random.seed(3)
# Number of inputs
N = 3
# Number of dimensions of each input
D = 4
# Create an empty list
all_x = []
# Create elements x_n and append to list
for n in range(N):
  all_x.append(np.random.normal(size=(D,1)))
# Print out the list
print(all_x)


[array([[ 1.78862847],
       [ 0.43650985],
       [ 0.09649747],
       [-1.8634927 ]]), array([[-0.2773882 ],
       [-0.35475898],
       [-0.08274148],
       [-0.62700068]]), array([[-0.04381817],
       [-0.47721803],
       [-1.31386475],
       [ 0.88462238]])]


We'll also need the weights and biases for the keys, queries, and values (equations 12.2 and 12.4)

In [3]:
# Set seed so we get the same random numbers
np.random.seed(0)

# Choose random values for the parameters
omega_q = np.random.normal(size=(D,D))
omega_k = np.random.normal(size=(D,D))
omega_v = np.random.normal(size=(D,D))
beta_q = np.random.normal(size=(D,1))
beta_k = np.random.normal(size=(D,1))
beta_v = np.random.normal(size=(D,1))

Now let's compute the queries, keys, and values for each input

In [4]:
# Make three lists to store queries, keys, and values
all_queries = []
all_keys = []
all_values = []
# For every input
for x in all_x:
  # TODO -- compute the keys, queries and values.
  # Replace these three lines
  query = omega_q @ x + beta_q
  key = omega_k @ x + beta_k
  value = omega_v @ x + beta_v


  all_queries.append(query)
  all_keys.append(key)
  all_values.append(value)

We'll need a softmax function (equation 12.5) -- here, it will take a list of arbitrary numbers and return a list where the elements are non-negative and sum to one


In [5]:
def softmax(items_in):

  # TODO Compute the elements of items_out
  # Replace this line
  items_out = np.exp(items_in) / np.sum(np.exp(items_in))

  return items_out ;

In [6]:
np.exp([0,1,2,3,4])

array([ 1.        ,  2.71828183,  7.3890561 , 20.08553692, 54.59815003])

Now compute the self attention values:

In [7]:
# Create emptymlist for output
all_x_prime = []

# For each output
for n in range(N):
  # Create list for dot products of query N with all keys
  all_km_qn = []
  # Compute the dot products
  for key in all_keys:
    # TODO -- compute the appropriate dot product
    # Replace this line
    dot_product = key.T @ all_queries[n]

    # Store dot product
    all_km_qn.append(dot_product)

  # Compute dot product
  attention = softmax(all_km_qn)
  # Print result (should be positive sum to one)
  print("Attentions for output ", n)
  print(attention)

  # TODO: Compute a weighted sum of all of the values according to the attention
  # (equation 12.3)
  # Replace this line
  x_prime = np.zeros((D,1))
  for i in range(N):
    x_prime += attention[i] * all_values[i]

  all_x_prime.append(x_prime)


# Print out true values to check you have it correct
print("x_prime_0_calculated:", all_x_prime[0].transpose())
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_calculated:", all_x_prime[1].transpose())
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_calculated:", all_x_prime[2].transpose())
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")


Attentions for output  0
[[[1.24326146e-13]]

 [[9.98281489e-01]]

 [[1.71851130e-03]]]
Attentions for output  1
[[[2.79525306e-12]]

 [[5.85506360e-03]]

 [[9.94144936e-01]]]
Attentions for output  2
[[[0.00505708]]

 [[0.00654776]]

 [[0.98839516]]]
x_prime_0_calculated: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_calculated: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_calculated: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]


Now let's compute the same thing, but using matrix calculations.  We'll store the $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ in the columns of a $D\times N$ matrix, using equations 12.6 and 12.7/8.

Note:  The book uses column vectors (for compatibility with the rest of the text), but in the wider literature it is more normal to store the inputs in the rows of a matrix;  in this case, the computation is the same, but all the matrices are transposed and the operations proceed in the reverse order.

In [9]:
# Define softmax operation that works independently on each column
def softmax_cols(data_in):
  # Exponentiate all of the values
  exp_values = np.exp(data_in) ;
  # log out exp_values
  print("exp_values")
  print(exp_values)
  # Sum over columns
  denom = np.sum(exp_values, axis = 0);
  # log out denom
  print("denom")
  print(denom)
  # Replicate denominator to N rows
  denom = np.matmul(np.ones((data_in.shape[0],1)), denom[np.newaxis,:])
  # log out replicated denom
  print("replicated denom")
  print(denom)
  # Compute softmax
  softmax = exp_values / denom
  # return the answer
  return softmax

In [10]:
 # Now let's compute self attention in matrix form
def self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):

  # TODO -- Write this function
  # 1. Compute queries, keys, and values
  # 2. Compute dot products
  # 3. Apply softmax to calculate attentions
  # 4. Weight values by attentions
  # Replace this line
  XQ = omega_q @ X + beta_q
  XK = omega_k @ X + beta_k
  XV = omega_v @ X + beta_v
  dots = np.matmul(XK.T,XQ)
  attentions = softmax_cols(dots)
  X_prime =  XV @ attentions
  print(attentions)



  return X_prime

In [11]:
# Copy data into matrix
X = np.zeros((D, N))
X[:,0] = np.squeeze(all_x[0])
X[:,1] = np.squeeze(all_x[1])
X[:,2] = np.squeeze(all_x[2])

# Run the self attention mechanism
X_prime = self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)

exp_values
[[4.66985467e-14 2.21621546e-09 1.67867896e+00]
 [3.74967746e-01 4.64218522e+00 2.17350530e+00]
 [6.45495599e-04 7.88207480e+02 3.28094170e+02]]
denom
[3.75613241e-01 7.92849665e+02 3.31946354e+02]
replicated denom
[[3.75613241e-01 7.92849665e+02 3.31946354e+02]
 [3.75613241e-01 7.92849665e+02 3.31946354e+02]
 [3.75613241e-01 7.92849665e+02 3.31946354e+02]]
[[1.24326146e-13 2.79525306e-12 5.05707907e-03]
 [9.98281489e-01 5.85506360e-03 6.54776072e-03]
 [1.71851130e-03 9.94144936e-01 9.88395160e-01]]
[[ 0.94744244  1.64201168  1.61949281]
 [-0.24348429 -0.08470004 -0.06641533]
 [-0.91310441  4.02764044  3.96863308]
 [-0.44522983  2.18690791  2.15858316]]


x_prime_0_calculated: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_calculated: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_calculated: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]

If you did this correctly, the values should be the same as above.

TODO:  

Print out the attention matrix
You will see that the values are quite extreme (one is very close to one and the others are very close to zero.  Now we'll fix this problem by using scaled dot-product attention.

In [14]:
# Now let's compute self attention in matrix form
def scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):

  # TODO -- Write this function
  # 1. Compute queries, keys, and values
  # 2. Compute dot products
  # 3. Scale the dot products as in equation 12.9
  # 4. Apply softmax to calculate attentions
  # 5. Weight values by attentions
  # Replace this line
  X_prime = np.zeros_like(X);
  XQ = omega_q @ X + beta_q
  XK = omega_k @ X + beta_k
  XV = omega_v @ X + beta_v
  dots = np.matmul(XK.T,XQ)
  attentions = softmax_cols(dots / np.sqrt(D))
  # log out the attentions
  print("attentions")
  print(attentions)

  X_prime =  XV @ attentions

  return X_prime

In [15]:
# Run the self attention mechanism
X_prime = scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print("X_prime")
print(X_prime)

exp_values
[[2.16098465e-07 4.70766976e-05 1.29563844e+00]
 [6.12346100e-01 2.15457309e+00 1.47428128e+00]
 [2.54066054e-02 2.80750330e+01 1.81133699e+01]]
denom
[ 0.63775292 30.22965321 20.88328965]
replicated denom
[[ 0.63775292 30.22965321 20.88328965]
 [ 0.63775292 30.22965321 20.88328965]
 [ 0.63775292 30.22965321 20.88328965]]
attentions
[[3.38843552e-07 1.55730194e-06 6.20418746e-02]
 [9.60161968e-01 7.12734969e-02 7.05962187e-02]
 [3.98376935e-02 9.28724946e-01 8.67361907e-01]]
X_prime
[[ 0.97411966  1.59622051  1.32638014]
 [-0.23738409 -0.09516106  0.13062402]
 [-0.72333202  3.70194096  3.02371664]
 [-0.34413007  2.01339538  1.6902419 ]]


TODO -- Investigate whether the self-attention mechanism is covariant with respect to permutation.
If it is, when we permute the columns of the input matrix $\mathbf{X}$, the columns of the output matrix $\mathbf{X}'$ will also be permuted.


In [17]:
X

array([[ 1.78862847, -0.2773882 , -0.04381817],
       [ 0.43650985, -0.35475898, -0.47721803],
       [ 0.09649747, -0.08274148, -1.31386475],
       [-1.8634927 , -0.62700068,  0.88462238]])

In [21]:
# let p1 be the identity, p2, p3, etc be all the other 3x3 permutation matrices
import itertools

def generate_permutation_matrices(n):
  """Generates all n x n permutation matrices.

  Args:
    n: The dimension of the square matrix.

  Returns:
    A list of n x n NumPy arrays, where each array is a permutation matrix.
  """
  permutation_matrices = []
  # Get all permutations of the indices [0, 1, ..., n-1]
  for permutation in itertools.permutations(range(n)):
    # Create an identity matrix of size n x n
    matrix = np.eye(n)
    # Permute the columns of the identity matrix according to the permutation
    permuted_matrix = matrix[:, permutation]
    permutation_matrices.append(permuted_matrix)
  return permutation_matrices



In [22]:
permutation_matrices_3x3 = generate_permutation_matrices(3)


In [23]:
permutation_matrices_3x3

[array([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]]),
 array([[1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.]]),
 array([[0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 1.]]),
 array([[0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.]]),
 array([[0., 1., 0.],
        [0., 0., 1.],
        [1., 0., 0.]]),
 array([[0., 0., 1.],
        [0., 1., 0.],
        [1., 0., 0.]])]

In [24]:
permuted_Xs = [np.matmul(X,permutation_matrices_3x3[i]) for i in range(len(permutation_matrices_3x3))]

In [25]:
permuted_Xs

[array([[ 1.78862847, -0.2773882 , -0.04381817],
        [ 0.43650985, -0.35475898, -0.47721803],
        [ 0.09649747, -0.08274148, -1.31386475],
        [-1.8634927 , -0.62700068,  0.88462238]]),
 array([[ 1.78862847, -0.04381817, -0.2773882 ],
        [ 0.43650985, -0.47721803, -0.35475898],
        [ 0.09649747, -1.31386475, -0.08274148],
        [-1.8634927 ,  0.88462238, -0.62700068]]),
 array([[-0.2773882 ,  1.78862847, -0.04381817],
        [-0.35475898,  0.43650985, -0.47721803],
        [-0.08274148,  0.09649747, -1.31386475],
        [-0.62700068, -1.8634927 ,  0.88462238]]),
 array([[-0.2773882 , -0.04381817,  1.78862847],
        [-0.35475898, -0.47721803,  0.43650985],
        [-0.08274148, -1.31386475,  0.09649747],
        [-0.62700068,  0.88462238, -1.8634927 ]]),
 array([[-0.04381817,  1.78862847, -0.2773882 ],
        [-0.47721803,  0.43650985, -0.35475898],
        [-1.31386475,  0.09649747, -0.08274148],
        [ 0.88462238, -1.8634927 , -0.62700068]]),
 array([[-

In [26]:
[scaled_dot_product_self_attention(Y,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k) for Y in permuted_Xs]

exp_values
[[2.16098465e-07 4.70766976e-05 1.29563844e+00]
 [6.12346100e-01 2.15457309e+00 1.47428128e+00]
 [2.54066054e-02 2.80750330e+01 1.81133699e+01]]
denom
[ 0.63775292 30.22965321 20.88328965]
replicated denom
[[ 0.63775292 30.22965321 20.88328965]
 [ 0.63775292 30.22965321 20.88328965]
 [ 0.63775292 30.22965321 20.88328965]]
attentions
[[3.38843552e-07 1.55730194e-06 6.20418746e-02]
 [9.60161968e-01 7.12734969e-02 7.05962187e-02]
 [3.98376935e-02 9.28724946e-01 8.67361907e-01]]
exp_values
[[2.16098465e-07 1.29563844e+00 4.70766976e-05]
 [2.54066054e-02 1.81133699e+01 2.80750330e+01]
 [6.12346100e-01 1.47428128e+00 2.15457309e+00]]
denom
[ 0.63775292 20.88328965 30.22965321]
replicated denom
[[ 0.63775292 20.88328965 30.22965321]
 [ 0.63775292 20.88328965 30.22965321]
 [ 0.63775292 20.88328965 30.22965321]]
attentions
[[3.38843552e-07 6.20418746e-02 1.55730194e-06]
 [3.98376935e-02 8.67361907e-01 9.28724946e-01]
 [9.60161968e-01 7.05962187e-02 7.12734969e-02]]
exp_values
[[2.154

[array([[ 0.97411966,  1.59622051,  1.32638014],
        [-0.23738409, -0.09516106,  0.13062402],
        [-0.72333202,  3.70194096,  3.02371664],
        [-0.34413007,  2.01339538,  1.6902419 ]]),
 array([[ 0.97411966,  1.32638014,  1.59622051],
        [-0.23738409,  0.13062402, -0.09516106],
        [-0.72333202,  3.02371664,  3.70194096],
        [-0.34413007,  1.6902419 ,  2.01339538]]),
 array([[ 1.59622051,  0.97411966,  1.32638014],
        [-0.09516106, -0.23738409,  0.13062402],
        [ 3.70194096, -0.72333202,  3.02371664],
        [ 2.01339538, -0.34413007,  1.6902419 ]]),
 array([[ 1.59622051,  1.32638014,  0.97411966],
        [-0.09516106,  0.13062402, -0.23738409],
        [ 3.70194096,  3.02371664, -0.72333202],
        [ 2.01339538,  1.6902419 , -0.34413007]]),
 array([[ 1.32638014,  0.97411966,  1.59622051],
        [ 0.13062402, -0.23738409, -0.09516106],
        [ 3.02371664, -0.72333202,  3.70194096],
        [ 1.6902419 , -0.34413007,  2.01339538]]),
 array([[ 

In [27]:
permuted_outputs = [scaled_dot_product_self_attention(Y,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k) for Y in permuted_Xs]

exp_values
[[2.16098465e-07 4.70766976e-05 1.29563844e+00]
 [6.12346100e-01 2.15457309e+00 1.47428128e+00]
 [2.54066054e-02 2.80750330e+01 1.81133699e+01]]
denom
[ 0.63775292 30.22965321 20.88328965]
replicated denom
[[ 0.63775292 30.22965321 20.88328965]
 [ 0.63775292 30.22965321 20.88328965]
 [ 0.63775292 30.22965321 20.88328965]]
attentions
[[3.38843552e-07 1.55730194e-06 6.20418746e-02]
 [9.60161968e-01 7.12734969e-02 7.05962187e-02]
 [3.98376935e-02 9.28724946e-01 8.67361907e-01]]
exp_values
[[2.16098465e-07 1.29563844e+00 4.70766976e-05]
 [2.54066054e-02 1.81133699e+01 2.80750330e+01]
 [6.12346100e-01 1.47428128e+00 2.15457309e+00]]
denom
[ 0.63775292 20.88328965 30.22965321]
replicated denom
[[ 0.63775292 20.88328965 30.22965321]
 [ 0.63775292 20.88328965 30.22965321]
 [ 0.63775292 20.88328965 30.22965321]]
attentions
[[3.38843552e-07 6.20418746e-02 1.55730194e-06]
 [3.98376935e-02 8.67361907e-01 9.28724946e-01]
 [9.60161968e-01 7.05962187e-02 7.12734969e-02]]
exp_values
[[2.154

In [28]:
permuted_outputs

[array([[ 0.97411966,  1.59622051,  1.32638014],
        [-0.23738409, -0.09516106,  0.13062402],
        [-0.72333202,  3.70194096,  3.02371664],
        [-0.34413007,  2.01339538,  1.6902419 ]]),
 array([[ 0.97411966,  1.32638014,  1.59622051],
        [-0.23738409,  0.13062402, -0.09516106],
        [-0.72333202,  3.02371664,  3.70194096],
        [-0.34413007,  1.6902419 ,  2.01339538]]),
 array([[ 1.59622051,  0.97411966,  1.32638014],
        [-0.09516106, -0.23738409,  0.13062402],
        [ 3.70194096, -0.72333202,  3.02371664],
        [ 2.01339538, -0.34413007,  1.6902419 ]]),
 array([[ 1.59622051,  1.32638014,  0.97411966],
        [-0.09516106,  0.13062402, -0.23738409],
        [ 3.70194096,  3.02371664, -0.72333202],
        [ 2.01339538,  1.6902419 , -0.34413007]]),
 array([[ 1.32638014,  0.97411966,  1.59622051],
        [ 0.13062402, -0.23738409, -0.09516106],
        [ 3.02371664, -0.72333202,  3.70194096],
        [ 1.6902419 , -0.34413007,  2.01339538]]),
 array([[ 

In [32]:
# prompt: Determine if all of the elements of permuted_outputs have the same columns but possibly in different orders within a numerical tolerance of 0.0001

import numpy as np
def matrices_are_column_permutation_within_tolerance(matrix1, matrix2, tolerance=1e-4):
  """
  Checks if two matrices have the same columns within a tolerance,
  allowing for different column orders.

  Args:
    matrix1: The first NumPy array.
    matrix2: The second NumPy array.
    tolerance: The numerical tolerance for comparison.

  Returns:
    True if the matrices have the same columns within the tolerance, False otherwise.
  """
  if matrix1.shape != matrix2.shape:
    return False

  # Create lists of columns for easier comparison
  cols1 = [matrix1[:, i] for i in range(matrix1.shape[1])]
  cols2 = [matrix2[:, i] for i in range(matrix2.shape[1])]

  # Check if every column in cols1 is close to a column in cols2
  for col1 in cols1:
    found_match = False
    for col2 in cols2:
      if np.allclose(col1, col2, atol=tolerance):
        found_match = True
        break
    if not found_match:
      return False

  # Check if every column in cols2 is close to a column in cols1 (redundant but robust)
  for col2 in cols2:
    found_match = False
    for col1 in cols1:
      if np.allclose(col2, col1, atol=tolerance):
        found_match = True
        break
    if not found_match:
      return False

  return True

# Check if all pairs of matrices in permuted_outputs have the same columns
all_same_columns = True
for i in range(len(permuted_outputs)):
  for j in range(i + 1, len(permuted_outputs)):
    if not matrices_are_column_permutation_within_tolerance(permuted_outputs[i], permuted_outputs[j], tolerance=0.0001):
      all_same_columns = False
      break
  if not all_same_columns:
    break

print(f"All permuted outputs have the same columns (possibly different order) within tolerance: {all_same_columns}")

# You can also check a specific pair
# print(f"Are the first two outputs column permutations: {matrices_are_column_permutation_within_tolerance(permuted_outputs[0], permuted_outputs[1], tolerance=0.0001)}")


All permuted outputs have the same columns (possibly different order) within tolerance: True


In [33]:
# prompt: For each pair of elements of permuted_outputs, print the difference of the closest (in L1 norm) pair of columns.

import numpy as np
def find_closest_columns_difference(matrix1, matrix2):
  """
  Finds the L1 norm difference between the closest pair of columns
  between two matrices.

  Args:
    matrix1: The first NumPy array.
    matrix2: The second NumPy array.

  Returns:
    The minimum L1 norm difference between any column in matrix1 and
    any column in matrix2.
  """
  min_diff = float('inf')
  for col1 in matrix1.T:  # Iterate over columns of the first matrix
    for col2 in matrix2.T:  # Iterate over columns of the second matrix
      diff = np.sum(np.abs(col1 - col2))  # Calculate L1 norm difference
      if diff < min_diff:
        min_diff = diff
  return min_diff

for i in range(len(permuted_outputs)):
    for j in range(i + 1, len(permuted_outputs)):
        diff = find_closest_columns_difference(permuted_outputs[i], permuted_outputs[j])
        print(f"Difference of closest columns between matrix {i} and matrix {j}: {diff}")


Difference of closest columns between matrix 0 and matrix 1: 1.942890293094024e-16
Difference of closest columns between matrix 0 and matrix 2: 0.0
Difference of closest columns between matrix 0 and matrix 3: 0.0
Difference of closest columns between matrix 0 and matrix 4: 1.942890293094024e-16
Difference of closest columns between matrix 0 and matrix 5: 2.220446049250313e-16
Difference of closest columns between matrix 1 and matrix 2: 1.942890293094024e-16
Difference of closest columns between matrix 1 and matrix 3: 1.942890293094024e-16
Difference of closest columns between matrix 1 and matrix 4: 0.0
Difference of closest columns between matrix 1 and matrix 5: 1.942890293094024e-16
Difference of closest columns between matrix 2 and matrix 3: 0.0
Difference of closest columns between matrix 2 and matrix 4: 1.942890293094024e-16
Difference of closest columns between matrix 2 and matrix 5: 0.0
Difference of closest columns between matrix 3 and matrix 4: 1.942890293094024e-16
Difference 