<a href="https://www.kaggle.com/code/william2020/apple-s-mlx-transformers-pe-multi-head-attn?scriptVersionId=187548653" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Apple's MLX: Input/Positional Embeddings and Multi-head Self-Attention for the Transformer architecture

#### Framework created by Apple.

In this tutorial, we will walk you through the core concepts and functionalities of MLX, starting with the basics of tokenization and positional encoding. Whether you’re a beginner looking to get started with machine learning on Apple hardware or an experienced practitioner seeking to optimize your workflows, this book provides practical examples and step-by-step instructions to help you harness the full potential of MLX.

In [1]:
!pip install -q mlx

In [2]:
import numpy as np
import mlx.core as mx
import mlx.nn as nn
from nltk.tokenize import word_tokenize
from collections import defaultdict

# Step 1: Tokenize the sentence

In [3]:
sentence = "What are the advantages and disadvantages of using a unified memory architecture?"
tokens = word_tokenize(sentence.lower())
print("Tokens:", tokens)

Tokens: ['what', 'are', 'the', 'advantages', 'and', 'disadvantages', 'of', 'using', 'a', 'unified', 'memory', 'architecture', '?']


# Step 2: Create a simple vocabulary and convert tokens to indices

In [4]:
vocab = defaultdict(lambda: len(vocab))
indices = [vocab[token] for token in tokens]
print("Indices:", indices)

Indices: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]


# Step 3: Initialize the embedding layer

In [5]:
embedding_dim = 64
num_embeddings = len(vocab)
embedding_layer = nn.Embedding(num_embeddings, embedding_dim)

# Step 4: Convert indices to MLX array and embed them

In [6]:
input_data = mx.array(indices)
embedded_tokens = embedding_layer(input_data)
print("Embedded Tokens:", embedded_tokens)

Embedded Tokens: array([[-0.241889, 0.0479568, 0.0795218, ..., 0.0124098, 0.0172652, 0.0071785],
       [0.141629, 0.167849, 0.0870575, ..., -0.0376079, -0.191881, 0.148116],
       [0.0208214, 0.00694565, -0.161432, ..., 0.0463976, -0.00566133, -0.160316],
       ...,
       [0.020175, 0.0723507, 0.174394, ..., -0.0944999, -0.113024, -0.164732],
       [-0.220407, -0.0641856, -0.00881927, ..., 0.0507493, 0.0249283, 0.229439],
       [0.013175, 0.064244, 0.0242759, ..., 0.0134159, 0.147243, -0.0463119]], dtype=float32)


# RoPE (Rotary Positional Encoding)

RoPE applies a rotational transformation to the token embeddings based on their positions in the sequence. This transformation uses sinusoidal functions to create a set of rotation matrices that are applied to the embeddings. The result is a set of positionally encoded embeddings that carry rich relative positional information.

In transformer models, understanding the relative positions of tokens within a sequence is crucial for tasks that require contextual understanding, such as language modeling and translation.

# Calculate sequence length

In [7]:
seq_len = embedded_tokens.shape[0]
print("Sequence Length:", seq_len)

Sequence Length: 13


# Generate frequencies for the sinusoidal embeddings

In [8]:
inv_freq = 1.0 / (10000 ** (np.arange(0, embedding_dim, 2).astype(np.float32) / embedding_dim))
print("Inverse Frequencies:", inv_freq)

freqs = mx.array(np.outer(np.arange(seq_len), inv_freq).astype(np.float32))
print("Frequencies:", freqs)

Inverse Frequencies: [1.0000000e+00 7.4989420e-01 5.6234133e-01 4.2169651e-01 3.1622776e-01
 2.3713736e-01 1.7782794e-01 1.3335215e-01 1.0000000e-01 7.4989416e-02
 5.6234129e-02 4.2169649e-02 3.1622779e-02 2.3713736e-02 1.7782794e-02
 1.3335215e-02 9.9999998e-03 7.4989423e-03 5.6234132e-03 4.2169648e-03
 3.1622779e-03 2.3713738e-03 1.7782794e-03 1.3335214e-03 1.0000000e-03
 7.4989418e-04 5.6234130e-04 4.2169649e-04 3.1622779e-04 2.3713738e-04
 1.7782794e-04 1.3335215e-04]
Frequencies: array([[0, 0, 0, ..., 0, 0, 0],
       [1, 0.749894, 0.562341, ..., 0.000237137, 0.000177828, 0.000133352],
       [2, 1.49979, 1.12468, ..., 0.000474275, 0.000355656, 0.000266704],
       ...,
       [10, 7.49894, 5.62341, ..., 0.00237137, 0.00177828, 0.00133352],
       [11, 8.24884, 6.18575, ..., 0.00260851, 0.00195611, 0.00146687],
       [12, 8.99873, 6.7481, ..., 0.00284565, 0.00213394, 0.00160023]], dtype=float32)


# Calculate cosine and sine of the frequencies

In [9]:
cos_pos = mx.cos(freqs)
sin_pos = mx.sin(freqs)
print("Cosine Positional Encoding:", cos_pos)
print("Sine Positional Encoding:", sin_pos)

Cosine Positional Encoding: array([[1, 1, 1, ..., 1, 1, 1],
       [0.540302, 0.731761, 0.846009, ..., 1, 1, 1],
       [-0.416147, 0.0709483, 0.431463, ..., 1, 1, 1],
       ...,
       [-0.839072, 0.347628, 0.790132, ..., 0.999997, 0.999998, 0.999999],
       [0.0044257, -0.384674, 0.995257, ..., 0.999997, 0.999998, 0.999999],
       [0.843854, -0.910606, 0.893862, ..., 0.999996, 0.999998, 0.999999]], dtype=float32)
Sine Positional Encoding: array([[0, 0, 0, ..., 0, 0, 0],
       [0.841471, 0.681561, 0.533168, ..., 0.000237137, 0.000177828, 0.000133352],
       [0.909297, 0.99748, 0.902131, ..., 0.000474275, 0.000355656, 0.000266704],
       ...,
       [-0.544021, 0.937633, -0.612937, ..., 0.00237137, 0.00177828, 0.00133352],
       [-0.99999, 0.923052, -0.0972765, ..., 0.00260851, 0.00195611, 0.00146687],
       [-0.536573, 0.413275, 0.448343, ..., 0.00284564, 0.00213393, 0.00160023]], dtype=float32)


# Split embedded tokens into even and odd parts

 Splits the embedded tokens into even and odd parts.

In [10]:
x1 = embedded_tokens[:, ::2]
x2 = embedded_tokens[:, 1::2]
print("x1 (even):", x1.shape)
print("x2 (odd):", x2.shape)

x1 (even): (13, 32)
x2 (odd): (13, 32)


# Apply rotational transformation

Applies the rotational transformation to the split parts.

In [11]:
x1_new = x1 * cos_pos - x2 * sin_pos
x2_new = x1 * sin_pos + x2 * cos_pos
print("Transformed x1:", x1_new.shape)
print("Transformed x2:", x2_new.shape)

Transformed x1: (13, 32)
Transformed x2: (13, 32)


# Concatenate the new x1 and x2 back together

Concatenates the transformed parts back together to get the final positional encoded embeddings.

In [12]:
positional_encoded_embeddings = mx.concatenate([x1_new, x2_new], axis=-1)
print("Positional Encoded Embeddings:", positional_encoded_embeddings)

Positional Encoded Embeddings: array([[-0.241889, 0.0795218, 0.0822457, ..., 0.189025, 0.0124098, 0.0071785],
       [-0.0647175, 0.0345135, 0.040535, ..., 0.0525666, -0.0375987, 0.14809],
       [-0.0149804, 0.0733039, 0.0653143, ..., -0.102836, 0.0464649, -0.160317],
       ...,
       [0.022432, 0.0585029, 0.149466, ..., 0.117419, -0.0942642, -0.164882],
       [-0.0651605, -0.116964, 0.0153244, ..., -0.269856, 0.0503992, 0.229476],
       [0.0455894, -0.0800577, 0.0880328, ..., 0.00635958, 0.0129789, -0.0460762]], dtype=float32)


# Multi-Head Self-Attention

To enhance the model’s ability to capture different types of relationships, the transformer employs multi-head self-attention. This technique splits the Query, Key, and Value vectors into multiple smaller sub-vectors, each corresponding to a different attention head. The attention mechanism is applied independently to each head, and the results are concatenated and linearly transformed to produce the final output.


In [13]:
# Define the dimensions
num_heads = 8
head_dim = embedding_dim // num_heads

# Sample input embeddings after positional encoding (from previous section)
# For demonstration, we assume `positional_encoded_embeddings` is already defined
input_embeddings = positional_encoded_embeddings

# Add a batch dimension if missing
if len(input_embeddings.shape) == 2:
    input_embeddings = input_embeddings[np.newaxis, :, :]

print("Input Embeddings Shape:", input_embeddings.shape)

Input Embeddings Shape: (1, 13, 64)


# Linear Transformations

This cell applies linear projections to the input embeddings to obtain the Query, Key, and Value matrices.

In [14]:
# Step 1: Linear Transformations
query_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
key_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
value_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)

queries = query_proj(input_embeddings)
keys = key_proj(input_embeddings)
values = value_proj(input_embeddings)

print("Queries Shape:", queries.shape)
print("Keys Shape:", keys.shape)
print("Values Shape:", values.shape)

Queries Shape: (1, 13, 64)
Keys Shape: (1, 13, 64)
Values Shape: (1, 13, 64)


# Split for Multi-Head Attention

This cell reshapes the Query, Key, and Value matrices to prepare them for multi-head attention. It then transposes these matrices to separate the attention heads and prints their shapes.

In [15]:
batch_size, seq_length, _ = queries.shape

queries = queries.reshape(batch_size, seq_length, num_heads, head_dim).transpose(0, 2, 1, 3)
keys = keys.reshape(batch_size, seq_length, num_heads, head_dim).transpose(0, 2, 1, 3)
values = values.reshape(batch_size, seq_length, num_heads, head_dim).transpose(0, 2, 1, 3)

print("Split Queries Shape:", queries.shape)
print("Split Keys Shape:", keys.shape)
print("Split Values Shape:", values.shape)

Split Queries Shape: (1, 8, 13, 8)
Split Keys Shape: (1, 8, 13, 8)
Split Values Shape: (1, 8, 13, 8)


# Scaled Dot-Product Attention

This cell calculates the attention scores by performing scaled dot-product attention. It then applies the softmax function to obtain normalized attention weights and computes the final attention output.

In [16]:
# Step 3: Scaled Dot-Product Attention
dk = head_dim
scores = mx.matmul(queries, keys.transpose(0, 1, 3, 2)) / mx.sqrt(mx.array([dk], dtype=queries.dtype))
attention_weights = nn.softmax(scores, axis=-1)
attention_output = mx.matmul(attention_weights, values)

print("Attention Weights Shape:", attention_weights.shape)
print("Attention Output Shape:", attention_output.shape)

Attention Weights Shape: (1, 8, 13, 13)
Attention Output Shape: (1, 8, 13, 8)


# Combine Heads

This cell combines the output from all attention heads back into a single tensor. It reshapes the combined output to match the original embedding dimensions and prints the shape.

In [17]:
# Step 4: Combine Heads
combined_output = attention_output.transpose(0, 2, 1, 3).reshape(batch_size, seq_length, embedding_dim)
print("Combined Output Shape:", combined_output.shape)

Combined Output Shape: (1, 13, 64)


# Final Linear Transformation

This cell applies a final linear transformation to the combined output from the attention heads.

In [18]:
# Step 5: Final Linear Transformation
output_proj = nn.Linear(embedding_dim, embedding_dim, bias=False)
final_output = output_proj(combined_output)
print("Final Output after Self-Attention Shape:", final_output.shape)

Final Output after Self-Attention Shape: (1, 13, 64)


# Conclusion

In this notebook, we explored the foundational components of the transformer model architecture using the MLX framework, specifically tailored for Apple Silicon. Here’s a summary of what we covered:

1.	Tokenization and Embedding:

	- We began by tokenizing an example sentence into individual tokens using NLTK.
	- Each token was then converted into a unique index using a simple vocabulary.
	- We utilized an embedding layer to convert these token indices into dense vectors, preparing them for 	   further processing.
2.	Rotary Position Embedding (RoPE):

	- We applied the RoPE technique to enhance the input embeddings with positional information.
	- RoPE uses a rotational transformation based on sinusoidal functions to incorporate relative positional data directly into the embeddings.
	
3. Self-Attention Mechanism:
	- Following positional encoding, we implemented the self-attention mechanism, a core component of transformer models.
		- This included:
		- Linear transformations to obtain Query, Key, and Value matrices from the input embeddings.
		- Splitting these matrices into multiple heads for multi-head attention.
		- Calculating attention scores through scaled dot-product attention and applying the softmax function to obtain normalized attention weights.
		- Combining the output from all attention heads and applying a final linear transformation to produce the final self-attention output.