# Converting the LLama 3.2 1B model from Hugging Face to JAX

This tutorial demonstrates to convert Meta's [Llama 3.2 1B model](https://huggingface.co/meta-llama/Llama-3.2-1B) from Hugging Face to a JAX model and run it on T4 GPU.

## Setup

Let's install the `jax-ai-stack`, we'll use the `jax` and `flax` libraries from the stack in this tutorial. We will also need `huggingface_hub` for downloading model weights and `transformers` for tokenization.

In [1]:
!pip install -q jax-ai-stack
!pip install -Uq transformers huggingface_hub
from google.colab import userdata
from google.colab import userdata
import os
os.environ['HF_TOKEN'] = userdata.get('HF_TOKEN')
# !huggingface-cli login

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/456.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m450.6/456.0 kB[0m [31m14.5 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m456.0/456.0 kB[0m [31m9.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.4/2.4 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m319.2/319.2 kB[0m [31m13.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m406.3/406.3 kB[0m [31m15.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.2/86.2 kB[0m [31m6.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

Take care of the imports.

In [2]:
import jax
import jax.numpy as jnp
from flax import nnx
from safetensors import safe_open
from pathlib import Path
import os
from huggingface_hub import snapshot_download
from transformers import AutoTokenizer
from dataclasses import dataclass



## Define the configuration

The Hugging Face Transformers library has [some tips for using the Llama3 model](https://huggingface.co/docs/transformers/main/en/model_doc/llama3). For further reference, the modeling code in the transformers library lives in the [`models/llama/modeling_llama.py` file](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py).

Before we create the model in JAX, we need to define some parameters. You can refer to the [Llama2 documentation for the configuration options](https://huggingface.co/docs/transformers/main/en/model_doc/llama2#transformers.LlamaConfig).

In [3]:
@dataclass
class LlamaConfig:
    def __init__(self):
        self.dim = 2048
        self.n_layers = 16
        self.n_heads = 32
        self.n_kv_heads = 8
        self.head_dim = self.dim // self.n_heads
        self.intermediate_size = 14336
        self.vocab_size = 128256
        self.multiple_of = 256
        self.norm_eps = 1e-05
        self.rope_theta = 500000.0

config = LlamaConfig()

## Load the model weights

We'll use the transformers library to download the model weights.

Meta requires [acceptance of the license](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/discussions/130) before you can access the files. You will also need a Hugging Face access token, please refer to [Hugging Face documentation](https://huggingface.co/docs/hub/en/security-tokens) to set it up.

In [4]:
model_id = "meta-llama/Llama-3.2-1B-instruct"
if os.path.exists('/kaggle'):
    weights_base_dir = '/kaggle/tmp'
elif os.path.exists('/content'):
    # Colab
    weights_base_dir = '/content'
else:
    # Local machine
    weights_base_dir = '.'

path_to_model_weights = os.path.join(weights_base_dir, model_id)

snapshot_download(repo_id=model_id, local_dir=path_to_model_weights)

Fetching 13 files:   0%|          | 0/13 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/41.7k [00:00<?, ?B/s]

LICENSE.txt:   0%|          | 0.00/7.71k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.52k [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/189 [00:00<?, ?B/s]

original/consolidated.00.pth:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

USE_POLICY.md:   0%|          | 0.00/6.02k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.47G [00:00<?, ?B/s]

params.json:   0%|          | 0.00/220 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

original/tokenizer.model:   0%|          | 0.00/2.18M [00:00<?, ?B/s]

'/kaggle/tmp/meta-llama/Llama-3.2-1B-instruct'

Then extract the model weights from the safetensors file and store them in the `weights` dict. These weights will be loaded into our JAX model soon.

In [5]:
def load_safetensors():
    weights = {}
    safetensors_files = Path(path_to_model_weights).glob('*.safetensors')

    for file in safetensors_files:
        with safe_open(file, framework="jax", device="cpu") as f:
            for key in f.keys():
                weights[key] = f.get_tensor(key)
    return weights

weights = load_safetensors()

Note that the weights are stored as `bfloat16`.

## Define the Flax model

Now we can define the model in Flax.

[This Transformer vs Llama diagram](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/_images/transformer_vs_llama.svg) from Nvidia visualizes the model architecture pretty nicely. We will define each layer using [Flax's NNX.Module](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module).

We will start by defining the RMS normalization layer. Note how we load the parameters from the `weights` dict.

In [6]:
class LlamaRMSNorm(nnx.Module):

    def __init__(self, name=None, layer_idx=None, rngs=None):
        if name is None and layer_idx is None:
            # Final normalization layer
            self.norm_weights = nnx.Param(weights["model.norm.weight"], rngs=rngs)
        else:
            self.norm_weights = nnx.Param(weights[f"model.layers.{layer_idx}.{name}.weight"], rngs=rngs)

    def __call__(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.astype(jnp.float32)
        squared_mean = jnp.mean(jnp.square(hidden_states), axis=-1, keepdims=True)
        hidden_states = hidden_states * jnp.reciprocal(jnp.sqrt(squared_mean + config.norm_eps))
        return self.norm_weights * hidden_states.astype(input_dtype)

Llama 3 uses [Rotary Position Embedding (RoPE)](https://arxiv.org/abs/2104.09864) to encode both token and positional embeddings. For a gentle introduction to RoPE, please refer to the [CMU lecture slides](https://www.cs.cmu.edu/~mgormley/courses/10423-s24//slides/lecture5-vit-ink.pdf) and this awesome [EleutherAI blog](https://blog.eleuther.ai/rotary-embeddings/).

In [7]:
class LlamaRotaryEmbedding(nnx.Module):

    def __init__(self, dim, base=10000, rngs=None):
        self.dim = dim
        self.base = base

    def __call__(self, position_ids):
        inv_freq = 1.0 / (self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim))
        inv_freq_expanded = jnp.expand_dims(inv_freq, axis=(0, 1))
        position_ids_expanded = jnp.expand_dims(position_ids, axis=(0, 2)).astype(jnp.float32)
        freqs = jnp.einsum('bij,bjk->bijk', position_ids_expanded, inv_freq_expanded)
        emb = jnp.concatenate([freqs, freqs], axis=-1)
        cos = jnp.cos(emb).squeeze(2).astype(jnp.bfloat16)
        sin = jnp.sin(emb).squeeze(2).astype(jnp.bfloat16)
        return cos, sin

Now we create the attention layers. Note how we load the weights into the q, k and v projection layers.

In [8]:
class LlamaAttention(nnx.Module):

    def __init__(self, layer_idx, rngs=None):
        self.q_proj = nnx.Linear(config.dim, config.n_heads * config.head_dim, use_bias=False, rngs=rngs)
        self.q_proj.kernel.value = weights[f"model.layers.{layer_idx}.self_attn.q_proj.weight"].T
        self.k_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs)
        self.k_proj.kernel.value = weights[f"model.layers.{layer_idx}.self_attn.k_proj.weight"].T
        self.v_proj = nnx.Linear(config.dim, config.n_kv_heads * config.head_dim, use_bias=False, rngs=rngs)
        self.v_proj.kernel.value = weights[f"model.layers.{layer_idx}.self_attn.v_proj.weight"].T
        self.o_proj = nnx.Linear(config.n_heads * config.head_dim, config.dim, use_bias=False, rngs=rngs)
        self.o_proj.kernel.value = weights[f"model.layers.{layer_idx}.self_attn.o_proj.weight"].T
        self.rotary_emb = LlamaRotaryEmbedding(config.head_dim, base=config.rope_theta, rngs=rngs)

    def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim=1):
        cos = jnp.expand_dims(cos, axis=unsqueeze_dim)
        sin = jnp.expand_dims(sin, axis=unsqueeze_dim)
        q_embed = (q * cos) + (self.rotate_half(q) * sin)
        k_embed = (k * cos) + (self.rotate_half(k) * sin)
        return q_embed, k_embed

    def rotate_half(self, x):
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return jnp.concatenate([-x2, x1], axis=-1)

    def repeat_kv(self, hidden_states, n_repeat):
        batch, n_kv_heads, seq_len, head_dim = hidden_states.shape
        if n_repeat == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].repeat(n_repeat, axis=2)
        return hidden_states.reshape(batch, n_kv_heads * n_repeat, seq_len, head_dim)

    def __call__(self, x, position_ids):
        batch_size, seq_len, _ = x.shape
        query = self.q_proj(x).reshape(batch_size, seq_len, config.n_heads, config.head_dim).transpose((0, 2, 1, 3))
        key = self.k_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))
        value = self.v_proj(x).reshape(batch_size, seq_len, config.n_kv_heads, config.head_dim).transpose((0, 2, 1, 3))
        # Assuming batch_size=1
        cos, sin = self.rotary_emb(position_ids[0])
        query, key = self.apply_rotary_pos_emb(query, key, cos, sin)

        key = self.repeat_kv(key, config.n_heads // config.n_kv_heads)
        value = self.repeat_kv(value, config.n_heads // config.n_kv_heads)

        attn_weights = jnp.matmul(query, jnp.transpose(key, (0, 1, 3, 2)))
        attn_weights = (attn_weights.astype(jnp.float32) / jnp.sqrt(config.head_dim)).astype(jnp.bfloat16)
        attn_weights = jax.nn.softmax(attn_weights.astype(jnp.float32), axis=-1).astype(jnp.bfloat16)
        attn_output = jnp.matmul(attn_weights, value).transpose((0, 2, 1, 3)).reshape(batch_size, seq_len, -1)
        output = self.o_proj(attn_output)
        return output

MLP layer follows the attention layer. Similarly we load the weights into the gate, up and down projection layers.

In [9]:
class LlamaMLP(nnx.Module):

    def __init__(self, layer_idx, rngs=None):
        self.gate_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs)
        self.gate_proj.kernel.value = weights[f"model.layers.{layer_idx}.mlp.gate_proj.weight"].T
        self.up_proj = nnx.Linear(config.dim, config.intermediate_size, use_bias=False, rngs=rngs)
        self.up_proj.kernel.value = weights[f"model.layers.{layer_idx}.mlp.up_proj.weight"].T
        self.down_proj = nnx.Linear(config.intermediate_size, config.dim, use_bias=False, rngs=rngs)
        self.down_proj.kernel.value = weights[f"model.layers.{layer_idx}.mlp.down_proj.weight"].T

    def __call__(self, x):
        return self.down_proj(jax.nn.silu(self.gate_proj(x)) * self.up_proj(x))

We assemble the decoder block.

In [10]:
class LlamaTransformerBlock(nnx.Module):

    def __init__(self, layer_idx, rngs=None):
        self.input_layernorm = LlamaRMSNorm(name="input_layernorm", layer_idx=layer_idx, rngs=rngs)
        self.attention = LlamaAttention(layer_idx=layer_idx, rngs=rngs)
        self.post_attention_layernorm = LlamaRMSNorm(name="post_attention_layernorm", layer_idx=layer_idx, rngs=rngs)
        self.mlp = LlamaMLP(layer_idx=layer_idx, rngs=rngs)

    def __call__(self, x, position_ids):
        residual = x
        x = self.input_layernorm(x)
        x = self.attention(x, position_ids)
        x = residual + x

        residual = x
        x = self.post_attention_layernorm(x)
        x = self.mlp(x)
        x = residual + x
        return x

Finally we have the enire model.

In [11]:
class LlamaForCausalLM(nnx.Module):

    def __init__(self, rngs=None):
        self.token_embed = nnx.Embed(num_embeddings=config.vocab_size, features=config.dim, dtype=jnp.bfloat16, rngs=rngs)
        self.token_embed.embedding.value = weights["model.embed_tokens.weight"]

        self.layers = [LlamaTransformerBlock(layer_idx=idx, rngs=rngs) for idx in range(config.n_layers)]
        self.lm_head = nnx.Linear(config.dim, config.vocab_size, use_bias=False, rngs=rngs)
        self.lm_head.kernel.value = weights["model.embed_tokens.weight"].T
        self.norm = LlamaRMSNorm(name=None, layer_idx=None, rngs=rngs)

    def __call__(self, input_ids, position_ids):
        assert input_ids.shape[0] == 1, "Only batch size 1 is supported"
        x = self.token_embed(input_ids)
        for layer in self.layers:
            x = layer(x, position_ids)
        x = self.norm(x)
        logits = self.lm_head(x)
        return logits

## Run the Flax model

Let's take it for a spin! We are still going to use the tokenizer from Hugging Face (since our primary focus is re-building the model instead of the tokenizer).

In [12]:
model = LlamaForCausalLM(rngs=nnx.Rngs(0))

tokenizer = AutoTokenizer.from_pretrained(model_id)

tokenizer_config.json:   0%|          | 0.00/54.5k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/296 [00:00<?, ?B/s]

Llama instruction-tuned model uses a [chat template](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct/discussions/14), so we need to follow the template. Here we are doing it manually for demonstration purpose, in reality you should just use the `apply_chat_template` method from the `transformers` library.

In [13]:
input_text = """<|start_header_id|>system<|end_header_id|>

You are a helpful assistant<|eot_id|><|start_header_id|>user<|end_header_id|>

How do you make a pancake?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

"""

input_ids = tokenizer(input_text, return_tensors="jax")["input_ids"]
position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])

for _ in range(200):
    logits = model(input_ids, position_ids)
    next_token = jnp.argmax(logits[:, -1, :], axis=-1)
    input_ids = jnp.concatenate([input_ids, next_token[:, None]], axis=1)
    position_ids = jnp.asarray([jnp.arange(input_ids.shape[1])])
    print(f"Generated token: {next_token[0]}")

print(tokenizer.decode(input_ids[0]))

TensorFlow and JAX classes are deprecated and will be removed in Transformers v5. We recommend migrating to PyTorch classes or pinning your version of Transformers.


[[128000 128006   9125 128007    271   2675    527    264  11190  18328
  128009 128006    882 128007    271   4438    656    499   1304    264
   54574    731     30 128009 128006  78191 128007    271]]
Generated token: 43346
Generated token: 264
Generated token: 54574
Generated token: 731
Generated token: 374
Generated token: 264
Generated token: 4382
Generated token: 1920
Generated token: 430
Generated token: 7612
Generated token: 1120
Generated token: 264
Generated token: 2478
Generated token: 14293
Generated token: 323
Generated token: 1063
Generated token: 6913
Generated token: 17677
Generated token: 7512
Generated token: 13
Generated token: 5810
Generated token: 596
Generated token: 264
Generated token: 3094
Generated token: 14656
Generated token: 30308
Generated token: 8641
Generated token: 1473
Generated token: 334
Generated token: 46847
Generated token: 25
Generated token: 57277
Generated token: 9
Generated token: 220
Generated token: 16
Generated token: 10747
Generated token

There you have it. We have successfully converted the Hugging Face model weights from the safetensors file, loaded them up in our JAX model, and run the model.

You might have noticed that the execution speed is quite slow. This is because, for simplicity, we have left out many optimizations (JIT, KV cache, SPMD and etc.) to speed things up. Feel free to implement them as an exercise.