# Reproduce GPT2 with MLX

This document provides a detailed breakdown of the implementation and training process for a transformer-based model, including the GPT model architecture, multi-head attention mechanism, multi-layer perceptron (MLP), and individual transformer blocks in Apple's MLX. The explanations are designed to help you understand each component's role and functionality within the broader model structure. By the end of this document, you should have a comprehensive understanding of how to implement and train a transformer model and code on in MLX, as well as how each piece fits together to form a powerful machine learning model.

## Table of Contents

- [GPT Model Implementation](#gpt-model-implementation)
  - [GPT Model Class Definition](#gpt-model-class-definition)
  - [Initialization Method (`__init__`)](#initialization-method-__init__)
  - [Generate Method (`generate`)](#generate-method-generate)
  - [Parameter Initialization Method (`_init_parameters`)](#parameter-initialization-method-_init_parameters)
  - [Forward Pass Method (`__call__`)](#forward-pass-method-__call__)

- [Multi-Head Attention Implementation](#multi-head-attention-implementation)
  - [Multi-Head Attention Class Definition](#multi-head-attention-class-definition)
  - [Initialization Method (`__init__`)](#initialization-method-__init__)
  - [Forward Pass Method (`__call__`)](#forward-pass-method-__call__)

- [Multi-Layer Perceptron (MLP) Implementation](#multi-layer-perceptron-mlp-implementation)
  - [MLP Class Definition](#mlp-class-definition)
  - [Attributes](#attributes)
  - [Initialization Method (`__init__`)](#initialization-method-__init__)
  - [Forward Pass Method (`__call__`)](#forward-pass-method-__call__)

- [Transformer Block Implementation](#transformer-block-implementation)
  - [Block Class Definition](#block-class-definition)
  - [Initialization Method (`__init__`)](#initialization-method-__init__)
  - [Forward Pass Method (`__call__`)](#forward-pass-method-__call__)

- [Training Loop Implementation](#training-loop-implementation)
  - [Training Loop Overview](#training-loop-overview)
  - [Training Phase](#training-phase)
  - [Evaluation Phase](#evaluation-phase)
  - [Logging](#logging)
  - [Conclusion](#conclusion)

In [None]:
!pip install -q mlx numpy

## Preparing the data

Install mlx and run the following imports.

In [420]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils as utils
import numpy as np
import math

The first step to training an LLM is collecting a large corpus of text data and then tokenizing it. Tokenization is the process of mapping text to integers, which can be fed into the LLM. Our training corpus for this model will be the works of Shakespeare concatenated into one file. This is roughly 1 million characters and looks like this:

In [421]:
with open('../input.txt', 'r') as f:
    text = f.read()

print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


First, we read the file as a single long string into the text variable. Then we use the set() function to get all the unique characters in the text which will be our vocabulary. By printing vocab you can see all the characters in our vocabulary as one string, and we have a total of 65 characters which till be our tokens.

In [422]:
vocab = set(text)
vocab

{'\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z'}

We'll then wrap a list around it:

In [423]:
vocab = list(set(text))
vocab

['G',
 '\n',
 'B',
 'v',
 'L',
 "'",
 ';',
 '!',
 't',
 'a',
 ',',
 'Z',
 'C',
 ':',
 'W',
 'R',
 '$',
 'F',
 'x',
 'I',
 'i',
 ' ',
 'c',
 'J',
 'Y',
 'Q',
 'l',
 'p',
 '?',
 'O',
 '&',
 'X',
 'w',
 'g',
 'T',
 'q',
 'd',
 '-',
 'K',
 'y',
 '.',
 'D',
 'U',
 'P',
 'N',
 'e',
 'j',
 'z',
 'o',
 'k',
 'M',
 'b',
 'm',
 'E',
 'A',
 '3',
 'h',
 'r',
 'n',
 'S',
 'H',
 'V',
 'u',
 'f',
 's']

Finally, we'll sort it

In [424]:
vocab = sorted(list(set(text)))
vocab

['\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [425]:
vocab_size = len(vocab)
print(f"vocab_size: {vocab_size}")

print(''.join(vocab))

vocab_size: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [426]:
# Create mapping from vocab to integers
itos = {i:c for i,c in enumerate(vocab)} # int to string
stoi = {c:i for i,c in enumerate(vocab)} # string to int
encode = lambda x: [stoi[c] for c in x] # encode string to int
decode = lambda x: ''.join([itos[i] for i in x]) # decode int to string

print(encode("hello world"))
# [46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
print(decode(encode("hello world")))
# hello world

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world


In [427]:
data = encode(text) # encode the entire text
split = int(0.9 * len(data)) # 90% train, 10% validation
train_data = data[:split] # first 90%
val_data = data[split:] # last 10%

In [428]:
ctx_len = 4
print(train_data[:ctx_len + 1])
# [18, 47, 56, 57, 58,  1, 15, 47, 58]
# x: [18, 47, 56, 57, 58,  1, 15, 47] | y: 58

# 8 sub examples
# [18] --> 47
# [18, 47] --> 56
# [18, 47, 56] --> 57
# [18, 47, 56, 57] --> 58
# [18, 47, 56, 57, 58] --> 1
# [18, 47, 56, 57, 58, 1] --> 15
# [18, 47, 56, 57, 58, 1, 15] --> 47
# [18, 47, 56, 57, 58, 1, 15, 47] --> 58

[18, 47, 56, 57, 58]


In [429]:
print("inputs: ", train_data[:ctx_len])
print("labels: ", train_data[1:ctx_len+1]) # labels = inputs indexed 1 higher
# inputs: [18, 47, 56, 57, 58,  1, 15, 47]
# labels: [47, 56, 57, 58,  1, 15, 47, 58]

inputs:  [18, 47, 56, 57]
labels:  [47, 56, 57, 58]


In [430]:
# Creating training and validation datasets
ctx_len = 4
X_train = mx.array([train_data[i:i+ctx_len] for i in range(0, len(train_data) - ctx_len, ctx_len)]) 
y_train = mx.array([train_data[i+1:i+ctx_len+1] for i in range(0, len(train_data) - ctx_len, ctx_len)]) 
X_val = mx.array([val_data[i:i+ctx_len] for i in range(0, len(val_data) - ctx_len, ctx_len)])
y_val = mx.array([val_data[i+1:i+ctx_len+1] for i in range(0, len(val_data) - ctx_len, ctx_len)])

In [431]:
def get_batches(X, y, b_size, shuffle=True):
    if shuffle:
        ix = np.arange(X.shape[0])
        np.random.shuffle(ix)
        ix = mx.array(ix)
        X = X[ix]
        y = y[ix]
    for i in range(0, X.shape[0], b_size):
        input = X[i:i+b_size]
        label = y[i:i+b_size]
        yield input, label

You'll notice the output below for the label array, 'y', is shifted to the left by 1 in order to predict the next token

In [432]:
batch = get_batches(X_train, y_train, 1)
for X, y in batch:
    print(X)
    print(y)
    break

array([[40, 43, 42, 57]], dtype=int32)
array([[43, 42, 57, 10]], dtype=int32)


In [433]:
ctx_len = 128
n_emb = 128
dropout = 0.1
head_size = 128
n_heads = 4 
n_layers = 3 
num_epochs = 20
batch_size = 64
lr = 1e-3

In [434]:
# class Attention(nn.Module):
#     def __init__(self, head_size):
#         super().__init__()
#         self.head_size = head_size # Define the head size of the attention mechanism 
#         self.k_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the key projection
#         self.q_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the query projection
#         self.v_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the value projection
#         indices = mx.arange(ctx_len) # Create a tensor with values from 0 to ctx_len - 1
#         print(f"indices: \n {indices} \n")
#         mask = indices[:, None] < indices[None] # If the value of the first tensor is less than the value of the second tensor, the value of the mask tensor is True, otherwise False which means that the mask tensor is a lower triangular matrix
#         print(f"mask: \n {mask} \n")
#         self._causal_mask = mask * -1e9 # Multiply the mask tensor by -1e9 to get a tensor with -1e9 where the value of the first tensor is less than the value of the second tensor
#         print(f"mask: \n {self._causal_mask} \n")
#         self.c_proj = nn.Linear(head_size, n_emb) # output projection layer to get the output of the attention mechanism
#         self.resid_dropout = nn.Dropout(dropout) # Define the dropout layer for the residual connection
#     def __call__(self, x): # shapes commented
#         B, T, C = x.shape # (batch_size, ctx_len, n_emb) - x is the input tensor
#         K = self.k_proj(x) # (B, T, head_size) - Project the keys
#         Q = self.q_proj(x) # (B, T, head_size) - Project the queries
#         V = self.v_proj(x) # (B, T, head_size) - Project the values
#         attn_weights = (Q @ K.transpose([0, 2, 1])) / math.sqrt(self.head_size) # We use K.transpose([0, 2, 1]) to transpose the second and third dimensions of K. This is because we want to multiply the queries with the keys. The shape of the attention weights is (B, T, T) 
#         # attn_weights.shape = (B, T, T)
#         attn_weights = attn_weights + self._causal_mask # Add the causal mask to the attention weights
#         attn_weights = mx.softmax(attn_weights, axis=-1) # Apply the softmax function to the attention weights to get the attention scores
#         o = (attn_weights @ V) # (B, T, head_size) - Multiply the attention scores with the values to get the output
#         o = self.c_proj(self.resid_dropout(o)) # (B, T, n_emb) - Apply the output projection layer to the output
#         return o # Return the output of the attention mechanism which will be used as the input to the feedforward neural network

## Multi-Head Attention Class

The `MultiHeadAttention` class is a subclass of `nn.Module`, designed to perform the attention mechanism across multiple heads in parallel. This allows the model to focus on different parts of the input sequence simultaneously.

### Initialization Method (`__init__`)

The initialization method sets up the following components:

- **Head Size (`self.head_size`)**: Defines the size of each attention head. The total dimension of the attention mechanism is divided among these heads.
- **Key Projection (`self.k_proj`)**: A linear layer that projects the input embeddings into key vectors of the specified `head_size`.
- **Query Projection (`self.q_proj`)**: A linear layer that projects the input embeddings into query vectors of the specified `head_size`.
- **Value Projection (`self.v_proj`)**: A linear layer that projects the input embeddings into value vectors of the specified `head_size`.
- **Causal Mask (`self._causal_mask`)**: A mask that ensures the model cannot attend to future tokens in the sequence. It is implemented as a lower triangular matrix with large negative values (`-1e9`) to effectively nullify certain attention weights.
- **Output Projection (`self.c_proj`)**: A linear layer that projects the concatenated output of all heads back into the original embedding dimension (`n_emb`).
- **Attention Dropout (`self.attn_dropout`)**: A dropout layer applied to the attention weights to prevent overfitting.
- **Residual Dropout (`self.resid_dropout`)**: A dropout layer applied to the output of the attention mechanism before the final projection.

### Forward Pass Method (`__call__`)

The `__call__` method defines the forward pass of the multi-head attention mechanism:

- **Input Shape**: The input tensor `x` has the shape `(B, T, C)` where `B` is the batch size, `T` is the context length (number of tokens), and `C` is the embedding dimension.
- **Key, Query, Value Projections**: The input tensor is projected into key (`K`), query (`Q`), and value (`V`) tensors using the respective linear layers. Each projection has the shape `(B, T, head_size)`.
- **Reshape for Multi-Head Attention**: The projected tensors are reshaped to enable multi-head attention. The new shape is `(B, n_heads, T, head_size//n_heads)`, allowing each head to focus on different parts of the sequence.
- **Attention Weight Calculation**: The attention weights are computed by taking the dot product of the query and key matrices, followed by scaling by the square root of the key dimension. The shape of the attention weights is `(B, n_heads, T, T)`.
- **Causal Masking**: The causal mask is added to the attention weights to prevent attending to future tokens.
- **Softmax and Dropout**: The masked attention weights are passed through a softmax function to normalize them into probabilities. Dropout is then applied to these attention weights.
- **Weighted Sum of Values**: The attention weights are used to compute a weighted sum of the value vectors, producing the attention output for each head.
- **Output Reshape and Projection**: The multi-head attention output is reshaped back to the original embedding dimensions and passed through the output projection layer. Dropout is applied before the final projection.
- **Final Output**: The method returns the output of the attention mechanism, which will be used as input to the subsequent layers, typically a feedforward neural network.

This implementation of multi-head attention allows the model to attend to different parts of the input sequence simultaneously, enhancing its ability to capture complex dependencies within the data.

In [435]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.head_size = head_size # Define the head size of the attention mechanism 
        self.k_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the key projection
        self.q_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the query projection
        self.v_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the value projection
        indices = mx.arange(ctx_len) # Create a tensor with values from 0 to ctx_len - 1
        # print(f"indices: \n {indices} \n")
        mask = indices[:, None] < indices[None] # If the value of the first tensor is less than the value of the second tensor, the value of the mask tensor is True, otherwise False which means that the mask tensor is a lower triangular matrix
        # print(f"mask: \n {mask} \n")
        self._causal_mask = mask * -1e9 # Multiply the mask tensor by -1e9 to get a tensor with -1e9 where the value of the first tensor is less than the value of the second tensor
        # print(f"mask: \n {self._causal_mask} \n")
        self.c_proj = nn.Linear(head_size, n_emb) # output projection layer to get the output of the attention mechanism
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout) # Define the dropout layer for the residual connection

    # Define the forward pass of the model
    def __call__(self, x): # shapes commented
        B, T, C = x.shape # (batch_size, ctx_len, n_emb) - x is the input tensor
        K = self.k_proj(x) # (B, T, head_size) - Project the keys
        Q = self.q_proj(x) # (B, T, head_size) - Project the queries
        V = self.v_proj(x) # (B, T, head_size) - Project the values
        mha_shape = (B, T, n_heads, head_size//n_heads) # This is the shape of the multi-head attention mechanism because we want to split the head_size into n_heads
        # print(f"mha_shape: \n {mha_shape} \n")
        K = mx.as_strided(K, (mha_shape)).transpose([0, 2, 1, 3]) # We use mx.as_strided to create a view of the K tensor with the shape (B, n_heads, T, head_size//n_heads) and then transpose the dimensions to get the desired shape
        Q = mx.as_strided(Q, (mha_shape)).transpose([0, 2, 1, 3]) # We use mx.as_strided to create a view of the Q tensor with the shape (B, n_heads, T, head_size//n_heads) and then transpose the dimensions to get the desired shape
        V = mx.as_strided(V, (mha_shape)).transpose([0, 2, 1, 3]) # We use mx.as_strided to create a view of the V tensor with the shape (B, n_heads, T, head_size//n_heads) and then transpose the dimensions to get the desired shape
        attn_weights = (Q @ K.transpose([0, 1, 3, 2])) / math.sqrt(Q.shape[-1]) # We use K.transpose([0, 1, 3, 2]) to transpose the second and third dimensions of K. This is because we want to multiply the queries with the keys. The shape of the attention weights is (B, n_heads, T, T)
        # print(f"attn_weights: \n {attn_weights} \n")
        attn_weights = attn_weights + self._causal_mask[:T, :T] # Add the causal mask to the attention weights
        # print(f"attn_weights + casual mask: \n {attn_weights} \n")
        attn_weights = mx.softmax(attn_weights, axis=-1) # Apply the softmax function to the attention weights to get the attention scores
        # print(f"softmax attn_weights: \n {attn_weights} \n")
        attn_weights = self.attn_dropout(attn_weights) # Apply the dropout layer to the attention weights
        # print(f"dropout attn_weights: \n {attn_weights} \n")
        o = (attn_weights @ V) # (B, n_heads, T, head_size//n_heads) - Multiply the attention scores with the values to get the output
        # print(f"output: \n {o} \n")
        o = o.transpose([0, 2, 1, 3]).reshape((B, T, head_size)) # We transpose the dimensions of the output and then reshape it to get the desired shape
        # print(f"output reshaped: \n {o} \n")
        o = self.c_proj(self.resid_dropout(o)) # (B, T, n_emb) - Apply the output projection layer to the output
        # print(f"output projection: \n {o} \n")
        return o # Return the output of the attention mechanism which will be used as the input to the feedforward neural network

## MLP Class Definition

The `MLP` class is a subclass of `nn.Module`, which defines a straightforward feedforward neural network used within a larger model, such as a transformer.

### Attributes

- **Fully Connected Layer (`c_fc`)**: This is a linear layer that expands the input dimension from `n_emb` to `4 * n_emb`. It performs a fully connected operation on the input tensor.
- **GELU Activation (`gelu`)**: The GELU (Gaussian Error Linear Unit) activation function is applied after the fully connected layer to introduce non-linearity into the model.
- **Projection Layer (`c_proj`)**: Another linear layer that projects the expanded dimension back down to the original embedding size (`n_emb`). This layer helps in controlling the model's capacity.
- **Dropout Layer (`dropout`)**: A dropout layer is applied for regularization, helping to prevent overfitting by randomly setting some of the activations to zero during training.

### Initialization Method (`__init__`)

The initialization method sets up the layers of the MLP:

- **c_fc**: This linear layer takes an input of size `n_emb` and expands it to `4 * n_emb`. This expansion allows the network to capture more complex patterns.
- **gelu**: The GELU activation function is applied to the output of the `c_fc` layer, introducing non-linearity and enabling the model to learn more complex relationships.
- **c_proj**: This linear layer reduces the dimension back to `n_emb`, balancing the expansion performed by `c_fc`.
- **dropout**: Dropout is applied after the projection layer to regularize the network, reducing the risk of overfitting.

### Forward Pass Method (`__call__`)

The `__call__` method defines the forward pass of the MLP:

- **Input to Fully Connected Layer**: The input tensor `x` is passed through the `c_fc` layer, which increases its dimensionality to `4 * n_emb`.
- **GELU Activation**: The expanded tensor is then passed through the GELU activation function, introducing non-linearity.
- **Projection Layer**: The activated tensor is passed through the `c_proj` layer, reducing its dimensionality back to `n_emb`.
- **Dropout**: Dropout is applied to the output of the projection layer for regularization.
- **Output**: The method returns the output tensor, which is now ready to be used in subsequent layers of the model.

This implementation of the MLP is commonly used in transformer models, where it acts as the feedforward network within each transformer block, helping the model learn complex mappings between the input and output spaces.

In [436]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.c_fc = nn.Linear(n_emb, 4 * n_emb)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * n_emb, n_emb)
        self.dropout = nn.Dropout(dropout)
    def __call__(self, x):
        x = self.gelu(self.c_fc(x))
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

## Transformer Block Class Definition

The `Block` class is a subclass of `nn.Module` that defines a single transformer block. This block combines a multi-head attention mechanism with a feedforward neural network (MLP) and includes layer normalization and residual connections.

### Initialization Method (`__init__`)

The initialization method sets up the following components:

- **Multi-Layer Perceptron (`self.mlp`)**: An instance of the `MLP` class, which is a simple feedforward neural network used within the transformer block.
- **Multi-Head Attention (`self.mha`)**: An instance of the `MultiHeadAttention` class, which allows the block to attend to different parts of the input sequence simultaneously.
- **Layer Normalization 1 (`self.ln_1`)**: A layer normalization layer that normalizes the input tensor before it is passed to the multi-head attention mechanism.
- **Layer Normalization 2 (`self.ln_2`)**: Another layer normalization layer that normalizes the input tensor before it is passed to the feedforward neural network (MLP).

### Forward Pass Method (`__call__`)

The `__call__` method defines the forward pass of the transformer block:

- **Multi-Head Attention with Residual Connection**: 
  - The input tensor `x` is first normalized using `ln_1`.
  - The normalized tensor is then passed through the multi-head attention mechanism (`mha`).
  - The output of the attention mechanism is added to the original input tensor, creating a residual connection. This helps in preserving the original information while also incorporating the attention-modified information.

- **Feedforward Neural Network (MLP) with Residual Connection**:
  - The output from the previous step is normalized using `ln_2`.
  - The normalized tensor is then passed through the MLP (`mlp`).
  - The output of the MLP is added to the input tensor from the previous step, creating another residual connection. This further refines the information while maintaining the integrity of the original input.

- **Final Output**: The method returns the final output tensor, which has been processed by both the multi-head attention mechanism and the feedforward neural network, with normalization and residual connections at each step.

This transformer block is a core component of transformer models, where multiple such blocks are stacked to form a deep network capable of learning complex patterns and dependencies in the data.

In [437]:
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = MLP() # Feedforward neural network
        self.mha = MultiHeadAttention() # Multi-head attention mechanism
        self.ln_1 = nn.LayerNorm(dims=n_emb) # Layer normalization layer
        self.ln_2 = nn.LayerNorm(dims=n_emb) # Layer normalization layer
    def __call__(self, x): 
        x = x + self.mha(self.ln_1(x)) # Add the output of the multi-head attention mechanism to the input tensor
        x = x + self.mlp(self.ln_2(x)) # Add the output of the feedforward neural network to the input tensor 
        return x # Return the output of the block

In [438]:
def loss_fn(model, x, y):
    logits = model(x)
    B, T, C = logits.shape # (batch_size, seq_len, vocab_size)
    logits = logits.reshape(B*T, C)
    y = y.reshape(B*T)
    loss = nn.losses.cross_entropy(logits, y, reduction='mean')
    return loss

## GPT Model Class Definition

The `GPT` class is a subclass of `nn.Module`, which defines the architecture and operations of the GPT model. The model includes an embedding layer for both tokens and positions, a series of transformer blocks, and a final layer normalization followed by a linear projection for output.

### Initialization Method (`__init__`)

The initialization method sets up the following components:

- **Word Embedding (`self.wte`)**: This is a lookup table that maps each token in the vocabulary to its corresponding embedding vector.
- **Position Embedding (`self.wpe`)**: This embedding layer assigns a unique embedding vector to each position in the input sequence.
- **Transformer Blocks (`self.blocks`)**: A sequence of transformer blocks that processes the input embeddings. The number of blocks is determined by `n_layers`.
- **Layer Normalization (`self.ln_f`)**: A normalization layer applied to the output of the transformer blocks.
- **Output Projection (`self.lm_head`)**: A linear layer that projects the final hidden state into the vocabulary space, producing logits for each token in the vocabulary.
- **Parameter Initialization (`self._init_parameters`)**: This method initializes the parameters of the model, including weights and biases, using specific initialization strategies.

### Generate Method (`generate`)

The `generate` method is responsible for generating text by predicting the next token iteratively:

- **Context Initialization**: A tensor of zeros is created to serve as the initial context.
- **Token Generation Loop**: For each step in the generation process, the current context is passed through the model to generate logits, which are used to sample the next token.
- **Context Update**: The sampled token is appended to the context, and the process repeats until the desired number of tokens is generated.

### Parameter Initialization Method (`_init_parameters`)

This method initializes the model's parameters with specific distributions:

- **Normal Initialization (`normal_init`)**: Used for most parameters, with a mean of 0 and a standard deviation of 0.02.
- **Residual Initialization (`residual_init`)**: Applied to residual layers, with a mean of 0 and a standard deviation adjusted by the number of layers.
- **Bias Initialization**: Biases are initialized to zero.
- **Updating Parameters**: The initialized parameters are applied to the model.

### Forward Pass Method (`__call__`)

The `__call__` method defines the forward pass of the GPT model:

- **Token Embeddings**: The input tokens are converted to embeddings using the word embedding layer.
- **Position Embeddings**: Positional information is added by summing token embeddings with position embeddings.
- **Transformer Blocks**: The combined embeddings are passed through the series of transformer blocks.
- **Layer Normalization**: The output of the transformer blocks is normalized.
- **Output Logits**: The final output is projected into the vocabulary space to produce logits for each token in the vocabulary.

This method returns the logits, which can be used for various tasks such as text generation, classification, or translation.

In [439]:

class GPT(nn.Module): # Define the GPT model
    def __init__(self):
        super().__init__() # Call the __init__ of the parent class
        self.wte = nn.Embedding(vocab_size, n_emb) # Lookup table for embeddings of each token in the vocab (word to embedding) -  n_emb means the size of the embedding vector
        self.wpe = nn.Embedding(ctx_len, n_emb) # Lookup table for embeddings of each position in the context (position to embedding)
        self.blocks = nn.Sequential(
            *[Block() for _ in range(n_layers)],
        ) # transformer blocks - n_layers means the number of transformer blocks
        self.ln_f = nn.LayerNorm(dims=n_emb) # final layernorm
        # print(f"layernorm: \n {self.ln_f} \n")
        self.lm_head = nn.Linear(n_emb, vocab_size) # output projection
        # print(f"lm_head: \n {self.lm_head} \n")
        self._init_parameters() # Initialize the parameters of the model
        # print total number of params on initialization
        total_params = sum([p.size for n,p in utils.tree_flatten(self.parameters())]) # Get the total number of parameters
        # print(f"Total params: {(total_params / 1e6):.3f}M") # Print the total number of parameters in millions

    # method of GPT class
    def generate(self, max_new_tokens):
        ctx = mx.zeros((1, 1), dtype=mx.int32) # (1, 1) - Create a context tensor with zeros
        for _ in range(max_new_tokens): # Loop through the number of tokens to generate
            logits = self(ctx[:, -ctx_len:]) # pass in last ctx_len characters to get the next token
            logits = logits[:, -1, :] # get logits for the next token only
            next_tok = mx.random.categorical(logits, num_samples=1) # sample the next token
            ctx = mx.concatenate((ctx, next_tok), axis=1) # append the next token to the context
        return ctx # return the context
    
    # method of GPT
    def _init_parameters(self):
        normal_init = nn.init.normal(mean=0.0, std=0.02) # Initialize the weights with a normal distribution
        residual_init = nn.init.normal(mean=0.0, std=(0.02 / math.sqrt(2 * n_layers))) # Initialize the residuals with a normal distribution
        new_params = [] # Create a list to store the new parameters
        # print(f"named_modules: \n {self.named_modules()} \n")
        for name, module in self.named_modules(): # Loop through the modules of the model
            if isinstance(module, nn.layers.linear.Linear): # Check if the module is a linear layer
                if 'c_proj' in name: # residual projection layer
                    new_params.append((name + '.weight', residual_init(module.weight))) # Initialize the weights of the residual projection layer
                else:
                    new_params.append((name + '.weight', normal_init(module.weight))) # Initialize the weights of the linear layer
                if 'bias' in module: # Check if the module has a bias
                    new_params.append((name + '.bias', mx.zeros(module.bias.shape))) # Initialize the bias with zeros
            elif isinstance(module, nn.layers.embedding.Embedding): # Check if the module is an embedding layer
                new_params.append((name + '.weight', normal_init(module.weight))) # Initialize the weights of the embedding layer
        self = self.update(utils.tree_unflatten(new_params)) # Update the model with the new parameters

    # Define the forward pass of the model
    def __call__(self, x):
        B, T = x.shape # (B = batch_size, T = ctx_len). x is the input tensor
        # print(f"input tensor: \n {x} \n")
        # print(f"x.shape: \n {x.shape} \n")
        tok_emb = self.wte(x) # (B, T, n_emb) - Get the embeddings of the tokens
        # print(f"token embedding: \n {tok_emb} \n")
        pos_emb = self.wpe(mx.arange(T)) # (T, n_emb) - Get the embeddings of the positions.  arange(T) creates a tensor with values from 0 to T-1 because T is the length of the context and minus 1 because the index starts from 0.
        # how it works is that the first position will have the first embedding, the second position will have the second embedding, and so on.
        # print(f"position embedding: \n {pos_emb} \n")
        x = tok_emb + pos_emb # (B, T, n_emb) - Add the token and position embeddings
        # print(f"token + position embedding: \n {x} \n")
        x = self.blocks(x) # (B, T, n_emb) - Pass the embeddings through the transformer blocks
        x = self.ln_f(x) # (B, T, b_emb) - Apply the final layer norm
        logits = self.lm_head(x) # (B, T, vocab_size) - Get the logits for the next token prediction
        return logits

In [440]:
model = GPT()
mx.eval(model.parameters()) # Create the model params (mlx is lazy evaluation)
loss_and_grad = nn.value_and_grad(model, loss_fn)
lr = 0.1
optimizer = optim.AdamW(learning_rate=lr) # We use the AdamW optimizer

## Training Phase

1. **Epoch Iteration**: 
   - The loop begins by iterating over a predefined number of epochs (`num_epochs`).
   - At the start of each epoch, the model is set to training mode using `model.train(True)`.

2. **Batch Processing**:
   - The training data is divided into batches using the `get_batches` function, which yields pairs of input data and corresponding labels.
   - For each batch:
     - The loss and gradients are computed by passing the input and label through the model using the `loss_and_grad` function.
     - The optimizer updates the model parameters based on the computed gradients.
     - The running loss, which accumulates the total loss over the batches, is updated by adding the loss of the current batch.

3. **Parameter and Optimizer State Evaluation**:
   - After processing each batch, the model parameters and optimizer state are evaluated using `mx.eval`. This step ensures that the model parameters are synchronized correctly.

4. **Average Training Loss**:
   - At the end of the training phase for the epoch, the average training loss is calculated by dividing the accumulated running loss by the number of batches (`batch_cnt`).

### Evaluation Phase

1. **Set Evaluation Mode**:
   - After the training phase, the model is set to evaluation mode using `model.train(False)`. This disables dropout and other training-specific layers.

2. **Validation Batch Processing**:
   - The validation data is also divided into batches using the `get_batches` function.
   - For each validation batch:
     - The loss is computed using the `loss_fn` function, which does not calculate gradients (since we are only evaluating the model).
     - The running loss accumulates the total loss over the validation batches.

3. **Average Validation Loss**:
   - At the end of the validation phase for the epoch, the average validation loss is calculated similarly to the training loss.

### Logging

- **Epoch Summary**:
  - At the end of each epoch, the average training and validation losses are printed out. This helps in monitoring the model's performance and identifying potential overfitting or underfitting.


This training loop effectively combines gradient descent optimization with batch processing and periodic evaluation to train a model over multiple epochs. The use of running averages and epoch-level summaries helps in assessing the model's learning progress.

In [441]:
for epoch in range(num_epochs):
    model.train(True)
    running_loss = 0
    batch_cnt = 0
    for input, label in get_batches(X_train, y_train, batch_size):
        batch_cnt += 1
        loss, grads = loss_and_grad(model, input, label)
        optimizer.update(model, grads)
        running_loss += loss.item()
        # compute new parameters and optimizer state
        mx.eval(model.parameters(), optimizer.state)
    avg_train_loss = running_loss / batch_cnt
    model.train(False) # set eval mode
    running_loss = 0
    batch_cnt = 0
    for input, label in get_batches(X_val, y_val, batch_size):
        batch_cnt += 1
        loss = loss_fn(model, input, label)
        running_loss += loss.item()
    avg_val_loss = running_loss / batch_cnt
    print(f"Epoch {epoch:2} | train = {avg_train_loss:.4f} | val = {avg_val_loss:.4f}")

KeyboardInterrupt: 

In [None]:
completion = decode(model.generate(1000)[0].tolist())
print(completion)
with open('completions.txt', 'w') as f:
    f.write(completion)