In [34]:
import numpy as np
import matplotlib.pyplot as plt

In [35]:
# set seed so we get the same random numbers
np.random.seed(3)
# number of inputs
N = 3
# number of dimensions of each input
D = 4
# create an empty list
all_x = []
# create elements in x_n and append to list
for n in range(N):
    all_x.append(np.random.normal(size = (D, 1)))

# print out the list
print(all_x)

[array([[ 1.78862847],
       [ 0.43650985],
       [ 0.09649747],
       [-1.8634927 ]]), array([[-0.2773882 ],
       [-0.35475898],
       [-0.08274148],
       [-0.62700068]]), array([[-0.04381817],
       [-0.47721803],
       [-1.31386475],
       [ 0.88462238]])]


In [36]:
# choose weights and biases for the keys, queries, and values
np.random.seed(0)

# choose random values for the parameters 
omega_q = np.random.normal(size = (D, D))
omega_k = np.random.normal(size = (D, D))
omega_v = np.random.normal(size = (D, D))
beta_q = np.random.normal(size = (D, 1))
beta_k = np.random.normal(size = (D, 1))
beta_v = np.random.normal(size = (D, 1))

In [37]:
# make three lists to store queries, keys, and values
all_queries = []
all_keys = []
all_values = []

# for every input
for x in all_x:
    query = beta_q + np.dot(omega_q, x)
    key = beta_k + np.dot(omega_k, x)
    value = beta_v + np.dot(omega_v, x)

    all_queries.append(query)
    all_keys.append(key)
    all_values.append(value)

In [22]:
# need a softmax 
def softmax(dot_products):
    res = np.exp(dot_products) / np.sum(np.exp(dot_products))
    return res

# create empty list for output
all_x_prime = []

# for each output
for n in range(N):
    # create a list for dot products of query N with all keys
    all_km_qn = []
    for key in all_keys:
        dot_product = np.dot(np.transpose(key), query[n])

        all_km_qn.append(dot_product)
    
    # compute dot product attention
    attention = softmax(all_km_qn)

    # print result (should be positive)
    print("Attentions for output ", n)
    print(attention)

    # compute a weighted sum of all the values according to the attention
    x_prime_sum = 0
    for n in range(N):
        x_prime_sum += np.dot(attention, all_values[n])
        
    all_x_prime.append(x_prime_sum)

# print out the true values to check you have it correct
print("x_prime_0_calculated:", all_x_prime[0].transpose())
print("x_prime_0_true: [[ 0.94744244 -0.24348429 -0.91310441 -0.44522983]]")
print("x_prime_1_calculated:", all_x_prime[1].transpose())
print("x_prime_1_true: [[ 1.64201168 -0.08470004  4.02764044  2.18690791]]")
print("x_prime_2_calculated:", all_x_prime[2].transpose())
print("x_prime_2_true: [[ 1.61949281 -0.06641533  3.96863308  2.15858316]]")

ValueError: shapes (1,4,3) and (1,) not aligned: 3 (dim 2) != 1 (dim 0)

In [42]:
# try the matrix version provided by the writer 
# Define softmax operation that works independently on each column
def softmax_cols(data_in):
  # Exponentiate all of the values
  exp_values = np.exp(data_in)
  # Sum over columns
  denom = np.sum(exp_values, axis = 0)
  # Replicate denominator to N rows
  denom = np.matmul(np.ones((data_in.shape[0],1)), denom[np.newaxis,:])
  # Compute softmax
  softmax = exp_values / denom
  # return the answer
  return softmax

In [49]:
# compute self-attention in matrix form 
def self_attention(X, omega_v, omega_q, omega_k, beta_v, beta_q, beta_k):
    one_T = np.transpose(np.ones((N, 1)))
    queries = np.add(np.dot(beta_q, one_T), np.dot(omega_q, X))
    keys = np.add(np.dot(beta_k, one_T), np.dot(omega_k, X))
    values = np.add(np.dot(beta_v, one_T), np.dot(omega_v, X))
    dot_prod = np.dot(np.transpose(keys), queries)
    softmax_values = softmax_cols(dot_prod)
    X_prime = np.dot(values, softmax_values)

    return X_prime

# Copy data into matrix
X = np.zeros((D, N))
X[:,0] = np.squeeze(all_x[0])
X[:,1] = np.squeeze(all_x[1])
X[:,2] = np.squeeze(all_x[2])

# Run the self attention mechanism
X_prime = self_attention(X,omega_v, omega_q, omega_k, beta_v, beta_q, beta_k)

# Print out the results
np.set_printoptions(suppress=True)
print(X_prime)

[[ 0.94744244  1.64201168  1.61949281]
 [-0.24348429 -0.08470004 -0.06641533]
 [-0.91310441  4.02764044  3.96863308]
 [-0.44522983  2.18690791  2.15858316]]


In [29]:
wan_T = np.transpose(np.ones((N, 1)))
print(wan_T.shape)

(1, 3)
