# Building LLaMA3.1 from scratch
Following https://levelup.gitconnected.com/building-llama-3-from-scratch-with-python-e0cf4dbbc306

## Part 1 - Get the Llama 3 model files from Hugging Face and explore them


### Download the model files
We only need to run these cells once, to download the files

### Import the libraries we need for tiktoken, PyTorch and json handling

In [139]:
# Import libraries

# Tokenisation library
import tiktoken

# Byte Pair Encoding function
from tiktoken.load import load_tiktoken_bpe

# PyTorch library
import torch

# JSON 
import json

### Load and inspect the tokenizer model (tokenizer.model)

In [140]:
# Load the tokenizer model
model_directory = "llama-3.1-8B/original/"
tokenizer_model = load_tiktoken_bpe(model_directory + "tokenizer.model") 

In [141]:
# How many tokens are in the tokenizer model?  12800
len(tokenizer_model)

128000

In [142]:
# What type of object is it? dict
type(tokenizer_model)

dict

In [143]:
# Print 10 items of the model
dict(list(tokenizer_model.items())[5000:5010])

{b' Web': 5000,
 b'Des': 5001,
 b'BC': 5002,
 b'ancial': 5003,
 b'Route': 5004,
 b'Dec': 5005,
 b'ferences': 5006,
 b' purch': 5007,
 b' Model': 5008,
 b'ctor': 5009}

### Load and inspect the model's learned weights (consolidated.00.pth)

In [144]:
# Load PyTorch model of Llama-3.1-8B

#   map_location=torch.device('cpu'): this tells torch I'm using CPU rather than CUDA compatible GPU
#   weights_only=True: this prevents a security warning
model = torch.load(model_directory + "consolidated.00.pth",  map_location=torch.device('cpu'), weights_only=True)

In [145]:
print("len(model)=",len(model))

len(model)= 291


In [146]:
print("model keys are of type ",type(next(iter(model.keys()), None)))
print("model values are of type ",type(next(iter(model.values()), None)))

model keys are of type  <class 'str'>
model values are of type  <class 'torch.Tensor'>


In [147]:
# Let's look at some of the keys of the model to see the architecture of the LLM
list(model.keys())[:20]

['tok_embeddings.weight',
 'layers.0.attention.wq.weight',
 'layers.0.attention.wk.weight',
 'layers.0.attention.wv.weight',
 'layers.0.attention.wo.weight',
 'layers.0.feed_forward.w1.weight',
 'layers.0.feed_forward.w3.weight',
 'layers.0.feed_forward.w2.weight',
 'layers.0.attention_norm.weight',
 'layers.0.ffn_norm.weight',
 'layers.1.attention.wq.weight',
 'layers.1.attention.wk.weight',
 'layers.1.attention.wv.weight',
 'layers.1.attention.wo.weight',
 'layers.1.feed_forward.w1.weight',
 'layers.1.feed_forward.w3.weight',
 'layers.1.feed_forward.w2.weight',
 'layers.1.attention_norm.weight',
 'layers.1.ffn_norm.weight',
 'layers.2.attention.wq.weight']

In [148]:
list(model.keys())[280:]

['layers.31.attention.wq.weight',
 'layers.31.attention.wk.weight',
 'layers.31.attention.wv.weight',
 'layers.31.attention.wo.weight',
 'layers.31.feed_forward.w1.weight',
 'layers.31.feed_forward.w3.weight',
 'layers.31.feed_forward.w2.weight',
 'layers.31.attention_norm.weight',
 'layers.31.ffn_norm.weight',
 'norm.weight',
 'output.weight']

In [149]:
# Let's lool at one entry
dict(list(model.items())[0:2])

{'tok_embeddings.weight': tensor([[ 1.2436e-03,  5.6763e-03, -3.2501e-03,  ...,  4.0588e-03,
          -2.6245e-03, -6.8283e-04],
         [-3.4332e-03,  1.3351e-03, -1.6556e-03,  ...,  9.4604e-04,
          -1.8616e-03, -2.1515e-03],
         [ 6.7520e-04, -1.6113e-02,  2.5635e-03,  ...,  4.3030e-03,
           7.8735e-03,  4.8523e-03],
         ...,
         [ 2.2230e-23,  3.9291e-24,  2.1713e-23,  ...,  6.4106e-23,
          -2.6625e-24, -2.3678e-23],
         [ 2.2954e-23, -2.2230e-24, -2.2334e-23,  ...,  2.8124e-23,
           8.7371e-24, -3.7223e-23],
         [-8.8922e-23, -7.6101e-23,  6.5140e-24,  ...,  5.9195e-24,
          -6.4934e-23, -2.7271e-24]], dtype=torch.bfloat16),
 'layers.0.attention.wq.weight': tensor([[ 0.0053, -0.0291, -0.0058,  ...,  0.0095, -0.0420, -0.0272],
         [ 0.0284,  0.0008, -0.0093,  ..., -0.0092, -0.0078,  0.0048],
         [-0.0142, -0.0679, -0.0049,  ..., -0.0142, -0.0498,  0.0192],
         ...,
         [-0.0035, -0.0101,  0.0459,  ...,  0.00

### Load and inspect model hyperparameters (params.json)

In [150]:
with open(model_directory + "params.json","r") as f:
          config = json.load(f)

print(config)

{'dim': 4096, 'ffn_dim_multiplier': 1.3, 'multiple_of': 1024, 'n_heads': 32, 'n_kv_heads': 8, 'n_layers': 32, 'norm_eps': 1e-05, 'rope_theta': 500000.0, 'use_scaled_rope': True, 'vocab_size': 128256}


In [151]:
type(config)

dict

## Part 2 - Build our Llama-3 model

### Store model hyperparameters for use later in building our LLM

In [152]:
# Dimension
dim = config["dim"]

# Layers
n_layers = config["n_layers"]

# Heads
n_heads = config["n_heads"]

# KV_heads
n_kv_heads = config["n_kv_heads"]

# Vocabulary
vocab_size = config["vocab_size"]

# Multiple
multiple_of = config["multiple_of"]

# Multiplier
ffn_dim_multiplier = config["ffn_dim_multiplier"]

# Epsilon
norm_eps = config["norm_eps"]

# RoPE
rope_theta = torch.tensor(config["rope_theta"])

### Create a tokenizer to tokenize our input data

In [153]:
special_tokens = [
    "<|begin_of_text|>",  # Marks the beginning of a text sequence.
    "<|end_of_text|>",  # Marks the end of a text sequence.
    "<|reserved_special_token_0|>",  # Reserved for future use.
    "<|reserved_special_token_1|>",  # Reserved for future use.
    "<|reserved_special_token_2|>",  # Reserved for future use.
    "<|reserved_special_token_3|>",  # Reserved for future use.
    "<|start_header_id|>",  # Indicates the start of a header ID.
    "<|end_header_id|>",  # Indicates the end of a header ID.
    "<|reserved_special_token_4|>",  # Reserved for future use.
    "<|eot_id|>",  # Marks the end of a turn (in a conversational context).
] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]  # A large set of tokens reserved for future use.

In [154]:
# Set a regex for breaking inpout into tokens
tokenize_breaker = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"

In [155]:
# Here's how the tokenize_breaker regex works on an example sentence
import re

# Example sentence
sentence = "I'm happy! It's 100% true."

# Since the regex pattern uses Unicode property escapes like \p{L}, we will use the 'regex' module (not the built-in 're' module) in Python,
# which supports Unicode property escapes directly.
import regex as re

# Tokenize the sentence using the regex pattern
tokens = re.findall(tokenize_breaker, sentence)

# Print the tokens
print(tokens)

['I', "'m", ' happy', '!', ' It', "'s", ' ', '100', '%', ' true', '.']


In [156]:
# Initialize tokenizer with specified parameters
tokenizer = tiktoken.Encoding(

    # Name of encoding
    name = "tokenizer.model",

    # Define tokenization pattern string
    pat_str = tokenize_breaker,

    # Assign BPE mergeable ranks from tokenizer_model of LLaMA-3
    mergeable_ranks = tokenizer_model,

    # Set special tokens with indices
    special_tokens={token: len(tokenizer_model) + i for i, token in enumerate(special_tokens)},
)

# Encode "hello world!" and decode tokens to string
tokenizer.decode(tokenizer.encode("hello world, I'm 10% pleased to see you!"))

"hello world, I'm 10% pleased to see you!"

In [157]:
tokenizer.encode("hello world, I'm 10% pleased to see you!")


[15339, 1917, 11, 358, 2846, 220, 605, 4, 18949, 311, 1518, 499, 0]

In [158]:
# Let's have a look at those special_tokens.  
# These are additional tokens we're adding to the "end" of tokenizer_model
special_tokens={token: len(tokenizer_model) + i for i, token in enumerate(special_tokens)}

# Show first 5 entries
list(special_tokens.items())[0: 5]

[('<|begin_of_text|>', 128000),
 ('<|end_of_text|>', 128001),
 ('<|reserved_special_token_0|>', 128002),
 ('<|reserved_special_token_1|>', 128003),
 ('<|reserved_special_token_2|>', 128004)]

In [159]:
# input prompt
prompt = "the answer to the ultimate question of life, the universe, and everything is "

prompt = "If I could "

# Encode the prompt using the tokenizer and prepend a special token (128000)
tokens = [128000] + tokenizer.encode(prompt)

print(tokens)  # Print the encoded tokens

# Convert the list of tokens into a PyTorch tensor
tokens = torch.tensor(tokens)


# Decode each token back into its corresponding string
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]

print(prompt_split_as_tokens)  # Print the decoded tokens

[128000, 53770, 4848, 8254, 8223, 11, 889, 656, 584]
['<|begin_of_text|>', 'five', ' six', ' seven', ' eight', ',', ' who', ' do', ' we']


In [160]:
# Check length of our input vector
input_seq_length = len(tokens)
input_seq_length

9

In [161]:
# What are the dimensions of the embedding vector for llama-3.1?
print(dim)

4096


In [162]:
tokens.shape

torch.Size([9])

### Transform input sequence to embeddings for each word in the sequence
We need to transform our input vector (sequence of seventeen words encoded as tokens) from 
its current dimensions of (17 x 1) to (17 x 4096), to capture the embedding for each of the 
seventeen words in our input vector

In [163]:
# Define embedding layer with vocab size and embedding dimension
embedding_layer = torch.nn.Embedding(vocab_size, dim)

# Copy pre-trained token embeddings to the embedding layer
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])

# Get token embeddings for given tokens, converting to torch.bfloat16 format
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)

# Print shape of resulting token embeddings
token_embeddings_unnormalized.shape


torch.Size([9, 4096])

### Normalise our input vector using RMSNorm


In [164]:
# Calculating RMSNorm
def rms_norm(tensor, norm_weights):

    #print("tensor.shape = ", tensor.shape)
    #print("norm_weights.shape = ", norm_weights.shape)
    
    # Calculate the mean of the square of tensor values along the last dimension
    squared_mean = tensor.pow(2).mean(-1, keepdim=True)
    #print("squared_mean.shape = ",squared_mean.shape)
    #print("i.e. for each of the tokens, the mean of the squares of the embedding values for that token")
    #print("squared_mean = ",squared_mean)
    
    # Add a small value to avoid division by zero
    normalized = torch.rsqrt(squared_mean + norm_eps) # note that rsqrt gives the inverse (reciprocal) of the square root
    #print("normalized.shape = ",normalized.shape)
    #print("i.e. the inverse square root of each squared mean")
    #print("normalized = ",normalized)
    
    # Multiply normalized tensor by the provided normalization weights
    return (tensor * normalized) * norm_weights

Use the attention weights from the first layer of our transformer architecture to normalise our embeddings

In [165]:
# using RMS normalization and provided normalization weights
token_embeddings = rms_norm(token_embeddings_unnormalized, 
                            model["layers.0.attention_norm.weight"])

# Print the shape of the resulting token embeddings
token_embeddings.shape

torch.Size([9, 4096])

## Part 3: Start building our attention heads.  Build the first head for the first layer

In [166]:
# Print the shapes of different weights
print(
    # Query weight shape
    model["layers.0.attention.wq.weight"].shape,
    
    # Key weight shape
    model["layers.0.attention.wk.weight"].shape,
    
    # Value weight shape
    model["layers.0.attention.wv.weight"].shape,
    
    # Output weight shape
    model["layers.0.attention.wo.weight"].shape
)

torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096]) torch.Size([4096, 4096])


Now remember that each layer (in this case layer zero) has a number of attention heads.  These weights combine the weights for all the heads in this layer.  Let's reshape so that each attention head's wieghts are separate...

In [167]:
# Retrieve query weight for the first layer of attention
q_layer0 = model["layers.0.attention.wq.weight"]

# Calculate dimension per head
head_dim = q_layer0.shape[0] // n_heads
print("head_dim = ", head_dim)

# Reshape query weight to separate heads
q_layer0 = q_layer0.view(n_heads, head_dim, dim)

# Print the shape of the reshaped query weight tensor
q_layer0.shape

head_dim =  128


torch.Size([32, 128, 4096])

In [168]:
# Extract the query weight for the first head of the first layer of attention
q_layer0_head0 = q_layer0[0]

# Print the shape of the extracted query weight tensor for the first head
q_layer0_head0.shape

torch.Size([128, 4096])

Multiply the query weights with the token embedding for each token, to get our query vector

In [169]:
# Matrix multiplication: token embeddings with transpose of query weight for first head
print("token_embeddings.shape = ", token_embeddings.shape)
print("q_layer0_head0.T.shape = ", q_layer0_head0.T.shape)

q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)

# Shape of resulting tensor: queries per token
print("q_per_token.shape = ",q_per_token.shape)

token_embeddings.shape =  torch.Size([9, 4096])
q_layer0_head0.T.shape =  torch.Size([4096, 128])
q_per_token.shape =  torch.Size([9, 128])


### Implement positional encoding with RoPE
We now need to implement positional encoding so that the query vectors encode something about their position in the sequence.  
Split the query vectors into pairs and apply rotational angle shift to each pair:

In [170]:
# Convert queries per token to float and split into pairs
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)

# Print the shape of the resulting tensor after splitting into pairs
q_per_token_split_into_pairs.shape

torch.Size([9, 64, 2])

In [171]:
q_per_token

tensor([[-0.0840, -0.1611,  0.2910,  ...,  0.9219,  0.5000,  0.3457],
        [ 0.4473, -0.2129,  0.8164,  ...,  1.3047,  1.5234,  1.4766],
        [ 0.9375, -0.0076,  2.1094,  ...,  1.1797,  1.7031,  1.4219],
        ...,
        [ 0.8867,  0.1729,  2.5781,  ...,  2.8594,  1.6172,  1.1094],
        [ 0.9023,  0.2383,  2.6406,  ...,  2.5000,  1.2969,  0.8047],
        [ 0.7461,  0.1377,  2.1875,  ...,  2.2344,  1.5391,  1.1016]],
       dtype=torch.bfloat16, grad_fn=<MmBackward0>)

In [172]:
q_per_token_split_into_pairs

tensor([[[-0.0840, -0.1611],
         [ 0.2910, -0.2910],
         [ 0.3184, -0.4141],
         ...,
         [ 0.4980, -0.1328],
         [ 0.4336,  0.9219],
         [ 0.5000,  0.3457]],

        [[ 0.4473, -0.2129],
         [ 0.8164, -0.7383],
         [ 0.6016, -1.0859],
         ...,
         [ 1.0312,  0.0830],
         [ 0.5898,  1.3047],
         [ 1.5234,  1.4766]],

        [[ 0.9375, -0.0076],
         [ 2.1094, -1.1484],
         [ 1.3750, -1.8516],
         ...,
         [ 1.2344, -0.4648],
         [ 0.1094,  1.1797],
         [ 1.7031,  1.4219]],

        ...,

        [[ 0.8867,  0.1729],
         [ 2.5781, -1.1719],
         [ 1.7969, -2.5156],
         ...,
         [ 1.4609, -0.5469],
         [ 0.4844,  2.8594],
         [ 1.6172,  1.1094]],

        [[ 0.9023,  0.2383],
         [ 2.6406, -1.1641],
         [ 1.5703, -2.2031],
         ...,
         [ 1.3594, -0.0270],
         [ 0.8516,  2.5000],
         [ 1.2969,  0.8047]],

        [[ 0.7461,  0.1377],
       

In [173]:
# Generate values from 0 to 1 split into 64 parts
zero_to_one_split_into_64_parts = torch.tensor(range(64))/64

# Print the resulting tensor
zero_to_one_split_into_64_parts

tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,
        0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,
        0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,
        0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,
        0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,
        0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,
        0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,
        0.9844])

In [174]:
rope_theta

tensor(500000.)

In [175]:
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)

# Display the resulting frequencies
freqs

tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,
        2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,
        8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,
        2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,
        7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,
        2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,
        6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,
        1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,
        5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,
        1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,
        4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])

In [176]:
1/(500000**0.0156)

0.8148845202809214

In [177]:
# Convert queries per token to complex numbers
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)

print("q_per_token_as_complex_numbers.shape = ", q_per_token_as_complex_numbers.shape)

# Calculate frequencies for each token using outer product of arange(17) and freqs
# freqs_for_each_token = torch.outer(torch.arange(17), freqs)
freqs_for_each_token = torch.outer(torch.arange(input_seq_length), freqs)
print("freqs_for_each_token.shape = ", freqs_for_each_token.shape)
print("freqs_for_each_token = ", freqs_for_each_token)

# Calculate complex numbers from frequencies_for_each_token using polar coordinates
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)

# Rotate complex numbers by frequencies
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis

print("q_per_token_as_complex_numbers_rotated.shape = ", q_per_token_as_complex_numbers_rotated.shape)
# Output: torch.Size([17, 64])

q_per_token_as_complex_numbers.shape =  torch.Size([9, 64])
freqs_for_each_token.shape =  torch.Size([9, 64])
freqs_for_each_token =  tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
         0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]

In [178]:
# Convert rotated complex numbers back to real numbers
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)

# Print the shape of the resulting tensor
q_per_token_split_into_pairs_rotated.shape

torch.Size([9, 64, 2])

In [179]:
# Reshape rotated token queries to match the original shape
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)

# Print the shape of the resulting tensor
q_per_token_rotated.shape

torch.Size([9, 128])

In [180]:
q_per_token_split_into_pairs_rotated

tensor([[[-0.0840, -0.1611],
         [ 0.2910, -0.2910],
         [ 0.3184, -0.4141],
         ...,
         [ 0.4980, -0.1328],
         [ 0.4336,  0.9219],
         [ 0.5000,  0.3457]],

        [[ 0.4208,  0.2613],
         [ 1.0972,  0.0873],
         [ 1.1428, -0.4849],
         ...,
         [ 1.0312,  0.0830],
         [ 0.5898,  1.3047],
         [ 1.5234,  1.4766]],

        [[-0.3833,  0.8556],
         [ 1.0233,  2.1728],
         [ 2.1285,  0.8878],
         ...,
         [ 1.2344, -0.4648],
         [ 0.1094,  1.1797],
         [ 1.7031,  1.4219]],

        ...,

        [[ 0.8997, -0.0818],
         [-0.7042, -2.7430],
         [-3.0726,  0.3410],
         ...,
         [ 1.4609, -0.5468],
         [ 0.4843,  2.8594],
         [ 1.6172,  1.1094]],

        [[ 0.5237,  0.7725],
         [ 1.5688, -2.4222],
         [-2.3036, -1.4189],
         ...,
         [ 1.3594, -0.0269],
         [ 0.8515,  2.5000],
         [ 1.2969,  0.8047]],

        [[-0.2448,  0.7181],
       

In [181]:
q_per_token_rotated

tensor([[-0.0840, -0.1611,  0.2910,  ...,  0.9219,  0.5000,  0.3457],
        [ 0.4208,  0.2613,  1.0972,  ...,  1.3047,  1.5234,  1.4766],
        [-0.3833,  0.8556,  1.0233,  ...,  1.1797,  1.7031,  1.4219],
        ...,
        [ 0.8997, -0.0818, -0.7042,  ...,  2.8594,  1.6172,  1.1094],
        [ 0.5237,  0.7725,  1.5688,  ...,  2.5000,  1.2969,  0.8047],
        [-0.2448,  0.7181,  2.3262,  ...,  2.2344,  1.5390,  1.1016]],
       grad_fn=<ViewBackward0>)

### Now do the same for Keys

In [182]:
# Extract the weight tensor for the attention mechanism's key in the first layer of the model
k_layer0 = model["layers.0.attention.wk.weight"]

# Reshape key weight for the first layer of attention to separate heads
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)

# Print the shape of the reshaped key weight tensor
k_layer0.shape  # Output: torch.Size([8, 128, 4096])

# Extract the key weight for the first head of the first layer of attention
k_layer0_head0 = k_layer0[0]

# Print the shape of the extracted key weight tensor for the first head
k_layer0_head0.shape  # Output: torch.Size([128, 4096])

# Calculate key per token by matrix multiplication
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)

# Print the shape of the resulting tensor representing keys per token
k_per_token.shape  # Output: torch.Size([17, 128])

# Split key per token into pairs and convert to float
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)

# Print the shape of the resulting tensor after splitting into pairs
k_per_token_split_into_pairs.shape  # Output: torch.Size([17, 64, 2])

# Convert key per token to complex numbers
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)

# Print the shape of the resulting tensor representing key per token as complex numbers
k_per_token_as_complex_numbers.shape  # Output: torch.Size([17, 64])

# Rotate complex key per token by frequencies
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)

# Print the shape of the rotated complex key per token
k_per_token_split_into_pairs_rotated.shape  # Output: torch.Size([17, 64, 2])

# Reshape rotated key per token to match the original shape
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)

# Print the shape of the rotated key per token
k_per_token_rotated.shape  # Output: torch.Size([17, 128])

torch.Size([9, 128])

### Implement self attention

In [183]:
print(head_dim)
print(q_per_token_rotated.shape, k_per_token_rotated.T.shape)


128
torch.Size([9, 128]) torch.Size([128, 9])


In [184]:
# Calculate query-key dot products per token
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (head_dim) ** 0.5

# Print the shape of the resulting tensor representing query-key dot products per token
qk_per_token.shape

torch.Size([9, 9])

In [185]:
# Result is a 17 x 17 matrix with the attention scores of every word in the input sequence with every other word (including itself)
qk_per_token

tensor([[3.7595, 1.5931, 1.2744, 1.2071, 1.4154, 2.1142, 1.8662, 1.5409, 1.4746],
        [9.5143, 6.2809, 4.6963, 3.5233, 3.4170, 3.8271, 3.9491, 4.2899, 3.9911],
        [9.9251, 8.2805, 6.6103, 3.8961, 1.9111, 1.9750, 1.3963, 3.1151, 4.2871],
        [9.5106, 9.0878, 8.4867, 6.4063, 3.8344, 2.7558, 0.3039, 1.2163, 3.0849],
        [8.9338, 8.8067, 9.1704, 8.2739, 6.4438, 4.5944, 1.0971, 0.1569, 1.6121],
        [8.1783, 7.2689, 7.3711, 7.2926, 6.7289, 7.2121, 4.6933, 2.5593, 2.3273],
        [9.1369, 8.5826, 8.6075, 8.8851, 8.7706, 8.6583, 6.7642, 2.7599, 0.9424],
        [8.1441, 7.1941, 7.0830, 7.3355, 7.6451, 8.8202, 7.7363, 5.4743, 2.5069],
        [8.4862, 7.1415, 7.0239, 7.1970, 7.4780, 8.5692, 8.6379, 6.8205, 6.5807]],
       grad_fn=<DivBackward0>)

In [186]:
# Create a mask tensor filled with negative infinity values
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)

# Keep upper triangular part of the mask tensor as negative infinity, set the rest to zeros
mask = torch.triu(mask, diagonal=1)

# Print the resulting mask tensor
mask

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0.]])

In [187]:
# Add the mask to the query-key dot products per token
qk_per_token_after_masking = qk_per_token + mask

# Apply softmax along the second dimension after masking
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)

qk_per_token_after_masking_after_softmax

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.9609, 0.0378, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.8125, 0.1572, 0.0295, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4863, 0.3184, 0.1748, 0.0217, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2676, 0.2354, 0.3379, 0.1377, 0.0221, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3477, 0.1396, 0.1553, 0.1436, 0.0815, 0.1318, 0.0000, 0.0000, 0.0000],
        [0.2305, 0.1318, 0.1357, 0.1787, 0.1592, 0.1426, 0.0215, 0.0000, 0.0000],
        [0.1826, 0.0703, 0.0630, 0.0811, 0.1108, 0.3594, 0.1211, 0.0126, 0.0000],
        [0.2119, 0.0552, 0.0491, 0.0583, 0.0771, 0.2305, 0.2461, 0.0400, 0.0315]],
       dtype=torch.bfloat16, grad_fn=<ToCopyBackward0>)

### Apply the attention scores to the value matrix
* So far we've produced an attention score matrix that captures only the relationships between tokens (how much each token should focus on the others) but does not carry the original token embedding or positional information forward.  
* The next step applies these attention scores to the values (which still retain both the word embeddings and positional encodings), re-incorporating these elements back into the representation, resulting in a contextually aware embedding for each token.

For the value matrix, which marks the end of the self-attention part, similar to keys, value weights are also shared across every 4 attention heads to save computation. As a result, the shape of the value weight matrix is [8x128x4096].

In [188]:

# Retrieve the value weight for the first layer of attention
v_layer0 = model["layers.0.attention.wv.weight"]
print(v_layer0.shape)

print("v_layer0.shape[0] = ", v_layer0.shape[0])
print("n_kv_heads = ", n_kv_heads)
print("v_layer0.shape[0] // n_kv_heads = ", v_layer0.shape[0] // n_kv_heads)

# Reshape value weight for the first layer of attention to separate heads
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)

# Print the shape of the reshaped value weight tensor
v_layer0.shape

torch.Size([1024, 4096])
v_layer0.shape[0] =  1024
n_kv_heads =  8
v_layer0.shape[0] // n_kv_heads =  128


torch.Size([8, 128, 4096])

Obtain value weight matrix for first layer and first head...

In [189]:
# Extract the value weight for the first head of the first layer of attention
v_layer0_head0 = v_layer0[0]

# Print the shape of the extracted value weight tensor for the first head
v_layer0_head0.shape

torch.Size([128, 4096])

Using the value weights, we compute the attention values for each token, resulting in a matrix of size [17x128]. Here, 17 denotes the number of tokens in the prompt, and 128 indicates the dimension of the value vector for each token.

In [190]:
# Calculate value per token by matrix multiplication
v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)

# Print the shape of the resulting tensor representing values per token
v_per_token.shape

torch.Size([9, 128])

Obtain the resulting attention matri...

In [191]:
# Calculate QKV attention by matrix multiplication
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)

# Print the shape of the resulting tensor
qkv_attention.shape

torch.Size([9, 128])

***We now have the attention values for the first layer and first head or in other words self attention.***

## Part 4: Rinse and repeat: implement multi-head attention

Do all of that again in a loop, once for each head on the layer

In [192]:
# Store QKV attention for each head in a list
qkv_attention_store = []

# Iterate through each head
for head in range(n_heads):
    # Extract query, key, and value weights for the current head
    q_layer0_head = q_layer0[head]
    k_layer0_head = k_layer0[head//4]  # Key weights are shared across 4 heads
    v_layer0_head = v_layer0[head//4]  # Value weights are shared across 4 heads
    
    # Calculate query per token by matrix multiplication
    q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)
    
    # Calculate key per token by matrix multiplication
    k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)
    
    # Calculate value per token by matrix multiplication
    v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)
    
    # Split query per token into pairs and rotate them
    q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
    q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
    q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
    
    # Split key per token into pairs and rotate them
    k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
    k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
    k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
    
    # Calculate query-key dot products per token
    qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5
    
    # Create a mask tensor filled with negative infinity values
    mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
    # Set upper triangular part of the mask tensor to negative infinity
    mask = torch.triu(mask, diagonal=1)
    # Add the mask to the query-key dot products per token
    qk_per_token_after_masking = qk_per_token + mask
    
    # Apply softmax along the second dimension after masking
    qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
    
    # Calculate QKV attention by matrix multiplication
    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
    
    # Store QKV attention for the current head
    qkv_attention_store.append(qkv_attention)

# Print the number of QKV attentions stored
len(qkv_attention_store)


32

In [193]:
print(qkv_attention_store[0].shape)

torch.Size([9, 128])


Now that the QKV attention matrix for all 32 heads in the first layer is obtained, all attention scores will be merged into one large matrix of size [17x4096]..

In [194]:
# Concatenate QKV attentions from all heads along the last dimension
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)

# Print the shape of the resulting tensor
stacked_qkv_attention.shape

torch.Size([9, 4096])

Now multiply the weight matrix with the stacked QKV matrix...

In [195]:
# Calculate the embedding delta by matrix multiplication with the output weight
embedding_delta = torch.matmul(stacked_qkv_attention, model["layers.0.attention.wo.weight"].T)

# Print the shape of the resulting tensor
embedding_delta.shape

torch.Size([9, 4096])

This gives us the change in the embedding values after attention.  Now add these to the original token embeddings...

In [196]:
# Add the embedding delta to the unnormalized token embeddings to get the final embeddings
embedding_after_edit = token_embeddings_unnormalized + embedding_delta

# Print the shape of the resulting tensor
embedding_after_edit.shape

torch.Size([9, 4096])

Normalise...

In [197]:
# Normalize edited embeddings using root mean square normalization and provided weights
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"])

# Print the shape of resulting normalized embeddings
embedding_after_edit_normalized.shape


torch.Size([9, 4096])

### Now implement the feed forward neural network with a SwiGLU activation function

In [198]:
# Retrieve weights for feedforward layer
w1 = model["layers.0.feed_forward.w1.weight"]
w2 = model["layers.0.feed_forward.w2.weight"]
w3 = model["layers.0.feed_forward.w3.weight"]

# Perform operations for feedforward layer
output_after_feedforward = torch.matmul(
    torch.functional.F.silu(
        torch.matmul(embedding_after_edit_normalized, w1.T)
    ) 
    * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T
)

# Print the shape of the resulting tensor after feedforward
output_after_feedforward.shape


torch.Size([9, 4096])

In [199]:
# Initialize final embedding with unnormalized token embeddings
final_embedding = token_embeddings_unnormalized

# Iterate through each layer
for layer in range(n_layers):
    # Initialize list to store QKV attentions for each head
    qkv_attention_store = []
    
    # Normalize the final embedding using root mean square normalization and weights from the current layer
    layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
    
    # Retrieve query, key, value, and output weights for the attention mechanism of the current layer
    q_layer = model[f"layers.{layer}.attention.wq.weight"]
    q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
    k_layer = model[f"layers.{layer}.attention.wk.weight"]
    k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
    v_layer = model[f"layers.{layer}.attention.wv.weight"]
    v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
    w_layer = model[f"layers.{layer}.attention.wo.weight"]
    
    # Iterate through each head
    for head in range(n_heads):
        # Extract query, key, and value weights for the current head
        q_layer_head = q_layer[head]
        k_layer_head = k_layer[head//4]  # Key weights are shared across 4 heads
        v_layer_head = v_layer[head//4]  # Value weights are shared across 4 heads
        
        # Calculate query per token by matrix multiplication
        q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
        
        # Calculate key per token by matrix multiplication
        k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
        
        # Calculate value per token by matrix multiplication
        v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
        
        # Split query per token into pairs and rotate them
        q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
        q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
        q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
        q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
        
        # Split key per token into pairs and rotate them
        k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
        k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
        k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
        k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
        
        # Calculate query-key dot products per token
        qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5
        
        # Create a mask tensor filled with negative infinity values
        mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
        # Set upper triangular part of the mask tensor to negative infinity
        mask = torch.triu(mask, diagonal=1)
        # Add the mask to the query-key dot products per token
        qk_per_token_after_masking = qk_per_token + mask
        
        # Apply softmax along the second dimension after masking
        qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
        
        # Calculate QKV attention by matrix multiplication
        qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
        
        # Store QKV attention for the current head
        qkv_attention_store.append(qkv_attention)
    
    # Concatenate QKV attentions from all heads along the last dimension
    stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
    
    # Calculate embedding delta by matrix multiplication with the output weight
    embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
    
    # Add the embedding delta to the current embedding to get the edited embedding
    embedding_after_edit = final_embedding + embedding_delta
    
    # Normalize the edited embedding using root mean square normalization and weights from the current layer
    embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
    
    # Retrieve weights for the feedforward layer
    w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
    w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
    w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
    
    # Perform operations for the feedforward layer
    output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
    
    # Update the final embedding with the edited embedding plus the output from the feedforward layer
    final_embedding = embedding_after_edit + output_after_feedforward

## Part 5: put it all together for all 32 layers

### Generate the model's output
The final embedding shouild represent the model's guess for the next token.  
It has the same shape as the token embeddings [17x40896]

In [200]:
# Normalize the final embedding using root mean square normalization and provided weights
final_embedding = rms_norm(final_embedding, model["norm.weight"])

# Print the shape of the resulting normalized final embedding
final_embedding.shape

torch.Size([9, 4096])

In [201]:
# Print the shape of the output weight tensor
model["output.weight"].shape

torch.Size([128256, 4096])

To predict the next value, use the embedding of the last token...

In [202]:
# Calculate logits by matrix multiplication between the final embedding and the transpose of the output weight tensor
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)

# Print the shape of the resulting logits tensor
logits.shape

torch.Size([128256])

In [203]:
# Find the index of the maximum value along the last dimension to determine the next token
next_token = torch.argmax(logits, dim=-1)

# Output the index of the next token
next_token

tensor(15763)

In [204]:
# Decode the index of the next token using the tokenizer
tokenizer.decode([next_token.item()])

' appreciate'

In [205]:
next_token

tensor(15763)

In [206]:
print(next_token.item())

15763


In [207]:
tokenizer.decode([1041])

'100'