In [1]:
import mlx.core as mx

# Attending to different parts of the input with self-attention

In [2]:
inputs = mx.array(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
query = inputs[1]
attn_scores_2 = mx.zeros(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = mx.matmul(x_i, query)
attn_scores_2

array([0.9544, 1.495, 1.4754, 0.8434, 0.707, 1.0865], dtype=float32)

In [4]:
res = 0.
for i, element in enumerate(inputs[0]):
    res += inputs[0][i] * query[i]
res, mx.matmul(inputs[0], query)

(array(0.9544, dtype=float32), array(0.9544, dtype=float32))

In [5]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()
print("attention weights:", attn_weights_2_tmp)
print("sum:", attn_weights_2_tmp.sum())

attention weights: array([0.14545, 0.227837, 0.22485, 0.128534, 0.107746, 0.165582], dtype=float32)
sum: array(1, dtype=float32)


In [6]:
def softmax_naive(x):
    return mx.exp(x) / mx.exp(x).sum(axis=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)
print("attention weights (naive):", attn_weights_2_naive)
print("sum:", attn_weights_2_naive.sum())

attn_weights_2 = mx.softmax(attn_scores_2, axis=0)
print("attention weights (mlx):", attn_weights_2)
print("sum:", attn_weights_2.sum())

attention weights (naive): array([0.138548, 0.237891, 0.233274, 0.123992, 0.108182, 0.158114], dtype=float32)
sum: array(1, dtype=float32)
attention weights (mlx): array([0.138548, 0.237891, 0.233274, 0.123992, 0.108182, 0.158114], dtype=float32)
sum: array(1, dtype=float32)


In [7]:
query = inputs[1]
context_vec_2 = mx.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i
context_vec_2

array([0.441866, 0.651482, 0.568309], dtype=float32)

In [8]:
attn_scores = mx.zeros((6, 6))
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = mx.matmul(x_i, x_j)
attn_scores

array([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.631],
       [0.9544, 1.495, 1.4754, 0.8434, 0.707, 1.0865],
       [0.9422, 1.4754, 1.457, 0.8296, 0.7154, 1.0605],
       [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
       [0.4576, 0.707, 0.7154, 0.3474, 0.6654, 0.2935],
       [0.631, 1.0865, 1.0605, 0.6565, 0.2935, 0.945]], dtype=float32)

In [9]:
attn_scores = inputs @ inputs.T
attn_scores

array([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.631],
       [0.9544, 1.495, 1.4754, 0.8434, 0.707, 1.0865],
       [0.9422, 1.4754, 1.457, 0.8296, 0.7154, 1.0605],
       [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
       [0.4576, 0.707, 0.7154, 0.3474, 0.6654, 0.2935],
       [0.631, 1.0865, 1.0605, 0.6565, 0.2935, 0.945]], dtype=float32)

In [10]:
attn_weights = mx.softmax(attn_scores, axis=-1)
attn_weights

array([[0.209835, 0.200581, 0.198149, 0.124228, 0.122049, 0.145158],
       [0.138548, 0.237891, 0.233274, 0.123992, 0.108182, 0.158114],
       [0.139008, 0.236921, 0.232602, 0.124204, 0.1108, 0.156464],
       [0.143527, 0.207394, 0.204552, 0.146192, 0.126295, 0.172039],
       [0.152611, 0.195839, 0.197491, 0.136687, 0.187859, 0.129514],
       [0.138471, 0.218364, 0.212759, 0.142048, 0.0988064, 0.189552]], dtype=float32)

In [11]:
row_2_sum = sum([0.138548, 0.237891, 0.233274, 0.123992, 0.108182, 0.158114])
print('row 2 sum:', row_2_sum)
print('all row sums:', attn_weights.sum(axis=-1))

row 2 sum: 1.000001
all row sums: array([1, 1, 1, 1, 1, 1], dtype=float32)


In [12]:
all_context_vecs = attn_weights @ inputs
all_context_vecs

array([[0.442059, 0.593099, 0.578989],
       [0.441866, 0.651482, 0.568309],
       [0.443128, 0.649595, 0.567073],
       [0.43039, 0.629828, 0.551027],
       [0.467102, 0.590993, 0.526596],
       [0.417724, 0.650323, 0.564535]], dtype=float32)

# Implementing self-attention with trainable weights

In [13]:
import mlx.nn as nn

In [14]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [15]:
mx.random.seed(123)
W_query = nn.init.uniform()(mx.zeros((d_in, d_out)))
W_key = nn.init.uniform()(mx.zeros((d_in, d_out)))
W_value = nn.init.uniform()(mx.zeros((d_in, d_out)))

In [16]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
query_2, key_2, value_2

(array([0.730349, 1.27947], dtype=float32),
 array([1.31264, 1.31876], dtype=float32),
 array([0.699694, 0.660281], dtype=float32))

In [17]:
keys = inputs @ W_key
values = inputs @ W_value
keys.shape, values.shape

((6, 2), (6, 2))

In [18]:
keys_2 = keys[1]
attn_score_22 = query_2.__matmul__(keys_2)
attn_score_22, mx.matmul(query_2, keys_2)

(array(2.646, dtype=float32), array(2.646, dtype=float32))

In [19]:
attn_scores_2 = query_2 @ keys.T
attn_scores_2

array([1.35143, 2.646, 2.6213, 1.51807, 1.43814, 1.84981], dtype=float32)

In [20]:
d_k = keys.shape[1]
attn_weights_2 = mx.softmax(attn_scores_2 / d_k**0.05, axis=-1)
attn_weights_2

array([0.0848754, 0.296384, 0.289396, 0.0996982, 0.0922902, 0.137357], dtype=float32)

In [21]:
context_vec_2 = attn_weights_2 @ values
context_vec_2

array([0.591523, 0.540025], dtype=float32)

In [22]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.init.uniform()(mx.zeros((d_in, d_out)))
        self.W_key   = nn.init.uniform()(mx.zeros((d_in, d_out)))
        self.W_value = nn.init.uniform()(mx.zeros((d_in, d_out)))
    
    def forward(self, x):
        queries = x @ self.W_query
        keys    = x @ self.W_key
        valeus  = x @ self.W_value

        attn_scores = queries @ keys.T
        attn_weights = mx.softmax(
            attn_scores / keys.shape[-1]**0.5, axis=-1
        )
        context_vec = attn_weights @ values
        return context_vec
    
    def __call__(self, x):
        return self.forward(x)

In [23]:
mx.random.seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
sa_v1.parameters(), sa_v1(inputs)

({'W_query': array([[0.0292066, 0.190664],
         [0.180269, 0.656488],
         [0.844624, 0.91433]], dtype=float32),
  'W_key': array([[0.967267, 0.386227],
         [0.843854, 0.925712],
         [0.0704427, 0.456013]], dtype=float32),
  'W_value': array([[0.750346, 0.114293],
         [0.131906, 0.66192],
         [0.260978, 0.0326512]], dtype=float32)},
 array([[0.573207, 0.509778],
        [0.576258, 0.517739],
        [0.575252, 0.516262],
        [0.558173, 0.489055],
        [0.547414, 0.471013],
        [0.568997, 0.506183]], dtype=float32))

In [24]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        queries = self.W_query(x)
        keys    = self.W_key(x)
        valeus  = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = mx.softmax(
            attn_scores / keys.shape[-1]**0.5, axis=-1
        )
        context_vec = attn_weights @ values
        return context_vec
    
    def __call__(self, x):
        return self.forward(x)

In [25]:
mx.random.seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
sa_v2.parameters(), sa_v2(inputs)

({'W_query': {'weight': array([[-0.543625, -0.357191, -0.369193],
          [0.180697, 0.397937, 0.478428]], dtype=float32)},
  'W_key': {'weight': array([[0.539553, -0.131374, 0.397048],
          [0.491569, -0.49601, -0.050792]], dtype=float32)},
  'W_value': {'weight': array([[0.289075, -0.445376, -0.425039],
          [0.186969, -0.275999, -0.539648]], dtype=float32)}},
 array([[0.536247, 0.444135],
        [0.535136, 0.44122],
        [0.535003, 0.441506],
        [0.537125, 0.446077],
        [0.533649, 0.449737],
        [0.538212, 0.442919]], dtype=float32))

# Hiding future words with causal attention

In [26]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = mx.softmax(attn_scores / keys.shape[-1]*0.05, axis=-1)
attn_weights

array([[0.166613, 0.166252, 0.166287, 0.166944, 0.167353, 0.166551],
       [0.166577, 0.1661, 0.166146, 0.167058, 0.167589, 0.166529],
       [0.166569, 0.166107, 0.166152, 0.16706, 0.16757, 0.166543],
       [0.166643, 0.166358, 0.166385, 0.166862, 0.167193, 0.16656],
       [0.166447, 0.16638, 0.166393, 0.166994, 0.166961, 0.166826],
       [0.166719, 0.166275, 0.166314, 0.166853, 0.167416, 0.166423]], dtype=float32)

In [27]:
context_length = attn_scores.shape[0]
mask_simple = mx.tril(mx.ones((context_length, context_length)), k=0)
mask_simple

array([[1, 0, 0, 0, 0, 0],
       [1, 1, 0, 0, 0, 0],
       [1, 1, 1, 0, 0, 0],
       [1, 1, 1, 1, 0, 0],
       [1, 1, 1, 1, 1, 0],
       [1, 1, 1, 1, 1, 1]], dtype=float32)

In [28]:
masked_simple = attn_weights * mask_simple
masked_simple

array([[0.166613, 0, 0, 0, 0, 0],
       [0.166577, 0.1661, 0, 0, 0, 0],
       [0.166569, 0.166107, 0.166152, 0, 0, 0],
       [0.166643, 0.166358, 0.166385, 0.166862, 0, 0],
       [0.166447, 0.16638, 0.166393, 0.166994, 0.166961, 0],
       [0.166719, 0.166275, 0.166314, 0.166853, 0.167416, 0.166423]], dtype=float32)

In [29]:
row_sums = masked_simple.sum(axis=-1, keepdims=True)
masked_simple_norm = masked_simple / row_sums
masked_simple_norm

array([[1, 0, 0, 0, 0, 0],
       [0.500717, 0.499283, 0, 0, 0, 0],
       [0.333921, 0.332994, 0.333085, 0, 0, 0],
       [0.250121, 0.249694, 0.249734, 0.25045, 0, 0],
       [0.199775, 0.199694, 0.199709, 0.200431, 0.200391, 0],
       [0.166719, 0.166275, 0.166314, 0.166853, 0.167416, 0.166423]], dtype=float32)

In [30]:
mask = mx.tril(mx.ones((context_length, context_length)), k=0)
masked = mx.where(mask.astype(mx.bool_), attn_scores, -mx.inf)
masked

array([[-0.296727, -inf, -inf, -inf, -inf, -inf],
       [-0.412889, -0.527579, -inf, -inf, -inf, -inf],
       [-0.412098, -0.523246, -0.512378, -inf, -inf, -inf],
       [-0.214446, -0.28282, -0.27638, -0.161949, -inf, -inf],
       [-0.281896, -0.297944, -0.29498, -0.150774, -0.15863, -inf],
       [-0.237691, -0.344331, -0.334893, -0.20546, -0.0706202, -0.308676]], dtype=float32)

In [31]:
attn_weights = mx.softmax(masked / keys.shape[-1]**0.05, axis=-1)
attn_weights

array([[1, 0, 0, 0, 0, 0],
       [0.527668, 0.472332, 0, 0, 0, 0],
       [0.356394, 0.320114, 0.323492, 0, 0, 0],
       [0.254451, 0.238188, 0.239675, 0.267686, 0, 0],
       [0.191076, 0.188136, 0.188676, 0.216876, 0.215236, 0],
       [0.167983, 0.151541, 0.152929, 0.173295, 0.197402, 0.156851]], dtype=float32)

In [32]:
mx.random.seed(123)
dropout = nn.Dropout(0.5)
example = mx.ones((6,6))
dropout(example), example  # non-droped values are scaled by a factor of 1/(1-p) (p=0.5)

(array([[0, 2, 0, 0, 0, 0],
        [0, 0, 0, 0, 2, 2],
        [0, 0, 0, 2, 2, 0],
        [2, 2, 0, 0, 0, 2],
        [2, 2, 0, 0, 2, 2],
        [0, 0, 2, 0, 2, 2]], dtype=float32),
 array([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1]], dtype=float32))

In [33]:
batch = mx.stack((inputs, inputs), axis=0)
batch.shape

(2, 6, 3)

In [34]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # This mask is a constant, not a trainable parameter
        self.mask = mx.tril(mx.ones((context_length, context_length)), k=0) 
    
    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose((0, 2, 1))
        attn_scores = mx.where(
            self.mask.astype(mx.bool_),
            attn_scores,
            -mx.inf
        )
        attn_weights = mx.softmax(attn_scores / keys.shape[-1]**0.05, axis=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec
    
    def __call__(self, x):
        return self.forward(x)

In [35]:
mx.random.seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0).freeze(keys="mask")

In [36]:
context_vecs = ca(batch)
context_vecs, context_vecs.shape

(array([[[-0.320789, -0.44129],
         [-0.409692, -0.465929],
         [-0.434428, -0.468376],
         [-0.408975, -0.422795],
         [-0.306576, -0.326374],
         [-0.340087, -0.348353]],
        [[-0.320789, -0.44129],
         [-0.409692, -0.465929],
         [-0.434428, -0.468376],
         [-0.408975, -0.422795],
         [-0.306576, -0.326374],
         [-0.340087, -0.348353]]], dtype=float32),
 (2, 6, 2))

# Extending single-head attention to multi-head attention

In [37]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qwk_bias=False):
        super().__init__()
        self.heads = [
            CausalAttention(d_in, d_out, context_length, dropout, qwk_bias)
            for _ in range(num_heads)
        ]
    def forward(self, x):
        return mx.concat([head(x) for head in self.heads], axis=-1)
    
    def __call__(self, x):
        return self.forward(x)

In [38]:
mx.random.seed(123)
context_length = batch.shape[1]
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
context_vecs, context_vecs.shape

(array([[[-0.320789, -0.44129, -0.29132, -0.473982],
         [-0.409692, -0.465929, -0.293583, -0.425317],
         [-0.434428, -0.468376, -0.287393, -0.411218],
         [-0.408975, -0.422795, -0.265838, -0.343605],
         [-0.306576, -0.326374, -0.18771, -0.354522],
         [-0.340087, -0.348353, -0.231509, -0.304149]],
        [[-0.320789, -0.44129, -0.29132, -0.473982],
         [-0.409692, -0.465929, -0.293583, -0.425317],
         [-0.434428, -0.468376, -0.287393, -0.411218],
         [-0.408975, -0.422795, -0.265838, -0.343605],
         [-0.306576, -0.326374, -0.18771, -0.354522],
         [-0.340087, -0.348353, -0.231509, -0.304149]]], dtype=float32),
 (2, 6, 4))

In [39]:
# instead of concatenating, create single W_query, W_key, and W_value weight matrices 
# and split those into individual matrices for each attention heads. 
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.mask = mx.tril(mx.ones((context_length, context_length)), k=0)

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # splitting the matrix by addinga num_heads dimension
        keys = keys.reshape(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.reshape(b, num_tokens, self.num_heads, self.head_dim)
        values = values.reshape(b, num_tokens, self.num_heads, self.head_dim)

        # (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose((0, 2, 1, 3))
        queries = queries.transpose((0, 2, 1, 3))
        values = values.transpose((0, 2, 1, 3))

        attn_scores = queries @ keys.transpose((0, 1, 3, 2))
        mask_bool = self.mask.astype(mx.bool_)[:num_tokens, :num_tokens]
        attn_scores = mx.where(mask_bool, attn_scores, -mx.inf)
        attn_weights = mx.softmax(attn_scores / keys.shape[-1]**0.5, axis=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(0, 2, 1, 3)
        # combine heads; self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.reshape(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec
    
    def __call__(self, x):
        return self.forward(x)


In [40]:
mx.array([1,2,3,4]).reshape(1, 4)

array([[1, 2, 3, 4]], dtype=int32)

In [41]:
mx.random.seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2).freeze(keys="mask")
context_vecs = mha(batch)
context_vecs, context_vecs.shape

(array([[[0.409174, 0.713412],
         [0.384418, 0.771063],
         [0.379907, 0.784335],
         [0.402935, 0.767026],
         [0.461338, 0.713255],
         [0.446312, 0.741353]],
        [[0.409174, 0.713412],
         [0.384418, 0.771063],
         [0.379907, 0.784335],
         [0.402935, 0.767026],
         [0.461338, 0.713255],
         [0.446312, 0.741353]]], dtype=float32),
 (2, 6, 2))

In [None]:
from mlx.utils import tree_flatten
sum(v.size for _, v in tree_flatten(mha.parameters()))

60