## Llama3 8B using JAX

Converted from this [PyTorch Lightning tutorial](https://lightning.ai/fareedhassankhan12/studios/building-llama-3-from-scratch) to use JAX. You will need a Kaggle VM or a high-RAM colab VM to run this, although no GPU is needed.


First install dependencies.

In [1]:
!pip install -q jax-ai-stack
!pip install -Uq transformers huggingface_hub tiktoken blobfile

Download model weights.

Imports.

In [2]:
from pathlib import Path
import tiktoken
from tiktoken.load import load_tiktoken_bpe
import torch
import json, os
import jax
import jax.numpy as jnp
from flax import nnx
from huggingface_hub import snapshot_download

In [3]:
# import os
# model_id = "meta-llama/Meta-Llama-3-8B"
# path_to_model = os.path.join("/content", model_id, 'original')
# snapshot_download(repo_id=model_id, local_dir=path_to_model, subfolder='original')

from huggingface_hub import hf_hub_download

repo_id = "meta-llama/Meta-Llama-3-8B"
subfolder = "original"
filenames = ["params.json", "tokenizer.model", "consolidated.00.pth"]

path_to_model = os.path.join("/content", repo_id)

for filename in filenames:
    hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        subfolder=subfolder,
        local_dir=path_to_model
    )

path_to_model = os.path.join(path_to_model, 'original')

Tokenizer.

In [4]:
tokenizer_model = load_tiktoken_bpe(path_to_model+"/tokenizer.model")
model_weights = torch.load(path_to_model+"/consolidated.00.pth")


with open(path_to_model+"/params.json", "r") as f:
    config = json.load(f)

dim = config["dim"]
n_layers = config["n_layers"]
n_heads = config["n_heads"]
n_kv_heads = config["n_kv_heads"]
vocab_size = config["vocab_size"]
multiple_of = config["multiple_of"]
ffn_dim_multiplier = config["ffn_dim_multiplier"]
norm_eps = config["norm_eps"]
rope_theta = config["rope_theta"]

special_tokens = [
    "<|begin_of_text|>",
    "<|end_of_text|>",
    "<|reserved_special_token_0|>",
    "<|reserved_special_token_1|>",
    "<|reserved_special_token_2|>",
    "<|reserved_special_token_3|>",
    "<|start_header_id|>",
    "<|end_header_id|>",
    "<|reserved_special_token_4|>",
    "<|eot_id|>",
] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)]

tokenize_breaker = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"

tokenizer = tiktoken.Encoding(
    name = path_to_model+"tokenizer.model",
    pat_str = tokenize_breaker,
    mergeable_ranks = tokenizer_model,
    special_tokens={token: len(tokenizer_model) + i for i, token in enumerate(special_tokens)},
)

# prompt = "the answer to the ultimate question of life, the universe, and everything is "
prompt = "the capital of China is "
tokens = [128000] + tokenizer.encode(prompt)

  model_weights = torch.load(path_to_model+"/consolidated.00.pth")


Embeddings.

In [5]:
embedding_layer = nnx.Embed(vocab_size, dim, rngs=nnx.Rngs(0))
embedding_layer.embedding.value = model_weights["tok_embeddings.weight"].float().numpy()
token_embeddings_unnormalized = embedding_layer(jnp.asarray(tokens)).astype(jnp.bfloat16)
hidden_state = token_embeddings_unnormalized

RMS layer norm.

In [6]:
def rms_norm(tensor, norm_weights):
    squared_mean = jnp.mean(jnp.square(tensor), axis=-1, keepdims=True)
    normalized = jnp.reciprocal(jnp.sqrt(squared_mean + norm_eps))
    return (tensor * normalized) * norm_weights

Prep for RoPE calculation.

In [7]:
head_dim = dim // n_heads

zero_to_one_split_into_64_parts = jnp.arange(64)/64
freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)
freqs_for_each_token = jnp.outer(jnp.arange(token_embeddings_unnormalized.shape[0]), freqs)
freqs_cis = jnp.complex64(jnp.exp(1j * freqs))

See [NVidia diagram](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/_images/transformer_vs_llama.svg) for model architecture.

In [8]:
for layer in range(n_layers):
    qkv_attention_store = []

    layer_embedding_norm = rms_norm(hidden_state, jnp.asarray(model_weights[f"layers.{layer}.attention_norm.weight"].float().numpy()).astype(jnp.bfloat16))

    q_layer = jnp.asarray(model_weights[f"layers.{layer}.attention.wq.weight"].float().numpy()).astype(jnp.bfloat16)
    q_layer = jnp.reshape(q_layer, (n_heads, q_layer.shape[0] // n_heads, dim))
    k_layer = jnp.asarray(model_weights[f"layers.{layer}.attention.wk.weight"].float().numpy()).astype(jnp.bfloat16)
    k_layer = k_layer.reshape(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)
    v_layer = jnp.asarray(model_weights[f"layers.{layer}.attention.wv.weight"].float().numpy()).astype(jnp.bfloat16)
    v_layer = v_layer.reshape(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)
    w_layer = jnp.asarray(model_weights[f"layers.{layer}.attention.wo.weight"].float().numpy()).astype(jnp.bfloat16)

    for head in range(n_heads):
        q_layer_head = q_layer[head]
        k_layer_head = k_layer[head//4]
        v_layer_head = v_layer[head//4]

        q_per_token = jnp.matmul(layer_embedding_norm, q_layer_head.T)
        k_per_token = jnp.matmul(layer_embedding_norm, k_layer_head.T)
        v_per_token = jnp.matmul(layer_embedding_norm, v_layer_head.T)

        # apply RoPe below
        freqs_for_each_token = jnp.outer(jnp.arange(token_embeddings_unnormalized.shape[0]), freqs)
        freqs_cis = jnp.exp(1j * freqs_for_each_token)
        q_per_token_split_into_pairs = q_per_token.astype(jnp.float32).reshape(q_per_token.shape[0], -1, 2)
        q_per_token_as_complex_numbers = q_per_token_split_into_pairs[..., 0] + 1j * q_per_token_split_into_pairs[..., 1]
        q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis
        q_per_token_split_into_pairs_rotated = jnp.stack([q_per_token_as_complex_numbers_rotated.real, q_per_token_as_complex_numbers_rotated.imag], axis=-1)
        q_per_token_rotated = q_per_token_split_into_pairs_rotated.reshape(q_per_token.shape)

        # Repeat the process for k_per_token
        k_per_token_split_into_pairs = k_per_token.astype(jnp.float32).reshape(k_per_token.shape[0], -1, 2)
        k_per_token_as_complex_numbers = k_per_token_split_into_pairs[..., 0] + 1j * k_per_token_split_into_pairs[..., 1]
        k_per_token_as_complex_numbers_rotated = k_per_token_as_complex_numbers * freqs_cis
        k_per_token_split_into_pairs_rotated = jnp.stack([k_per_token_as_complex_numbers_rotated.real, k_per_token_as_complex_numbers_rotated.imag], axis=-1)
        k_per_token_rotated = k_per_token_split_into_pairs_rotated.reshape(k_per_token.shape)

        # TODO: update 128
        qk_per_token = jnp.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5

        mask = jnp.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf"))
        mask = jnp.triu(mask, k=1)
        qk_per_token_after_masking = qk_per_token + mask

        qk_per_token_after_masking_after_softmax = jax.nn.softmax(qk_per_token_after_masking, axis=1)

        qkv_attention = jnp.matmul(qk_per_token_after_masking_after_softmax, v_per_token)

        qkv_attention_store.append(qkv_attention)

    stacked_qkv_attention = jnp.concatenate(qkv_attention_store, axis=-1)

    embedding_delta = jnp.matmul(stacked_qkv_attention, w_layer.T)

    embedding_after_edit = hidden_state + embedding_delta

    embedding_after_edit_normalized = rms_norm(embedding_after_edit, model_weights[f"layers.{layer}.ffn_norm.weight"].float().numpy())

    w1 = jnp.asarray(model_weights[f"layers.{layer}.feed_forward.w1.weight"].float().numpy()).astype(jnp.bfloat16)
    w2 = jnp.asarray(model_weights[f"layers.{layer}.feed_forward.w2.weight"].float().numpy()).astype(jnp.bfloat16)
    w3 = jnp.asarray(model_weights[f"layers.{layer}.feed_forward.w3.weight"].float().numpy()).astype(jnp.bfloat16)

    output_after_feedforward = jnp.matmul(nnx.silu(jnp.matmul(embedding_after_edit_normalized, w1.T)) * jnp.matmul(embedding_after_edit_normalized, w3.T), w2.T)

    hidden_state = embedding_after_edit + output_after_feedforward

logits = jnp.matmul(hidden_state[-1], jnp.asarray(model_weights["output.weight"].float().numpy()).astype(jnp.bfloat16).T)

Predit the next token.

In [9]:
next_token = jnp.argmax(logits, axis=-1)

print(tokenizer.decode([next_token]))

 Beijing
