In [1]:
import mlx.core as mx

# Attending to different parts of the input with self-attention

In [2]:
inputs = mx.array(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [10]:
query = inputs[1]
attn_scores_2 = mx.zeros(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = mx.matmul(x_i, query)
attn_scores_2

array([0.9544, 1.495, 1.4754, 0.8434, 0.707, 1.0865], dtype=float32)

In [11]:
res = 0.
for i, element in enumerate(inputs[0]):
    res += inputs[0][i] * query[i]
res, mx.matmul(inputs[0], query)

(array(0.9544, dtype=float32), array(0.9544, dtype=float32))

In [12]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("attention weights:", attn_weights_2_tmp)
print("sum:", attn_weights_2_tmp.sum())

attention weights: array([0.14545, 0.227837, 0.22485, 0.128534, 0.107746, 0.165582], dtype=float32)
sum: array(1, dtype=float32)


In [None]:
def softmax_naive(x):
    return mx.exp(x) / mx.exp(x).sum(axis=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("attention weights (naive):", attn_weights_2_naive)
print("sum:", attn_weights_2_naive.sum())

attn_weights_2 = mx.softmax(attn_scores_2, axis=0)
print("attention weights (mlx):", attn_weights_2)
print("sum:", attn_weights_2.sum())

attention weights (naive): array([0.138548, 0.237891, 0.233274, 0.123992, 0.108182, 0.158114], dtype=float32)
sum: array(1, dtype=float32)
attention weights (mlx): array([0.138548, 0.237891, 0.233274, 0.123992, 0.108182, 0.158114], dtype=float32)
sum: array(1, dtype=float32)


In [15]:
query = inputs[1]
context_vec_2 = mx.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
context_vec_2

array([0.441866, 0.651482, 0.568309], dtype=float32)

In [17]:
attn_scores = mx.zeros((6, 6))
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = mx.matmul(x_i, x_j)
attn_scores

array([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.631],
       [0.9544, 1.495, 1.4754, 0.8434, 0.707, 1.0865],
       [0.9422, 1.4754, 1.457, 0.8296, 0.7154, 1.0605],
       [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
       [0.4576, 0.707, 0.7154, 0.3474, 0.6654, 0.2935],
       [0.631, 1.0865, 1.0605, 0.6565, 0.2935, 0.945]], dtype=float32)

In [19]:
attn_scores = inputs @ inputs.T
attn_scores

array([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.631],
       [0.9544, 1.495, 1.4754, 0.8434, 0.707, 1.0865],
       [0.9422, 1.4754, 1.457, 0.8296, 0.7154, 1.0605],
       [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
       [0.4576, 0.707, 0.7154, 0.3474, 0.6654, 0.2935],
       [0.631, 1.0865, 1.0605, 0.6565, 0.2935, 0.945]], dtype=float32)

In [21]:
attn_weights = mx.softmax(attn_scores, axis=-1)
attn_weights

array([[0.209835, 0.200581, 0.198149, 0.124228, 0.122049, 0.145158],
       [0.138548, 0.237891, 0.233274, 0.123992, 0.108182, 0.158114],
       [0.139008, 0.236921, 0.232602, 0.124204, 0.1108, 0.156464],
       [0.143527, 0.207394, 0.204552, 0.146192, 0.126295, 0.172039],
       [0.152611, 0.195839, 0.197491, 0.136687, 0.187859, 0.129514],
       [0.138471, 0.218364, 0.212759, 0.142048, 0.0988064, 0.189552]], dtype=float32)

In [28]:
row_2_sum = sum([0.138548, 0.237891, 0.233274, 0.123992, 0.108182, 0.158114])
print('row 2 sum:', row_2_sum)
print('all row sums:', attn_weights.sum(axis=-1))

row 2 sum: 1.000001
all row sums: array([1, 1, 1, 1, 1, 1], dtype=float32)


In [29]:
all_context_vecs = attn_weights @ inputs
all_context_vecs

array([[0.442059, 0.593099, 0.578989],
       [0.441866, 0.651482, 0.568309],
       [0.443128, 0.649595, 0.567073],
       [0.43039, 0.629828, 0.551027],
       [0.467102, 0.590993, 0.526596],
       [0.417724, 0.650323, 0.564535]], dtype=float32)

# Implementing self-attention with trainable weights

In [31]:
import mlx.nn as nn

In [30]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [37]:
mx.random.seed(123)
W_query = nn.init.uniform()(mx.zeros((d_in, d_out)))
W_key = nn.init.uniform()(mx.zeros((d_in, d_out)))
W_value = nn.init.uniform()(mx.zeros((d_in, d_out)))

In [38]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
query_2, key_2, value_2

(array([0.730349, 1.27947], dtype=float32),
 array([1.31264, 1.31876], dtype=float32),
 array([0.699694, 0.660281], dtype=float32))

In [39]:
keys = inputs @ W_key
values = inputs @ W_value
keys.shape, values.shape

((6, 2), (6, 2))