## Preparing the data

Install mlx and run the following imports.

In [353]:
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import mlx.utils as utils
import numpy as np
import math

The first step to training an LLM is collecting a large corpus of text data and then tokenizing it. Tokenization is the process of mapping text to integers, which can be fed into the LLM. Our training corpus for this model will be the works of Shakespeare concatenated into one file. This is roughly 1 million characters and looks like this:

In [354]:
with open('../input.txt', 'r') as f:
    text = f.read()

print(text[:200])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you


First, we read the file as a single long string into the text variable. Then we use the set() function to get all the unique characters in the text which will be our vocabulary. By printing vocab you can see all the characters in our vocabulary as one string, and we have a total of 65 characters which till be our tokens.

In [355]:
vocab = set(text)
vocab

{'\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z'}

We'll then wrap a list around it:

In [356]:
vocab = list(set(text))
vocab

['G',
 '\n',
 'B',
 'v',
 'L',
 "'",
 ';',
 '!',
 't',
 'a',
 ',',
 'Z',
 'C',
 ':',
 'W',
 'R',
 '$',
 'F',
 'x',
 'I',
 'i',
 ' ',
 'c',
 'J',
 'Y',
 'Q',
 'l',
 'p',
 '?',
 'O',
 '&',
 'X',
 'w',
 'g',
 'T',
 'q',
 'd',
 '-',
 'K',
 'y',
 '.',
 'D',
 'U',
 'P',
 'N',
 'e',
 'j',
 'z',
 'o',
 'k',
 'M',
 'b',
 'm',
 'E',
 'A',
 '3',
 'h',
 'r',
 'n',
 'S',
 'H',
 'V',
 'u',
 'f',
 's']

Finally, we'll sort it

In [357]:
vocab = sorted(list(set(text)))
vocab

['\n',
 ' ',
 '!',
 '$',
 '&',
 "'",
 ',',
 '-',
 '.',
 '3',
 ':',
 ';',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z']

In [358]:
vocab_size = len(vocab)
print(f"vocab_size: {vocab_size}")

print(''.join(vocab))

vocab_size: 65

 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [359]:
# Create mapping from vocab to integers
itos = {i:c for i,c in enumerate(vocab)} # int to string
stoi = {c:i for i,c in enumerate(vocab)} # string to int
encode = lambda x: [stoi[c] for c in x] # encode string to int
decode = lambda x: ''.join([itos[i] for i in x]) # decode int to string

print(encode("hello world"))
# [46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
print(decode(encode("hello world")))
# hello world

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world


In [360]:
data = encode(text) # encode the entire text
split = int(0.9 * len(data)) # 90% train, 10% validation
train_data = data[:split] # first 90%
val_data = data[split:] # last 10%

In [361]:
ctx_len = 4
print(train_data[:ctx_len + 1])
# [18, 47, 56, 57, 58,  1, 15, 47, 58]
# x: [18, 47, 56, 57, 58,  1, 15, 47] | y: 58

# 8 sub examples
# [18] --> 47
# [18, 47] --> 56
# [18, 47, 56] --> 57
# [18, 47, 56, 57] --> 58
# [18, 47, 56, 57, 58] --> 1
# [18, 47, 56, 57, 58, 1] --> 15
# [18, 47, 56, 57, 58, 1, 15] --> 47
# [18, 47, 56, 57, 58, 1, 15, 47] --> 58

[18, 47, 56, 57, 58]


In [362]:
print("inputs: ", train_data[:ctx_len])
print("labels: ", train_data[1:ctx_len+1]) # labels = inputs indexed 1 higher
# inputs: [18, 47, 56, 57, 58,  1, 15, 47]
# labels: [47, 56, 57, 58,  1, 15, 47, 58]

inputs:  [18, 47, 56, 57]
labels:  [47, 56, 57, 58]


In [363]:
# Creating training and validation datasets
ctx_len = 4
X_train = mx.array([train_data[i:i+ctx_len] for i in range(0, len(train_data) - ctx_len, ctx_len)]) 
y_train = mx.array([train_data[i+1:i+ctx_len+1] for i in range(0, len(train_data) - ctx_len, ctx_len)]) 
X_val = mx.array([val_data[i:i+ctx_len] for i in range(0, len(val_data) - ctx_len, ctx_len)])
y_val = mx.array([val_data[i+1:i+ctx_len+1] for i in range(0, len(val_data) - ctx_len, ctx_len)])

In [364]:
def get_batches(X, y, b_size, shuffle=True):
    if shuffle:
        ix = np.arange(X.shape[0])
        np.random.shuffle(ix)
        ix = mx.array(ix)
        X = X[ix]
        y = y[ix]
    for i in range(0, X.shape[0], b_size):
        input = X[i:i+b_size]
        label = y[i:i+b_size]
        yield input, label

You'll notice the output below for the label array, 'y', is shifted to the left by 1 in order to predict the next token

In [365]:
batch = get_batches(X_train, y_train, 1)
for X, y in batch:
    print(X)
    print(y)
    break

array([[43, 57, 58, 1]], dtype=int32)
array([[57, 58, 1, 58]], dtype=int32)


In [366]:
ctx_len = 128
n_emb = 128
dropout = 0.1
head_size = 128
n_heads = 4 
n_layers = 3 
num_epochs = 20
batch_size = 64
lr = 1e-3

In [367]:
# class Attention(nn.Module):
#     def __init__(self, head_size):
#         super().__init__()
#         self.head_size = head_size # Define the head size of the attention mechanism 
#         self.k_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the key projection
#         self.q_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the query projection
#         self.v_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the value projection
#         indices = mx.arange(ctx_len) # Create a tensor with values from 0 to ctx_len - 1
#         print(f"indices: \n {indices} \n")
#         mask = indices[:, None] < indices[None] # If the value of the first tensor is less than the value of the second tensor, the value of the mask tensor is True, otherwise False which means that the mask tensor is a lower triangular matrix
#         print(f"mask: \n {mask} \n")
#         self._causal_mask = mask * -1e9 # Multiply the mask tensor by -1e9 to get a tensor with -1e9 where the value of the first tensor is less than the value of the second tensor
#         print(f"mask: \n {self._causal_mask} \n")
#         self.c_proj = nn.Linear(head_size, n_emb) # output projection layer to get the output of the attention mechanism
#         self.resid_dropout = nn.Dropout(dropout) # Define the dropout layer for the residual connection
#     def __call__(self, x): # shapes commented
#         B, T, C = x.shape # (batch_size, ctx_len, n_emb) - x is the input tensor
#         K = self.k_proj(x) # (B, T, head_size) - Project the keys
#         Q = self.q_proj(x) # (B, T, head_size) - Project the queries
#         V = self.v_proj(x) # (B, T, head_size) - Project the values
#         attn_weights = (Q @ K.transpose([0, 2, 1])) / math.sqrt(self.head_size) # We use K.transpose([0, 2, 1]) to transpose the second and third dimensions of K. This is because we want to multiply the queries with the keys. The shape of the attention weights is (B, T, T) 
#         # attn_weights.shape = (B, T, T)
#         attn_weights = attn_weights + self._causal_mask # Add the causal mask to the attention weights
#         attn_weights = mx.softmax(attn_weights, axis=-1) # Apply the softmax function to the attention weights to get the attention scores
#         o = (attn_weights @ V) # (B, T, head_size) - Multiply the attention scores with the values to get the output
#         o = self.c_proj(self.resid_dropout(o)) # (B, T, n_emb) - Apply the output projection layer to the output
#         return o # Return the output of the attention mechanism which will be used as the input to the feedforward neural network

In [368]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.head_size = head_size # Define the head size of the attention mechanism 
        self.k_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the key projection
        self.q_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the query projection
        self.v_proj = nn.Linear(n_emb, head_size, bias=False) # Linear layer for the value projection
        indices = mx.arange(ctx_len) # Create a tensor with values from 0 to ctx_len - 1
        # print(f"indices: \n {indices} \n")
        mask = indices[:, None] < indices[None] # If the value of the first tensor is less than the value of the second tensor, the value of the mask tensor is True, otherwise False which means that the mask tensor is a lower triangular matrix
        # print(f"mask: \n {mask} \n")
        self._causal_mask = mask * -1e9 # Multiply the mask tensor by -1e9 to get a tensor with -1e9 where the value of the first tensor is less than the value of the second tensor
        # print(f"mask: \n {self._causal_mask} \n")
        self.c_proj = nn.Linear(head_size, n_emb) # output projection layer to get the output of the attention mechanism
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout) # Define the dropout layer for the residual connection

    # Define the forward pass of the model
    def __call__(self, x): # shapes commented
        B, T, C = x.shape # (batch_size, ctx_len, n_emb) - x is the input tensor
        K = self.k_proj(x) # (B, T, head_size) - Project the keys
        Q = self.q_proj(x) # (B, T, head_size) - Project the queries
        V = self.v_proj(x) # (B, T, head_size) - Project the values
        mha_shape = (B, T, n_heads, head_size//n_heads) # This is the shape of the multi-head attention mechanism because we want to split the head_size into n_heads
        # print(f"mha_shape: \n {mha_shape} \n")
        K = mx.as_strided(K, (mha_shape)).transpose([0, 2, 1, 3]) # We use mx.as_strided to create a view of the K tensor with the shape (B, n_heads, T, head_size//n_heads) and then transpose the dimensions to get the desired shape
        Q = mx.as_strided(Q, (mha_shape)).transpose([0, 2, 1, 3]) # We use mx.as_strided to create a view of the Q tensor with the shape (B, n_heads, T, head_size//n_heads) and then transpose the dimensions to get the desired shape
        V = mx.as_strided(V, (mha_shape)).transpose([0, 2, 1, 3]) # We use mx.as_strided to create a view of the V tensor with the shape (B, n_heads, T, head_size//n_heads) and then transpose the dimensions to get the desired shape
        attn_weights = (Q @ K.transpose([0, 1, 3, 2])) / math.sqrt(Q.shape[-1]) # We use K.transpose([0, 1, 3, 2]) to transpose the second and third dimensions of K. This is because we want to multiply the queries with the keys. The shape of the attention weights is (B, n_heads, T, T)
        # print(f"attn_weights: \n {attn_weights} \n")
        attn_weights = attn_weights + self._causal_mask[:T, :T] # Add the causal mask to the attention weights
        # print(f"attn_weights + casual mask: \n {attn_weights} \n")
        attn_weights = mx.softmax(attn_weights, axis=-1) # Apply the softmax function to the attention weights to get the attention scores
        # print(f"softmax attn_weights: \n {attn_weights} \n")
        attn_weights = self.attn_dropout(attn_weights) # Apply the dropout layer to the attention weights
        # print(f"dropout attn_weights: \n {attn_weights} \n")
        o = (attn_weights @ V) # (B, n_heads, T, head_size//n_heads) - Multiply the attention scores with the values to get the output
        # print(f"output: \n {o} \n")
        o = o.transpose([0, 2, 1, 3]).reshape((B, T, head_size)) # We transpose the dimensions of the output and then reshape it to get the desired shape
        # print(f"output reshaped: \n {o} \n")
        o = self.c_proj(self.resid_dropout(o)) # (B, T, n_emb) - Apply the output projection layer to the output
        # print(f"output projection: \n {o} \n")
        return o # Return the output of the attention mechanism which will be used as the input to the feedforward neural network

In [369]:
class MLP(nn.Module):
    """
    Simple implementation of a feedforward neural network.

    Attributes:
        c_fc (nn.Linear): Linear layer for the fully connected operation.
        gelu (nn.GELU): GELU activation function.
        c_proj (nn.Linear): Linear layer for the projection operation.
        dropout (nn.Dropout): Dropout layer for regularization.

    Methods:
        __call__(x): Forward pass of the neural network.

    """
    def __init__(self):
        super().__init__()
        self.c_fc = nn.Linear(n_emb, 4 * n_emb)
        self.gelu = nn.GELU()
        self.c_proj = nn.Linear(4 * n_emb, n_emb)
        self.dropout = nn.Dropout(dropout)
    def __call__(self, x):
        x = self.gelu(self.c_fc(x))
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [370]:
class Block(nn.Module):
    def __init__(self):
        super().__init__()
        self.mlp = MLP() # Feedforward neural network
        self.mha = MultiHeadAttention() # Multi-head attention mechanism
        print(self.mha)
        self.ln_1 = nn.LayerNorm(dims=n_emb) # Layer normalization layer
        self.ln_2 = nn.LayerNorm(dims=n_emb) # Layer normalization layer
    def __call__(self, x): 
        x = x + self.mha(self.ln_1(x)) # Add the output of the multi-head attention mechanism to the input tensor
        x = x + self.mlp(self.ln_2(x)) # Add the output of the feedforward neural network to the input tensor 
        return x # Return the output of the block

In [371]:
def loss_fn(model, x, y):
    logits = model(x)
    B, T, C = logits.shape # (batch_size, seq_len, vocab_size)
    logits = logits.reshape(B*T, C)
    y = y.reshape(B*T)
    loss = nn.losses.cross_entropy(logits, y, reduction='mean')
    return loss

In [372]:

class GPT(nn.Module): # Define the GPT model
    def __init__(self):
        super().__init__() # Call the __init__ of the parent class
        self.wte = nn.Embedding(vocab_size, n_emb) # Lookup table for embeddings of each token in the vocab (word to embedding) -  n_emb means the size of the embedding vector
        self.wpe = nn.Embedding(ctx_len, n_emb) # Lookup table for embeddings of each position in the context (position to embedding)
        self.blocks = nn.Sequential(
            *[Block() for _ in range(n_layers)],
        ) # transformer blocks - n_layers means the number of transformer blocks
        self.ln_f = nn.LayerNorm(dims=n_emb) # final layernorm
        # print(f"layernorm: \n {self.ln_f} \n")
        self.lm_head = nn.Linear(n_emb, vocab_size) # output projection
        # print(f"lm_head: \n {self.lm_head} \n")
        self._init_parameters() # Initialize the parameters of the model
        # print total number of params on initialization
        total_params = sum([p.size for n,p in utils.tree_flatten(self.parameters())]) # Get the total number of parameters
        # print(f"Total params: {(total_params / 1e6):.3f}M") # Print the total number of parameters in millions

    # method of GPT class
    def generate(self, max_new_tokens):
        ctx = mx.zeros((1, 1), dtype=mx.int32) # (1, 1) - Create a context tensor with zeros
        for _ in range(max_new_tokens): # Loop through the number of tokens to generate
            logits = self(ctx[:, -ctx_len:]) # pass in last ctx_len characters to get the next token
            logits = logits[:, -1, :] # get logits for the next token only
            next_tok = mx.random.categorical(logits, num_samples=1) # sample the next token
            ctx = mx.concatenate((ctx, next_tok), axis=1) # append the next token to the context
        return ctx # return the context
    
    # method of GPT
    def _init_parameters(self):
        normal_init = nn.init.normal(mean=0.0, std=0.02) # Initialize the weights with a normal distribution
        residual_init = nn.init.normal(mean=0.0, std=(0.02 / math.sqrt(2 * n_layers))) # Initialize the residuals with a normal distribution
        new_params = [] # Create a list to store the new parameters
        # print(f"named_modules: \n {self.named_modules()} \n")
        for name, module in self.named_modules(): # Loop through the modules of the model
            if isinstance(module, nn.layers.linear.Linear): # Check if the module is a linear layer
                if 'c_proj' in name: # residual projection layer
                    new_params.append((name + '.weight', residual_init(module.weight))) # Initialize the weights of the residual projection layer
                else:
                    new_params.append((name + '.weight', normal_init(module.weight))) # Initialize the weights of the linear layer
                if 'bias' in module: # Check if the module has a bias
                    new_params.append((name + '.bias', mx.zeros(module.bias.shape))) # Initialize the bias with zeros
            elif isinstance(module, nn.layers.embedding.Embedding): # Check if the module is an embedding layer
                new_params.append((name + '.weight', normal_init(module.weight))) # Initialize the weights of the embedding layer
        self = self.update(utils.tree_unflatten(new_params)) # Update the model with the new parameters

    # Define the forward pass of the model
    def __call__(self, x):
        B, T = x.shape # (B = batch_size, T = ctx_len). x is the input tensor
        # print(f"input tensor: \n {x} \n")
        # print(f"x.shape: \n {x.shape} \n")
        tok_emb = self.wte(x) # (B, T, n_emb) - Get the embeddings of the tokens
        # print(f"token embedding: \n {tok_emb} \n")
        pos_emb = self.wpe(mx.arange(T)) # (T, n_emb) - Get the embeddings of the positions.  arange(T) creates a tensor with values from 0 to T-1 because T is the length of the context and minus 1 because the index starts from 0.
        # how it works is that the first position will have the first embedding, the second position will have the second embedding, and so on.
        # print(f"position embedding: \n {pos_emb} \n")
        x = tok_emb + pos_emb # (B, T, n_emb) - Add the token and position embeddings
        # print(f"token + position embedding: \n {x} \n")
        x = self.blocks(x) # (B, T, n_emb) - Pass the embeddings through the transformer blocks
        x = self.ln_f(x) # (B, T, b_emb) - Apply the final layer norm
        logits = self.lm_head(x) # (B, T, vocab_size) - Get the logits for the next token prediction
        return logits

In [373]:
model = GPT()
mx.eval(model.parameters()) # Create the model params (mlx is lazy evaluation)
loss_and_grad = nn.value_and_grad(model, loss_fn)
lr = 0.1
optimizer = optim.AdamW(learning_rate=lr)

MultiHeadAttention(
  (k_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (q_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (v_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (c_proj): Linear(input_dims=128, output_dims=128, bias=True)
  (attn_dropout): Dropout(p=0.09999999999999998)
  (resid_dropout): Dropout(p=0.09999999999999998)
)
MultiHeadAttention(
  (k_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (q_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (v_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (c_proj): Linear(input_dims=128, output_dims=128, bias=True)
  (attn_dropout): Dropout(p=0.09999999999999998)
  (resid_dropout): Dropout(p=0.09999999999999998)
)
MultiHeadAttention(
  (k_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (q_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (v_proj): Linear(input_dims=128, output_dims=128, bias=False)
  (c_proj): Linear(input_dims=128, o

In [374]:
num_epochs=1
batch_size=1
for epoch in range(num_epochs):
    model.train(True)
    running_loss = 0
    batch_cnt = 0
    for input, label in get_batches(X_train, y_train, batch_size):
        batch_cnt += 1
        loss, grads = loss_and_grad(model, input, label)
        optimizer.update(model, grads)
        running_loss += loss.item()
        # compute new parameters and optimizer state
        mx.eval(model.parameters(), optimizer.state)
    avg_train_loss = running_loss / batch_cnt
    model.train(False) # set eval mode
    running_loss = 0
    batch_cnt = 0
    for input, label in get_batches(X_val, y_val, batch_size):
        batch_cnt += 1
        loss = loss_fn(model, input, label)
        running_loss += loss.item()
    avg_val_loss = running_loss / batch_cnt
    print(f"Epoch {epoch:2} | loss = {loss.item():.f4}|train = {avg_train_loss:.4f} | val = {avg_val_loss:.4f}")

KeyboardInterrupt: 

In [None]:
completion = decode(model.generate(1000)[0].tolist())
print(completion)
with open('completions.txt', 'w') as f:
    f.write(completion)

In [None]:
# model = GPT()
model = GPT()

print(f"token embedding shape: {model.wte.weight.shape}")
# vocab_size x n_emb (65, 6)

print(f"positional embedding shape: {model.wpe.weight.shape}\n")
# ctx_len x n_emb (8, 6)

model(X)
# You'll see the token embeddings, positional embeddings, and the sum of the two embeddings for the first batch of the training data.

indices: 
 array([0, 1, 2, 3], dtype=int32) 

mask: 
 array([[False, True, True, True],
       [False, False, True, True],
       [False, False, False, True],
       [False, False, False, False]], dtype=bool) 

mask: 
 array([[-0, -1e+09, -1e+09, -1e+09],
       [-0, -0, -1e+09, -1e+09],
       [-0, -0, -0, -1e+09],
       [-0, -0, -0, -0]], dtype=float32) 

indices: 
 array([0, 1, 2, 3], dtype=int32) 

mask: 
 array([[False, True, True, True],
       [False, False, True, True],
       [False, False, False, True],
       [False, False, False, False]], dtype=bool) 

mask: 
 array([[-0, -1e+09, -1e+09, -1e+09],
       [-0, -0, -1e+09, -1e+09],
       [-0, -0, -0, -1e+09],
       [-0, -0, -0, -0]], dtype=float32) 

indices: 
 array([0, 1, 2, 3], dtype=int32) 

mask: 
 array([[False, True, True, True],
       [False, False, True, True],
       [False, False, False, True],
       [False, False, False, False]], dtype=bool) 

mask: 
 array([[-0, -1e+09, -1e+09, -1e+09],
       [-0, -0, -1e+09

AttributeError: 'MultiHeadAttention' object has no attribute 'attn_dropout'