<a href="https://colab.research.google.com/github/udlbook/udlbook/blob/main/Notebooks/Chap12/12_1_Self_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Notebook 12.1: Self Attention**

This notebook builds a self-attention mechanism from scratch, as discussed in section 12.2 of the book.

Work through the cells below, running each cell in turn. In various places you will see the words "TODO". Follow the instructions at these places and make predictions about what is going to happen or write code to complete the functions.

Contact me at udlbookmail@gmail.com if you find any mistakes or have any suggestions.



In [11]:
import numpy as np
import matplotlib.pyplot as plt

The self-attention mechanism maps $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ and returns $N$ outputs $\mathbf{x}'_{n}\in \mathbb{R}^{D}$.  



In [45]:
# Set seed so we get the same random numbers
np.random.seed(3)
# Number of inputs
N = 3
# Number of dimensions of each input
D = 4
# Create an empty list
all_x = []
# Create elements x_n and append to list
for n in range(N):
  x = np.random.normal(size=(D,1))
  all_x.append(x)
  print(f"Input vector x_{n} (shape {x.shape}):")
  print(f"    {x.flatten()}")

print(f"\nOur complete input has {len(all_x)} vectors:")


Input vector x_0 (shape (4, 1)):
    [ 1.78862847  0.43650985  0.09649747 -1.8634927 ]
Input vector x_1 (shape (4, 1)):
    [-0.2773882  -0.35475898 -0.08274148 -0.62700068]
Input vector x_2 (shape (4, 1)):
    [-0.04381817 -0.47721803 -1.31386475  0.88462238]

Our complete input has 3 vectors:


We'll also need the weights and biases for the keys, queries, and values (equations 12.2 and 12.4)

In [46]:
# Set seed so we get the same random numbers
np.random.seed(0)

# Choose random values for the parameters
omega_q = np.random.normal(size=(D,D))
omega_k = np.random.normal(size=(D,D))
omega_v = np.random.normal(size=(D,D))
beta_q = np.random.normal(size=(D,1))
beta_k = np.random.normal(size=(D,1))
beta_v = np.random.normal(size=(D,1))

Now let's compute the queries, keys, and values for each input

In [51]:
# Make three lists to store queries, keys, and values
all_queries = []
all_keys = []
all_values = []

# For every input
for i, x in enumerate(all_x):
  print(f"\nTransforming input x_{i}:")
  print(f"    Original x_{i}: {x.flatten()}")

  # Compute the keys, queries and values
  query = beta_q + omega_q @ x
  key = beta_k + omega_k @ x
  value = beta_v + omega_v @ x

  print(f"        {beta_q.shape} + {omega_q.shape} * {x.shape}")
  print(f"       Query_{i} = β_q + Ω_q @ x_{i}")
  print(f"        Shape: {query.shape}, Values: {query.flatten()}")
  print(f"       Key_{i} = β_k + Ω_k @ x_{i}")
  print(f"        Shape: {key.shape}, Values: {key.flatten()}")
  print(f"       Value_{i} = β_v + Ω_v @ x_{i}")
  print(f"        Shape: {value.shape}, Values: {value.flatten()}")

  all_queries.append(query)
  all_keys.append(key)
  all_values.append(value)

print(f"\n  Summary after transformation:")
print(f"    We now have {len(all_queries)} queries, {len(all_keys)} keys, and {len(all_values)} values")
print(f"    Each query/key/value has shape: {all_queries[0].shape}")


Transforming input x_0:
    Original x_0: [ 1.78862847  0.43650985  0.09649747 -1.8634927 ]
        (4, 1) + (4, 4) * (4, 1)
       Query_0 = β_q + Ω_q @ x_0
        Shape: (4, 1), Values: [-2.36543342  3.07476988 -3.59698468  1.22226059]
       Key_0 = β_k + Ω_k @ x_0
        Shape: (4, 1), Values: [ 3.69380505 -3.9952365   3.7499519   3.12154293]
       Value_0 = β_v + Ω_v @ x_0
        Shape: (4, 1), Values: [-2.71096662  3.5538184  -6.92955207 -3.0352825 ]

Transforming input x_1:
    Original x_1: [-0.2773882  -0.35475898 -0.08274148 -0.62700068]
        (4, 1) + (4, 4) * (4, 1)
       Query_1 = β_q + Ω_q @ x_1
        Shape: (4, 1), Values: [-3.7312083  -0.36779138 -1.93624723 -0.11330562]
       Key_1 = β_k + Ω_k @ x_1
        Shape: (4, 1), Values: [-0.34284839 -0.31052676 -0.02825783 -0.76803995]
       Value_1 = β_v + Ω_v @ x_1
        Shape: (4, 1), Values: [ 0.94623971 -0.24375925 -0.92165993 -0.44978771]

Transforming input x_2:
    Original x_2: [-0.04381817 -0.47721803 

We'll need a softmax function (equation 12.5) -- here, it will take a list of arbitrary numbers and return a list where the elements are non-negative and sum to one


In [52]:
def softmax(items_in):
  print(f"       Softmax input: {items_in}")
  e_x = np.exp(items_in - np.max(items_in))
  items_out = e_x / e_x.sum()
  print(f"       Softmax output: {items_out} (sum = {items_out.sum():.6f})")
  return items_out

Now compute the self attention values:

In [55]:
# Create empty list for output
all_x_prime = []

# For each output
for n in range(N):
  print(f"\n  Computing output x'_{n} (the {n}-th attended representation):")
  print(f"    We'll see how much x'_{n} should 'pay attention' to each input")

  # Create list for dot products of query N with all keys
  all_km_qn = []
  print(f"\n       Step 1: Query_{n} meets all Keys (computing compatibility)")

  # Compute the dot products
  for m, key in enumerate(all_keys):
    print(f"        Query_{n} • Key_{m}:")
    print(f"            Query_{n}: {all_queries[n].flatten()}")
    print(f"            Key_{m}: {key.flatten()}")

    dot_product = key.T @ all_queries[n]
    scalar_dot = dot_product[0,0]  # Extract scalar from 1x1 matrix

    print(f"            Dot product: {scalar_dot:.6f}")
    print(f"            (This measures how 'compatible' query_{n} is with key_{m})")

    all_km_qn.append(scalar_dot)

  # Convert the list of dot products to a numpy array
  all_km_qn = np.array(all_km_qn)
  print(f"\n      All compatibility scores for output {n}: {all_km_qn}")

  # Compute attention
  print(f"\n      Step 2: Converting compatibility to attention weights (softmax)")
  attention = softmax(all_km_qn)

  print(f"       Final attention weights for output {n}: {attention}")
  print(f"       Interpretation:")
  for m in range(N):
    percentage = attention[m] * 100
    print(f"        Output {n} pays {percentage:.2f}% attention to input {m}")

  print(f"\n       Step 3: Creating the weighted combination of values")
  x_prime = np.zeros((D,1))
  print(f"        Starting with zero vector: {x_prime.flatten()}")

  for m in range(N):
    contribution = attention[m] * all_values[m]
    print(f"        + {attention[m]:.6f} × Value_{m}")
    print(f"          = {attention[m]:.6f} × {all_values[m].flatten()}")
    print(f"          = {contribution.flatten()}")
    x_prime += contribution
    print(f"        Running sum: {x_prime.flatten()}")

  print(f"\n       Final result x'_{n}: {x_prime.flatten()}")
  all_x_prime.append(x_prime)

print("\n" + "="*80)
print("  The Final Results")
print("="*80)

# Print out true values to check you have it correct
print("x_prime_0_calculated:", all_x_prime[0].transpose())
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_calculated:", all_x_prime[1].transpose())
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_calculated:", all_x_prime[2].transpose())
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")


  Computing output x'_0 (the 0-th attended representation):
    We'll see how much x'_0 should 'pay attention' to each input

       Step 1: Query_0 meets all Keys (computing compatibility)
        Query_0 • Key_0:
            Query_0: [-2.36543342  3.07476988 -3.59698468  1.22226059]
            Key_0: [ 3.69380505 -3.9952365   3.7499519   3.12154293]
            Dot product: -30.695063
            (This measures how 'compatible' query_0 is with key_0)
        Query_0 • Key_1:
            Query_0: [-2.36543342  3.07476988 -3.59698468  1.22226059]
            Key_1: [-0.34284839 -0.31052676 -0.02825783 -0.76803995]
            Dot product: -0.980915
            (This measures how 'compatible' query_0 is with key_1)
        Query_0 • Key_2:
            Query_0: [-2.36543342  3.07476988 -3.59698468  1.22226059]
            Key_2: [-1.64524855 -3.17297146  0.34070328 -0.20908514]
            Dot product: -7.345492
            (This measures how 'compatible' query_0 is with key_2)

      

Now let's compute the same thing, but using matrix calculations.  We'll store the $N$ inputs $\mathbf{x}_{n}\in\mathbb{R}^{D}$ in the columns of a $D\times N$ matrix, using equations 12.6 and 12.7/8.

Note:  The book uses column vectors (for compatibility with the rest of the text), but in the wider literature it is more normal to store the inputs in the rows of a matrix;  in this case, the computation is the same, but all the matrices are transposed and the operations proceed in the reverse order.

In [56]:
# Define softmax operation that works independently on each column
def softmax_cols(data_in):
  # Exponentiate all of the values
  exp_values = np.exp(data_in) ;
  # Sum over columns
  denom = np.sum(exp_values, axis = 0);
  # Replicate denominator to N rows
  denom = np.matmul(np.ones((data_in.shape[0],1)), denom[np.newaxis,:])
  # Compute softmax
  softmax = exp_values / denom
  # return the answer
  return softmax

In [64]:
# Now let's compute self attention in matrix form
def self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):

  # TODO -- Write this function
  # 1. Compute queries, keys, and values
  # 2. Compute dot products
  # 3. Apply softmax to calculate attentions
  # 4. Weight values by attentions
  # Replace this line

  # Queries, Keys and Values
  Q = beta_q @ np.ones((1,N)) + omega_q @ X
  K = beta_k @ np.ones((1,N)) + omega_k @ X
  V = beta_v @ np.ones((1,N)) + omega_v @ X

  # Dot product
  dot_product = K.T @ Q

  # Apply softmax
  attention = softmax_cols(dot_product)

  # Weight values
  X_prime = V @ attention

  return X_prime

In [65]:
# Copy data into matrix
X = np.zeros((D, N))
X[:,0] = np.squeeze(all_x[0])
X[:,1] = np.squeeze(all_x[1])
X[:,2] = np.squeeze(all_x[2])

# Run the self attention mechanism
X_prime = self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)

[[ 0.94744244  1.64201168  1.61949281]
 [-0.24348429 -0.08470004 -0.06641533]
 [-0.91310441  4.02764044  3.96863308]
 [-0.44522983  2.18690791  2.15858316]]


If you did this correctly, the values should be the same as above.

TODO:  

Print out the attention matrix
You will see that the values are quite extreme (one is very close to one and the others are very close to zero.  Now we'll fix this problem by using scaled dot-product attention.

In [63]:
print(attention)

[0.00505708 0.00654776 0.98839516]


In [66]:
# Now let's compute self attention in matrix form
def scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):

  # TODO -- Write this function
  # 1. Compute queries, keys, and values
  # 2. Compute dot products
  # 3. Scale the dot products as in equation 12.9
  # 4. Apply softmax to calculate attentions
  # 5. Weight values by attentions
  # Replace this line

  # Compute Queries, Keys and Values
  Q = beta_q @ np.ones((1, N)) + omega_q @ X
  K = beta_k @ np.ones((1, N)) + omega_k @ X
  V = beta_v @ np.ones((1, N)) + omega_v @ X

  # Compute dot product
  dot_product = K.T @ Q

  # Scale the dot product
  scaled_dot_product = dot_product / np.sqrt(omega_q.shape[0])

  # Apply softmax to calculate attentions
  attention = softmax_cols(scaled_dot_product)

  # Multiply Value matrix with attention
  x_prime = V @ attention

  return X_prime

In [67]:
# Run the self attention mechanism
X_prime = scaled_dot_product_self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
print(X_prime)

[[ 0.94744244  1.64201168  1.61949281]
 [-0.24348429 -0.08470004 -0.06641533]
 [-0.91310441  4.02764044  3.96863308]
 [-0.44522983  2.18690791  2.15858316]]


TODO -- Investigate whether the self-attention mechanism is covariant with respect to permutation.
If it is, when we permute the columns of the input matrix $\mathbf{X}$, the columns of the output matrix $\mathbf{X}'$ will also be permuted.
