### Download the model files
We only need to run these cells once, to download the files

### Import the libraries we need for tiktoken, PyTorch and json handling

In [1]:
# Tokenisation library
import tiktoken

# Byte Pair Encoding function
from tiktoken.load import load_tiktoken_bpe

# PyTorch library
import torch

# JSON 
import json

### Load the model from the model files

In [2]:
# Load the tokenizer model
model_directory = "llama-3.1-8B/original/"
tokenizer_model = load_tiktoken_bpe(model_directory + "tokenizer.model") 

# Load PyTorch model of Llama-3.1-8B
model = torch.load(model_directory + "consolidated.00.pth",  map_location=torch.device('cpu'), weights_only=True)

with open(model_directory + "params.json","r") as f:
          config = json.load(f)

### Store model hyperparameters

In [3]:
# Dimension
dim = config["dim"]

# Layers
n_layers = config["n_layers"]

# Heads
n_heads = config["n_heads"]

# KV_heads
n_kv_heads = config["n_kv_heads"]

# Vocabulary
vocab_size = config["vocab_size"]

# Multiple
multiple_of = config["multiple_of"]

# Multiplier
ffn_dim_multiplier = config["ffn_dim_multiplier"]

# Epsilon
norm_eps = config["norm_eps"]

# RoPE
rope_theta = torch.tensor(config["rope_theta"])

### Function to normalise our input vector using RMSNorm

In [17]:
# Calculating RMSNorm
def rms_norm(tensor, norm_weights):
    
    # Calculate the mean of the square of tensor values along the last dimension
    squared_mean = tensor.pow(2).mean(-1, keepdim=True)
    
    # Add a small value to avoid division by zero
    normalized = torch.rsqrt(squared_mean + norm_eps) # note that rsqrt gives the inverse (reciprocal) of the square root
    
    # Multiply normalized tensor by the provided normalization weights
    return (tensor * normalized) * norm_weights

### Function to transform tokens into embeddings

In [18]:
# Given our input tokens tokens (words, parts of words etc), transform these into vector embeddings for each token
def tranform_token_sequence_to_embeddings(tokens):

    # Define embedding layer with vocab size and embedding dimension
    embedding_layer = torch.nn.Embedding(vocab_size, dim)
    
    # Copy pre-trained token embeddings to the embedding layer
    embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])
    
    # Get token embeddings for given tokens, converting to torch.bfloat16 format
    token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)
    
    return token_embeddings_unnormalized 

### Transformer function

In [19]:
def llama3_transformer(embedding, input_seq_length):

    # Get freqs_cis
    zero_to_one_split_into_64_parts = torch.tensor(range(64))/64
    freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
    freqs_for_each_token = torch.outer(torch.arange(input_seq_length), freqs)
    # Calculate complex numbers from frequencies_for_each_token using polar coordinates
    freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)
    
    # Iterate through each layer
    for layer in range(n_layers):
        # Initialize list to store QKV attentions for each head
        qkv_attention_store = []
        
        # Normalize the final embedding using root mean square normalization and weights from the current layer
        layer_embedding_norm = rms_norm(embedding, model[f"layers.{layer}.attention_norm.weight"])
        
        # Retrieve query, key, value, and output weights for the attention mechanism of the current layer
        q_layer = model[f"layers.{layer}.attention.wq.weight"]
        q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
        k_layer = model[f"layers.{layer}.attention.wk.weight"]
        k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
        v_layer = model[f"layers.{layer}.attention.wv.weight"]
        v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
        w_layer = model[f"layers.{layer}.attention.wo.weight"]
        
        # Iterate through each head
        for head in range(n_heads):
            # Extract query, key, and value weights for the current head
            q_layer_head = q_layer[head]
            k_layer_head = k_layer[head//4]  # Key weights are shared across 4 heads
            v_layer_head = v_layer[head//4]  # Value weights are shared across 4 heads
            
            # Calculate query per token by matrix multiplication
            q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
            
            # Calculate key per token by matrix multiplication
            k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
            
            # Calculate value per token by matrix multiplication
            v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
            
            # Split query per token into pairs and rotate them
            q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
            q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
            q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
            q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
            
            # Split key per token into pairs and rotate them
            k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
            k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
            k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
            k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
            
            # Calculate query-key dot products per token
            qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5
            
            # Create a mask tensor filled with negative infinity values
            mask = torch.full((len(embedding), len(embedding)), float("-inf"))
            # Set upper triangular part of the mask tensor to negative infinity
            mask = torch.triu(mask, diagonal=1)
            # Add the mask to the query-key dot products per token
            qk_per_token_after_masking = qk_per_token + mask
            
            # Apply softmax along the second dimension after masking
            qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
            
            # Calculate QKV attention by matrix multiplication
            qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
            
            # Store QKV attention for the current head
            qkv_attention_store.append(qkv_attention)
        
        # Concatenate QKV attentions from all heads along the last dimension
        stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
        
        # Calculate embedding delta by matrix multiplication with the output weight
        embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
        
        # Add the embedding delta to the current embedding to get the edited embedding
        embedding_after_edit = embedding + embedding_delta
        
        # Normalize the edited embedding using root mean square normalization and weights from the current layer
        embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
        
        # Retrieve weights for the feedforward layer
        w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
        w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
        w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
        
        # Perform operations for the feedforward layer
        output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
        
        # Update the final embedding with the edited embedding plus the output from the feedforward layer
        embedding = embedding_after_edit + output_after_feedforward
    
    return embedding
        

### Create a tokenizer to tokenize our input data

In [20]:
special_tokens = [
    "<|begin_of_text|>",  # Marks the beginning of a text sequence.
    "<|end_of_text|>",  # Marks the end of a text sequence.
    "<|reserved_special_token_0|>",  # Reserved for future use.
    "<|reserved_special_token_1|>",  # Reserved for future use.
    "<|reserved_special_token_2|>",  # Reserved for future use.
    "<|reserved_special_token_3|>",  # Reserved for future use.
    "<|start_header_id|>",  # Indicates the start of a header ID.
    "<|end_header_id|>",  # Indicates the end of a header ID.
    "<|reserved_special_token_4|>",  # Reserved for future use.
    "<|eot_id|>",  # Marks the end of a turn (in a conversational context).
] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]  # A large set of tokens reserved for future use.

# Set a regex for breaking inpout into tokens
tokenize_breaker = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"

# Initialize tokenizer with specified parameters
tokenizer = tiktoken.Encoding(

    # Name of encoding
    name = "tokenizer.model",

    # Define tokenization pattern string
    pat_str = tokenize_breaker,

    # Assign BPE mergeable ranks from tokenizer_model of LLaMA-3
    mergeable_ranks = tokenizer_model,

    # Set special tokens with indices
    special_tokens={token: len(tokenizer_model) + i for i, token in enumerate(special_tokens)},
)

### Function that takes an input prompt as a lit of tokens and returns the predicted next token

In [39]:
def generate_next_token(tokens):

    # Convert the list of tokens into a PyTorch tensor
    tokens = torch.tensor(tokens)
    input_seq_length = len(tokens)

    # Transform prompt to embeddings
    final_embedding = tranform_token_sequence_to_embeddings(tokens)

    # Run through LLama
    final_embedding = llama3_transformer(final_embedding, input_seq_length)

    # Normalize the final embedding using root mean square normalization and provided weights
    final_embedding = rms_norm(final_embedding, model["norm.weight"])

    # Calculate logits by matrix multiplication between the final embedding and the transpose of the output weight tensor
    logits = torch.matmul(final_embedding[-1], model["output.weight"].T)

    # Find the index of the maximum value along the last dimension to determine the next token
    next_token = torch.argmax(logits, dim=-1)

    return next_token.item()

In [50]:
# input prompt
prompt = "To paraphrase the opening lines of Anna Karenina: "

# Encode the prompt using the tokenizer and prepend a <|begin_of_text|> token (128000)
prompt_tokens = [128000] + tokenizer.encode(prompt)

print("Input prompt: ", tokenizer.decode(prompt_tokens))
print("Input prompt as tokens: ", prompt_tokens)

# Generate the next tokens
tokens_to_generate = 15
for n in range(0,tokens_to_generate):
    # Call function to predict next token
    next_token = generate_next_token(prompt_tokens)
    # Concatenate predicted token onto end of prompt to create next prompt
    prompt_tokens = prompt_tokens + [next_token]
    print(tokenizer.decode([next_token]))

# Decode tokens back to text
prompt = tokenizer.decode(prompt_tokens)

print("Output: ", tokenizer.decode(prompt_tokens))
print("Output as tokens: ", prompt_tokens)

Input prompt:  <|begin_of_text|>To paraphrase the opening lines of Anna Karenina: 
Input prompt as tokens:  [128000, 1271, 63330, 10857, 279, 8736, 5238, 315, 24101, 35745, 2259, 25, 220]
 
All
 happy
 families
 are
 alike
;
 each
 unhappy
 family
 is
 unhappy
 in
 its
 own
Output:  <|begin_of_text|>To paraphrase the opening lines of Anna Karenina:  All happy families are alike; each unhappy family is unhappy in its own
Output as tokens:  [128000, 1271, 63330, 10857, 279, 8736, 5238, 315, 24101, 35745, 2259, 25, 220, 4194, 2460, 6380, 8689, 527, 27083, 26, 1855, 43251, 3070, 374, 43251, 304, 1202, 1866]
