<a href="https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap12/12_1_Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Notebook 12.1: Self Attention**

This notebook builds a self-attention mechanism from scratch, as discussed in section 12.2 of the book.

Work through the cells below, running each cell in turn. In various places you will see the words "TO DO". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.

Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.



In [1]:
import numpy as np
import matplotlib.pyplot as plt
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

The self-attention mechanism maps $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ and returns $N$ outputs $\mathbf{x}'_{n}\in \mathbb{R}^{D}$.  



In [2]:
# Set seed so we get the same random numbers
np.random.seed(3)
# Number of inputs
N = 3
# Number of dimensions of each input
D = 4
# Create an empty list
all_x = []
# Create elements x_n and append to list
for n in range(N):
  all_x.append(np.random.normal(size=(D,1)))
# Print out the list
print(all_x)


[array([[ 1.78862847],
       [ 0.43650985],
       [ 0.09649747],
       [-1.8634927 ]]), array([[-0.2773882 ],
       [-0.35475898],
       [-0.08274148],
       [-0.62700068]]), array([[-0.04381817],
       [-0.47721803],
       [-1.31386475],
       [ 0.88462238]])]


We'll also need the weights and biases for the keys, queries, and values (equations 12.2 and 12.4)

In [3]:
# Set seed so we get the same random numbers
np.random.seed(0)

# Choose random values for the parameters
omega_q = np.random.normal(size=(D,D))
omega_k = np.random.normal(size=(D,D))
omega_v = np.random.normal(size=(D,D))
beta_q = np.random.normal(size=(D,1))
beta_k = np.random.normal(size=(D,1))
beta_v = np.random.normal(size=(D,1))

Now let's compute the queries, keys, and values for each input

![attention 12.2](../../public/attention122.png)
![attention 12.4](../../public/attention124.png)

In [4]:
# Make three lists to store queries, keys, and values
all_queries = []
all_keys = []
all_values = []
# For every input
# all_x is 4x3, i.e 4 dimensional with 3 samples in total, each col is a sample
all_x2 = np.zeros((D, N))  # 4x3 input matrix, 4 dimen and 3 samples
for idx, x in enumerate(all_x):
  all_x2[:, idx] = x[:,0]

# matrix multiplication method
all_queries2 = beta_q + np.dot(omega_q, all_x2)
all_keys2 = beta_k + np.dot(omega_k, all_x2)
all_values2 = beta_v + np.dot(omega_v, all_x2)

for x in all_x:
  # TODO -- compute the keys, queries and values.
  # Replace these three lines
  # these are 3x4, i.e each row is the result of weights * a_sample
  query = beta_q + np.dot(omega_q, x)
  key = beta_k + np.dot(omega_k, x)
  value = beta_v + np.dot(omega_v, x)

  all_queries.append(query)
  all_keys.append(key)
  all_values.append(value)

aq = np.array([a[:,0] for a in all_queries]).T
ak = np.array([a[:,0] for a in all_keys]).T
av = np.array([a[:,0] for a in all_values]).T
print(f"This shows there are rounding errors, so using array_equals or (==).all() will not work as that will just return false. Use np.isclose().all() or np.allclose()\n{aq - all_queries2}\n{ak - all_keys2}\n{av - all_values2}")

assert np.allclose(aq, all_queries2)
assert np.allclose(ak, all_keys2)
assert np.allclose(av, all_values2)

This shows there are rounding errors, so using array_equals or (==).all() will not work as that will just return false. Use np.isclose().all() or np.allclose()
[[ 0.00000000e+00 -4.44089210e-16  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00 -1.11022302e-16  0.00000000e+00]]
[[ 0.00000000e+00  5.55111512e-17  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 8.88178420e-16 -1.04083409e-17  5.55111512e-17]
 [ 0.00000000e+00  2.22044605e-16  1.11022302e-16]]
[[ 4.44089210e-16  0.00000000e+00  2.22044605e-16]
 [ 0.00000000e+00  0.00000000e+00 -5.55111512e-17]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00]]


We'll need a softmax function (equation 12.5) -- here, it will take a list of arbitrary numbers and return a list where the elements are non-negative and sum to one


![softmax 12.5](../../public/notebook125prince.png)

In [5]:
def softmax(k_m, q_n, items_in):

  # TODO Compute the elements of items_out
  # Replace this line
  items_out = np.exp(np.dot(k_m[:, 0], q_n[:, 0])) / np.sum(np.exp(items_in))

  return items_out

Now compute the self attention values:

![softmax 12.5](../../public/notebook125prince.png)
![self attention 12.3](../../public/notebook123prince.png)

In [6]:
# Create emptymlist for output
all_x_prime = []

# For each output
# Create list for dot products of query N with all keys
# Note that there is alot of unecessary loops because we aren't vectorizing
for n in range(N):
  # Compute the dot products, this corresponds to all the dot products needed in the denominator in 12.5
  all_dots_km_over_given_qn = [np.dot(key[:, 0], all_queries[n][:, 0]) for key in all_keys]
  # Compute attention, but I think Prince wants kmqn/normalizing_constant over all m, that way we get a vector and can check that it sums to one. This is just a sanity check to check that the softmax method works
  attention = [softmax(all_keys[m], all_queries[n], all_dots_km_over_given_qn) for m in range(N)]
  # Print result (should be positive sum to one)
  print("Attentions for output ", n)
  print(attention)
  print(f"Check that it sums to one: {np.sum(attention)}")
  # TODO: Compute a weighted sum of all of the values according to the attention
  # (equation 12.3)
  # Replace this line
  # This line computes the attention and the weighted sum with the computed attention at the same time
  all_x_prime.append(sum([softmax(all_keys[m], all_queries[n], all_dots_km_over_given_qn) * all_values[m] for m in range(N)]))

# Print out true values to check you have it correct
print("x_prime_0_calculated:", all_x_prime[0].transpose())
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_calculated:", all_x_prime[1].transpose())
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_calculated:", all_x_prime[2].transpose())
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")


Attentions for output  0
[1.2432614615724094e-13, 0.9982814887008525, 0.0017185112990231368]
Check that it sums to one: 1.0
Attentions for output  1
[2.7952530620087617e-12, 0.0058550635983758564, 0.994144936398829]
Check that it sums to one: 1.0
Attentions for output  2
[0.005057079072941125, 0.006547760717181933, 0.9883951602098769]
Check that it sums to one: 1.0
x_prime_0_calculated: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_calculated: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_calculated: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]


Now let's compute the same thing, but using matrix calculations.  We'll store the $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ in the columns of a $D\times N$ matrix, using equations 12.6 and 12.7/8.

Note:  The book uses column vectors (for compatibility with the rest of the text), but in the wider literature it is more normal to store the inputs in the rows of a matrix;  in this case, the computation is the same, but all the matrices are transposed and the operations proceed in the reverse order.

![12.6 and 12.7](../../public/attention126127prince.png)

In [7]:
# Define softmax operation that works independently on each column
def softmax_cols(data_in):
  # Exponentiate all the values
  # this should be K^T * Q i.e. if K and Q are DxD both, then we get DxD matrix
  exp_values = np.exp(data_in)
  # Sum over columns
  denom = np.sum(exp_values, axis = 0)
  # Replicate denominator to N rows, this is an outer product
  # However numpy knows broadcasting, so it would work without having to undergo this process
  # I guess the author does this for pedagogical reasons
  # denom = np.matmul(np.ones((data_in.shape[0],1)), denom[np.newaxis,:])
  # Compute softmax
  # Normalize column-wise because of how K^TQ fills the matrix, going by the formulas 12.5 and 12.3 of how the matrix are filled (write it out),
  # the matrix formulation is ultimately filled in the same way in the end, so then it will be clear why we want to normalize the columns
  softmax = exp_values / denom
  # return the answer
  return softmax

In [8]:
 # Now let's compute self attention in matrix form
def self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):

  # TODO -- Write this function
  # 1. Compute queries, keys, and values
  # 2. Compute dot products
  # 3. Apply softmax to calculate attentions
  # 4. Weight values by attentions
  # Replace this line
  Q = beta_q + np.dot(omega_q, X)
  K = beta_k + np.dot(omega_k, X)
  V = beta_v + np.dot(omega_v, X)
  attention = softmax_cols(np.dot(K.T, Q))
  X_prime = np.dot(V, attention)


  return X_prime, attention

In [9]:
# Copy data into matrix
X = np.zeros((D, N))
X[:,0] = np.squeeze(all_x[0])
X[:,1] = np.squeeze(all_x[1])
X[:,2] = np.squeeze(all_x[2])

# Run the self attention mechanism
X_prime, attention_matrix = self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime.T)
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")

[[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]
 [ 1.64201168 -0.08470004  4.02764044  2.18690791]
 [ 1.61949281 -0.06641533  3.96863308  2.15858316]]
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]


If you did this correctly, the values should be the same as above.

TODO:  

Print out the attention matrix
You will see that the values are quite extreme (one is very close to one and the others are very close to zero.  Now we'll fix this problem by using scaled dot-product attention.

In [10]:
print(attention_matrix)

[[1.24326146e-13 2.79525306e-12 5.05707907e-03]
 [9.98281489e-01 5.85506360e-03 6.54776072e-03]
 [1.71851130e-03 9.94144936e-01 9.88395160e-01]]


<span style="color:green;white-space:pre-wrap">This is something I noticed a few were much higher, even very close to 1 than the others. Because of this the rest of the probabilities gets heavily diluted to the point of almost being zero.</span>

![12.9](../../public/attention129prince.png)

In [11]:
# Now let's compute self attention in matrix form
def scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):

  # TODO -- Write this function
  # 1. Compute queries, keys, and values
  # 2. Compute dot products
  # 3. Scale the dot products as in equation 12.9
  # 4. Apply softmax to calculate attentions
  # 5. Weight values by attentions
  # Replace this line
  Q = beta_q + omega_q @ X
  K = beta_k + omega_k @ X
  V = beta_v + omega_v @ X
  row_in_query_and_key = Q.shape[0]
  attention = softmax_cols((K.T @ Q) / np.sqrt(row_in_query_and_key))
  X_prime = V @ attention

  return X_prime

In [12]:
# Run the self attention mechanism
X_prime = scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime.T)
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")

[[ 0.97411966 -0.23738409 -0.72333202 -0.34413007]
 [ 1.59622051 -0.09516106  3.70194096  2.01339538]
 [ 1.32638014  0.13062402  3.02371664  1.6902419 ]]
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]


<span style="color:green;white-space:pre-wrap">It's expected of the values to be different now simply because we have scaled them.</span>

TODO -- Investigate whether the self-attention mechanism is covariant with respect to permutation.
If it is, when we permute the columns of the input matrix $\mathbf{X}$, the columns of the output matrix $\mathbf{X}'$ will also be permuted.


In [13]:
# X_permuted = X[:, np.random.permutation(X.shape[1])]  # not reliable as too few samples, so there might not even be any switching of cols
X_permuted = X[:, [1,0,2]]
# Run the self attention mechanism
X_prime, attention_matrix = self_attention(X_permuted,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime.T)
print("\nTrue answers for the original matrix, i.e. the non-permutated matrix")
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")

[[ 1.64201168 -0.08470004  4.02764044  2.18690791]
 [ 0.94744244 -0.24348429 -0.91310441 -0.44522983]
 [ 1.61949281 -0.06641533  3.96863308  2.15858316]]

True answers for the original matrix, i.e. the non-permutated matrix
x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]
x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]
x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]


<span style="color:green;white-space:pre-wrap">As can be seen the permutation is actually reflected: switching the first and second columns in input data switches the first and second column in the output matrix (we have transposed this to rows to match the solution text), one could easily show this mathematically by just explicitly inspecting the columns or rows, and track what happens at the output matrix.</span>