In [1]:
import mlx.core as mx
import mlx.nn as nn
import math

In [None]:
class LlamaAttention(nn.Module):
    def __init__(self, dims: int, num_heads: int):
        super().__init__()

        self.num_heads = num_heads # Number of heads in the multi-head attention

        self.rope = nn.RoPE(dims // num_heads, traditional=True) # Relative positional encoding meaning that the keys and queries will be augmented with the relative positional encoding
        self.query_proj = nn.Linear(dims, dims, bias=False) # Linear projection for the queries
        self.key_proj = nn.Linear(dims, dims, bias=False) # Linear projection for the keys
        self.value_proj = nn.Linear(dims, dims, bias=False) # Linear projection for the values
        self.out_proj = nn.Linear(dims, dims, bias=False) # Linear projection for the output

    def __call__(self, queries, keys, values, mask=None, cache=None):
        queries = self.query_proj(queries) # Project the queries
        keys = self.key_proj(keys)  # Project the keys
        values = self.value_proj(values)    # Project the values

        # Extract some shapes
        num_heads = self.num_heads # Number of heads
        B, L, D = queries.shape # Batch size, sequence length and dimensionality for the queries

        # Prepare the queries, keys and values for the attention computation
        queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) # Reshape the queries to have the shape (B, num_heads, L, D // num_heads) for the attention computation
        print(queries)
        keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) # Reshape the keys to have the shape (B, num_heads, L, D // num_heads) for the attention computation
        values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) # Reshape the values to have the shape (B, num_heads, L, D // num_heads) for the attention computation

        # Add RoPE to the queries and keys and combine them with the cache
        if cache is not None:
            key_cache, value_cache = cache # Unpack the cache
            queries = self.rope(queries, offset=key_cache.shape[2]) # Add RoPE to the queries
            keys = self.rope(keys, offset=key_cache.shape[2]) # Add RoPE to the keys
            keys = mx.concatenate([key_cache, keys], axis=2) # Concatenate the keys with the cache
            values = mx.concatenate([value_cache, values], axis=2) # Concatenate the values with the cache
        else: # If there is no cache
            queries = self.rope(queries) # Add RoPE to the queries
            keys = self.rope(keys) # Add RoPE to the keys

        # Finally perform the attention computation
        scale = math.sqrt(1 / queries.shape[-1]) # Compute the scale for the attention computation
        scores = (queries * scale) @ keys.transpose(0, 1, 3, 2) # Compute the attention scores
        if mask is not None: # If there is a mask
            scores = scores + mask # Add the mask to the scores
        scores = mx.softmax(scores, axis=-1) # Compute the attention weights
        values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1) # Compute the output values

        # Note that we return the keys and values to possibly be used as a cache
        return self.out_proj(values_hat), (keys, values) # Return the output values and the keys and values to be used as a cache