In [31]:
!pip install -q sentencepiece tiktoken torch blobfile matplotlib huggingface_hub ipywidgets

In [32]:
from huggingface_hub import login

login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [34]:
# Import the necessary function from the huggingface_hub library
from huggingface_hub import hf_hub_download

# Define the repository information
repo_id = "meta-llama/Meta-Llama-3-8B"
subfolder = "original"  # Specify the subfolder within the repository

# List of filenames to download
filenames = ["params.json", "tokenizer.model", "consolidated.00.pth"] 

# Specify the directory where you want to save the downloaded files
save_directory = "llama-3-8B/"  # Replace with your desired path

# Download each file
for filename in filenames:
    hf_hub_download(
        repo_id=repo_id,       # Repository ID
        filename=filename,     # Name of the file to download
        subfolder=subfolder,   # Subfolder within the repository
        local_dir=save_directory  # Directory to save the downloaded file
    )

GatedRepoError: 401 Client Error. (Request ID: Root=1-66627b40-1548c3f56c86a6650a319e23;d04a5add-420c-407e-b340-e489da9860d1)

Cannot access gated repo for url https://huggingface.co/meta-llama/Meta-Llama-3-8B/resolve/main/original/params.json.
Access to model meta-llama/Meta-Llama-3-8B is restricted. You must be authenticated to access it.

In [None]:
# Tokenization library
import tiktoken

# BPE loading function
from tiktoken.load import load_tiktoken_bpe

# PyTorch library
import torch

# JSON handling
import json

In [None]:
# Loading the tokenizer from llama-3-8B
tokenizer_model = load_tiktoken_bpe("models/tokenizer.model")

# Get the length of the tokenizer model 
len(tokenizer_model)
# OUTPUT: 128000

# Get the type of the `tokenizer_model` object.
type(tokenizer_model)
# OUTPUT: dictionary

dict

In [None]:
# Printing the first 10 items of tokenizer model
dict(list(tokenizer_model.items())[5600:5610])

{b'mitted': 5600,
 b" $('#": 5601,
 b' saw': 5602,
 b' approach': 5603,
 b'ICE': 5604,
 b' saying': 5605,
 b' anyone': 5606,
 b'meta': 5607,
 b'SD': 5608,
 b' song': 5609}

In [None]:
# Loading a PyTorch model of LLaMA-3-8B
model = torch.load("models/original/consolidated.00.pth")

# printing first 11 layers of the architecture
list(model.keys())[:11]

FileNotFoundError: [Errno 2] No such file or directory: 'models/original/consolidated.00.pth'

In [None]:
# Opening the parameters JSON file
with open("models/original_params.json", "r") as f:
    config = json.load(f)

# Printing the content
print(config)

{'dim': 4096, 'n_layers': 32, 'n_heads': 32, 'n_kv_heads': 8, 'vocab_size': 128256, 'multiple_of': 1024, 'ffn_dim_multiplier': 1.3, 'norm_eps': 1e-05, 'rope_theta': 500000.0}


In [None]:
# Dimension
dim = config["dim"]

# Layers
n_layers = config["n_layers"]

# Heads
n_heads = config["n_heads"]

# KV_heads
n_kv_heads = config["n_kv_heads"]

# Vocabulary
vocab_size = config["vocab_size"]

# Multiple
multiple_of = config["multiple_of"]

# Multiplier
ffn_dim_multiplier = config["ffn_dim_multiplier"]

# Epsilon
norm_eps = config["norm_eps"]

# RoPE
rope_theta = torch.tensor(config["rope_theta"])

In [None]:
special_tokens = [
    "<|begin_of_text|>",  # Marks the beginning of a text sequence.
    "<|end_of_text|>",  # Marks the end of a text sequence.
    "<|reserved_special_token_0|>",  # Reserved for future use.
    "<|reserved_special_token_1|>",  # Reserved for future use.
    "<|reserved_special_token_2|>",  # Reserved for future use.
    "<|reserved_special_token_3|>",  # Reserved for future use.
    "<|start_header_id|>",  # Indicates the start of a header ID.
    "<|end_header_id|>",  # Indicates the end of a header ID.
    "<|reserved_special_token_4|>",  # Reserved for future use.
    "<|eot_id|>",  # Marks the end of a turn (in a conversational context).
] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]  # A large set of tokens reserved for future use.+

In [None]:
# patterns based on which text will be break into tokens
tokenize_breaker = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"

In [None]:
# Initialize tokenizer with specified parameters
tokenizer = tiktoken.Encoding(

    # make sure to set path to tokenizer.model file
    name = "models/tokenizer.model",

    # Define tokenization pattern string
    pat_str = tokenize_breaker,

    # Assign BPE mergeable ranks from tokenizer_model of LLaMA-3
    mergeable_ranks = tokenizer_model,

    # Set special tokens with indices
    special_tokens={token: len(tokenizer_model) + i for i, token in enumerate(special_tokens)},
)

# Encode "hello world!" and decode tokens to string
tokenizer.decode(tokenizer.encode("hello world!"))

'hello world!'

In [None]:
# input prompt
prompt = "the best joke is?"

# Encode the prompt using the tokenizer and prepend a special token (128000)
tokens = [128000] + tokenizer.encode(prompt)

print(tokens)  # Print the encoded tokens

# Convert the list of tokens into a PyTorch tensor
tokens = torch.tensor(tokens)

# Decode each token back into its corresponding string
prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]

print(prompt_split_as_tokens)  # Print the decoded tokens

[128000, 1820, 1888, 22380, 374, 30]
['<|begin_of_text|>', 'the', ' best', ' joke', ' is', '?']


In [None]:
# checking dimension of input vector
len(tokens)

6

In [None]:
# checking dimension of embedding vector from llama-3 architecture
print(dim)

4096


In [None]:
# Define embedding layer with vocab size and embedding dimension
embedding_layer = torch.nn.Embedding(vocab_size, dim)

# Copy pre-trained token embeddings to the embedding layer
embedding_layer.weight.data.copy_(model["tok_embeddings.weight"])

# Get token embeddings for given tokens, converting to torch.bfloat16 format
token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)

# Print shape of resulting token embeddings
token_embeddings_unnormalized.shape

torch.Size([6, 4096])

In [None]:
# Calculating RMSNorm
def rms_norm(tensor, norm_weights):

    # Calculate the mean of the square of tensor values along the last dimension
    squared_mean = tensor.pow(2).mean(-1, keepdim=True)
    
    # Add a small value to avoid division by zero
    normalized = torch.rsqrt(squared_mean + norm_eps)
    
    # Multiply normalized tensor by the provided normalization weights
    return (tensor * normalized) * norm_weights

In [None]:
# using RMS normalization and provided normalization weights
token_embeddings = rms_norm(token_embeddings_unnormalized, 
                            model["layers.0.attention_norm.weight"])

# Print the shape of the resulting token embeddings
token_embeddings.shape

torch.Size([17, 4096])

In [None]:
# Print the shapes of different weights
print(
    # Query weight shape
    model["layers.0.attention.wq.weight"].shape,
    
    # Key weight shape
    model["layers.0.attention.wk.weight"].shape,
    
    # Value weight shape
    model["layers.0.attention.wv.weight"].shape,
    
    # Output weight shape
    model["layers.0.attention.wo.weight"].shape
)

torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096]) torch.Size([4096, 4096])


In [None]:
# Retrieve query weight for the first layer of attention
q_layer0 = model["layers.0.attention.wq.weight"]

# Calculate dimension per head
head_dim = q_layer0.shape[0] // n_heads

# Reshape query weight to separate heads
q_layer0 = q_layer0.view(n_heads, head_dim, dim)

# Print the shape of the reshaped query weight tensor
q_layer0.shape

torch.Size([32, 128, 4096])

In [None]:
# Extract the query weight for the first head of the first layer of attention
q_layer0_head0 = q_layer0[0]

# Print the shape of the extracted query weight tensor for the first head
q_layer0_head0.shape

torch.Size([128, 4096])

In [None]:
# Matrix multiplication: token embeddings with transpose of query weight for first head
q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)

# Shape of resulting tensor: queries per token
q_per_token.shape

torch.Size([17, 128])

In [None]:
# Convert queries per token to float and split into pairs
q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)

# Print the shape of the resulting tensor after splitting into pairs
q_per_token_split_into_pairs.shape

torch.Size([17, 64, 2])

In [None]:
# Generate values from 0 to 1 split into 64 parts
zero_to_one_split_into_64_parts = torch.tensor(range(64))/64

# Print the resulting tensor
zero_to_one_split_into_64_parts

tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,
        0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,
        0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,
        0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,
        0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,
        0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,
        0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,
        0.9844])

In [None]:
# Calculate frequencies using a power operation
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)

# Display the resulting frequencies
freqs

tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,
        2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,
        8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,
        2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,
        7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,
        2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,
        6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,
        1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,
        5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,
        1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,
        4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])

In [None]:
# Convert queries per token to complex numbers
q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)

q_per_token_as_complex_numbers.shape
# Output: torch.Size([17, 64])

# Calculate frequencies for each token using outer product of arange(17) and freqs
freqs_for_each_token = torch.outer(torch.arange(6), freqs)

# Calculate complex numbers from frequencies_for_each_token using polar coordinates
freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)

# Rotate complex numbers by frequencies
q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis

q_per_token_as_complex_numbers_rotated.shape
# Output: torch.Size([17, 64])

RuntimeError: The size of tensor a (17) must match the size of tensor b (6) at non-singleton dimension 0

In [None]:
# Convert rotated complex numbers back to real numbers
q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)

# Print the shape of the resulting tensor
q_per_token_split_into_pairs_rotated.shape

torch.Size([17, 64, 2])

In [None]:
# Reshape rotated token queries to match the original shape
q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)

# Print the shape of the resulting tensor
q_per_token_rotated.shape

torch.Size([17, 128])

In [None]:
# Extract the weight tensor for the attention mechanism's key in the first layer of the model
k_layer0 = model["layers.0.attention.wk.weight"]

# Reshape key weight for the first layer of attention to separate heads
k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)

# Print the shape of the reshaped key weight tensor
k_layer0.shape  # Output: torch.Size([8, 128, 4096])

# Extract the key weight for the first head of the first layer of attention
k_layer0_head0 = k_layer0[0]

# Print the shape of the extracted key weight tensor for the first head
k_layer0_head0.shape  # Output: torch.Size([128, 4096])

# Calculate key per token by matrix multiplication
k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)

# Print the shape of the resulting tensor representing keys per token
k_per_token.shape  # Output: torch.Size([17, 128])

# Split key per token into pairs and convert to float
k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)

# Print the shape of the resulting tensor after splitting into pairs
k_per_token_split_into_pairs.shape  # Output: torch.Size([17, 64, 2])

# Convert key per token to complex numbers
k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)

# Print the shape of the resulting tensor representing key per token as complex numbers
k_per_token_as_complex_numbers.shape  # Output: torch.Size([17, 64])

# Rotate complex key per token by frequencies
k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)

# Print the shape of the rotated complex key per token
k_per_token_split_into_pairs_rotated.shape  # Output: torch.Size([17, 64, 2])

# Reshape rotated key per token to match the original shape
k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)

# Print the shape of the rotated key per token
k_per_token_rotated.shape  # Output: torch.Size([17, 128])

torch.Size([17, 128])

In [None]:
# Calculate query-key dot products per token
qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (head_dim) ** 0.5

# Print the shape of the resulting tensor representing query-key dot products per token
qk_per_token.shape

torch.Size([17, 17])

In [None]:
# Create a mask tensor filled with negative infinity values
mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)

# Set upper triangular part of the mask tensor to negative infinity
mask = torch.triu(mask, diagonal=1)

# Print the resulting mask tensor
mask

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -in

In [None]:
# Add the mask to the query-key dot products per token
qk_per_token_after_masking = qk_per_token + mask

# Apply softmax along the second dimension after masking
qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)

In [None]:
# Retrieve the value weight for the first layer of attention
v_layer0 = model["layers.0.attention.wv.weight"]

# Reshape value weight for the first layer of attention to separate heads
v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)

# Print the shape of the reshaped value weight tensor
v_layer0.shape

torch.Size([8, 128, 4096])

In [None]:
# Extract the value weight for the first head of the first layer of attention
v_layer0_head0 = v_layer0[0]

# Print the shape of the extracted value weight tensor for the first head
v_layer0_head0.shape

torch.Size([128, 4096])

In [None]:
# Calculate value per token by matrix multiplication
v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)

# Print the shape of the resulting tensor representing values per token
v_per_token.shape

torch.Size([17, 128])

In [None]:
# Calculate QKV attention by matrix multiplication
qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)

# Print the shape of the resulting tensor
qkv_attention.shape

torch.Size([17, 128])

In [None]:
# Store QKV attention for each head in a list
qkv_attention_store = []

# Iterate through each head
for head in range(n_heads):
    # Extract query, key, and value weights for the current head
    q_layer0_head = q_layer0[head]
    k_layer0_head = k_layer0[head//4]  # Key weights are shared across 4 heads
    v_layer0_head = v_layer0[head//4]  # Value weights are shared across 4 heads
    
    # Calculate query per token by matrix multiplication
    q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)
    
    # Calculate key per token by matrix multiplication
    k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)
    
    # Calculate value per token by matrix multiplication
    v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)
    
    # Split query per token into pairs and rotate them
    q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
    q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
    q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
    
    # Split key per token into pairs and rotate them
    k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
    k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
    k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])
    k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
    
    # Calculate query-key dot products per token
    qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5
    
    # Create a mask tensor filled with negative infinity values
    mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device)
    # Set upper triangular part of the mask tensor to negative infinity
    mask = torch.triu(mask, diagonal=1)
    # Add the mask to the query-key dot products per token
    qk_per_token_after_masking = qk_per_token + mask
    
    # Apply softmax along the second dimension after masking
    qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
    
    # Calculate QKV attention by matrix multiplication
    qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
    
    # Store QKV attention for the current head
    qkv_attention_store.append(qkv_attention)

# Print the number of QKV attentions stored
len(qkv_attention_store)

32

In [None]:
# Concatenate QKV attentions from all heads along the last dimension
stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)

# Print the shape of the resulting tensor
stacked_qkv_attention.shape

torch.Size([17, 4096])

In [None]:
# Calculate the embedding delta by matrix multiplication with the output weight
embedding_delta = torch.matmul(stacked_qkv_attention, model["layers.0.attention.wo.weight"].T)

# Print the shape of the resulting tensor
embedding_delta.shape

torch.Size([17, 4096])

In [None]:
# Add the embedding delta to the unnormalized token embeddings to get the final embeddings
embedding_after_edit = token_embeddings_unnormalized + embedding_delta

# Print the shape of the resulting tensor
embedding_after_edit.shape

torch.Size([17, 4096])

In [None]:
# Normalize edited embeddings using root mean square normalization and provided weights
embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"])

# Print the shape of resulting normalized embeddings
embedding_after_edit_normalized.shape

torch.Size([17, 4096])

In [None]:
# Retrieve weights for feedforward layer
w1 = model["layers.0.feed_forward.w1.weight"]
w2 = model["layers.0.feed_forward.w2.weight"]
w3 = model["layers.0.feed_forward.w3.weight"]

# Perform operations for feedforward layer
output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)

# Print the shape of the resulting tensor after feedforward
output_after_feedforward.shape

torch.Size([17, 4096])

In [None]:
# Initialize final embedding with unnormalized token embeddings
final_embedding = token_embeddings_unnormalized

# Iterate through each layer
for layer in range(n_layers):
    # Initialize list to store QKV attentions for each head
    qkv_attention_store = []
    
    # Normalize the final embedding using root mean square normalization and weights from the current layer
    layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"])
    
    # Retrieve query, key, value, and output weights for the attention mechanism of the current layer
    q_layer = model[f"layers.{layer}.attention.wq.weight"]
    q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)
    k_layer = model[f"layers.{layer}.attention.wk.weight"]
    k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
    v_layer = model[f"layers.{layer}.attention.wv.weight"]
    v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
    w_layer = model[f"layers.{layer}.attention.wo.weight"]
    
    # Iterate through each head
    for head in range(n_heads):
        # Extract query, key, and value weights for the current head
        q_layer_head = q_layer[head]
        k_layer_head = k_layer[head//4]  # Key weights are shared across 4 heads
        v_layer_head = v_layer[head//4]  # Value weights are shared across 4 heads
        
        # Calculate query per token by matrix multiplication
        q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)
        
        # Calculate key per token by matrix multiplication
        k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)
        
        # Calculate value per token by matrix multiplication
        v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)
        
        # Split query per token into pairs and rotate them
        q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)
        q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)
        q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)
        q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)
        
        # Split key per token into pairs and rotate them
        k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)
        k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)
        k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)
        k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)
        
        # Calculate query-key dot products per token
        qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5
        
        # Create a mask tensor filled with negative infinity values
        mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
        # Set upper triangular part of the mask tensor to negative infinity
        mask = torch.triu(mask, diagonal=1)
        # Add the mask to the query-key dot products per token
        qk_per_token_after_masking = qk_per_token + mask
        
        # Apply softmax along the second dimension after masking
        qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)
        
        # Calculate QKV attention by matrix multiplication
        qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)
        
        # Store QKV attention for the current head
        qkv_attention_store.append(qkv_attention)
    
    # Concatenate QKV attentions from all heads along the last dimension
    stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)
    
    # Calculate embedding delta by matrix multiplication with the output weight
    embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)
    
    # Add the embedding delta to the current embedding to get the edited embedding
    embedding_after_edit = final_embedding + embedding_delta
    
    # Normalize the edited embedding using root mean square normalization and weights from the current layer
    embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"])
    
    # Retrieve weights for the feedforward layer
    w1 = model[f"layers.{layer}.feed_forward.w1.weight"]
    w2 = model[f"layers.{layer}.feed_forward.w2.weight"]
    w3 = model[f"layers.{layer}.feed_forward.w3.weight"]
    
    # Perform operations for the feedforward layer
    output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)
    
    # Update the final embedding with the edited embedding plus the output from the feedforward layer
    final_embedding = embedding_after_edit + output_after_feedforward

In [None]:
# Normalize the final embedding using root mean square normalization and provided weights
final_embedding = rms_norm(final_embedding, model["norm.weight"])

# Print the shape of the resulting normalized final embedding
final_embedding.shape

torch.Size([17, 4096])

In [None]:
# Print the shape of the output weight tensor
model["output.weight"].shape

torch.Size([128256, 4096])

In [None]:
# Calculate logits by matrix multiplication between the final embedding and the transpose of the output weight tensor
logits = torch.matmul(final_embedding[-1], model["output.weight"].T)

# Print the shape of the resulting logits tensor
logits.shape

torch.Size([128256])

In [None]:
# Find the index of the maximum value along the last dimension to determine the next token
next_token = torch.argmax(logits, dim=-1)

# Output the index of the next token
next_token

tensor(2983)

In [None]:
# Decode the index of the next token using the tokenizer
tokenizer.decode([next_token.item()])

'42'