# FrawdLLM Inference on TPU (JAX)

This notebook implements transformer inference from scratch using JAX, running on Google TPU.

**Setup:** Runtime > Change runtime type > TPU

In [None]:
# Install dependencies
!pip install huggingface_hub safetensors tokenizers

In [None]:
# Verify TPU is available
import jax
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

## Load Weights from HuggingFace

In [None]:
from huggingface_hub import hf_hub_download

# Download weights and tokenizer from HuggingFace
weights_path = hf_hub_download(repo_id="tsingla98/frawdllm-100m", filename="model.safetensors")
tokenizer_path = hf_hub_download(repo_id="tsingla98/frawdllm-100m", filename="tokenizer.json")

print(f"Weights: {weights_path}")
print(f"Tokenizer: {tokenizer_path}")

In [None]:
from safetensors import safe_open
import jax.numpy as jnp

# Load weights into JAX arrays
weights = {}
with safe_open(weights_path, framework="numpy") as f:
    for key in f.keys():
        weights[key] = jnp.array(f.get_tensor(key))
        print(f"{key}: {weights[key].shape}")

In [None]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file(tokenizer_path)

## Model Constants

In [None]:
import jax.numpy as jnp
import jax
import math

EMBEDDINGS_WEIGHT_KEY = "model.embeddings.token_emb.weight"

LN1_WEIGHT_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln1.weight"
LN1_BIAS_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln1.bias"

LN2_WEIGHT_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln2.weight"
LN2_BIAS_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln2.bias"

LNF_WEIGHT_FORMAT_KEY = "model.ln_f.weight"
LNF_BIAS_FORMAT_KEY = "model.ln_f.bias"

LM_HEAD_WEIGHT_KEY = "model.embeddings.token_emb.weight"  # tied weights

QKV_WEIGHTS_KEY = "model.blocks.{}.attn.qkv_proj.weight"
QKV_BIAS_KEY = "model.blocks.{}.attn.qkv_proj.bias"

OUTPUT_PROJ_BIAS_KEY = "model.blocks.{}.attn.out_proj.bias"
OUTPUT_PROJ_WEIGHT_KEY = "model.blocks.{}.attn.out_proj.weight"

FC1_WEIGHT_KEY = "model.blocks.{}.mlp.fc1.weight"
FC1_BIAS_KEY = "model.blocks.{}.mlp.fc1.bias"

FC2_WEIGHT_KEY = "model.blocks.{}.mlp.fc2.weight"
FC2_BIAS_KEY = "model.blocks.{}.mlp.fc2.bias"

STOP_TOKEN_ID = 3

EPSILON = 1e-5
N_HEADS = 12
N_LAYERS = 12
HEAD_DIM = 64
EMBEDDINGS_DIM = N_HEADS * HEAD_DIM
ROPE_THETA = 10000.0
TEMPERATURE = 0.5
TOP_P = 0.9

MAX_OUTPUT_TOKENS = 300
MAX_CONTEXT_LENGTH = 4096

## Inference Implementation (JAX)

In [None]:
import functools

def format_prompt(user_message: str) -> str:
    return f"<|bos|><|user|>{user_message}<|assistant|>"


def get_weights_tensor(key: str) -> jnp.ndarray:
    return weights[key]


def get_tokens_for_prompt(prompt: str) -> jnp.ndarray:
    return jnp.array(tokenizer.encode(prompt, add_special_tokens=False).ids)


def get_embeddings_for_selection(selection: jnp.ndarray) -> jnp.ndarray:
    return get_weights_tensor(EMBEDDINGS_WEIGHT_KEY)[selection]


# Precompute RoPE frequencies for max context length
def _compute_rope_freqs(max_len: int) -> jnp.ndarray:
    dim_indices = jnp.arange(0, HEAD_DIM, 2).astype(jnp.float32)
    freqs = 1.0 / (ROPE_THETA ** (dim_indices / HEAD_DIM))
    positions = jnp.arange(max_len).astype(jnp.float32)
    angles = jnp.outer(positions, freqs)
    return angles

ROPE_FREQS = _compute_rope_freqs(MAX_CONTEXT_LENGTH)


def apply_rope(x: jnp.ndarray, angles: jnp.ndarray) -> jnp.ndarray:
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    cos = jnp.cos(angles)
    sin = jnp.sin(angles)
    x_even_rot = x_even * cos - x_odd * sin
    x_odd_rot = x_even * sin + x_odd * cos
    out = jnp.stack([x_even_rot, x_odd_rot], axis=-1)
    out = out.reshape(x.shape)
    return out


# === PREFILL: Process full prompt (first pass) ===
@functools.partial(jax.jit, static_argnums=(0,))
def prefill_attention(
    layer_num: int, x: jnp.ndarray, k_cache: jnp.ndarray, v_cache: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]:
    seq_len = x.shape[0]

    gamma = get_weights_tensor(LN1_WEIGHT_FORMAT_PRE_ATTN_KEY.format(layer_num))
    beta = get_weights_tensor(LN1_BIAS_FORMAT_PRE_ATTN_KEY.format(layer_num))
    mean = x.mean(axis=-1, keepdims=True)
    std = x.std(axis=-1, keepdims=True)
    output = (x - mean) / (std + EPSILON)
    output = output * gamma + beta

    w_qkv = get_weights_tensor(QKV_WEIGHTS_KEY.format(layer_num))
    qkv = output @ w_qkv.T
    qkv_bias = get_weights_tensor(QKV_BIAS_KEY.format(layer_num))
    qkv = qkv + qkv_bias

    q, k, v = jnp.split(qkv, 3, axis=-1)
    q_batched = q.reshape(seq_len, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    k_batched = k.reshape(seq_len, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    v_batched = v.reshape(seq_len, N_HEADS, HEAD_DIM).swapaxes(0, 1)

    # Use precomputed RoPE, slice to seq_len
    angles = ROPE_FREQS[:seq_len]
    q_batched = apply_rope(q_batched, angles)
    k_batched = apply_rope(k_batched, angles)

    # Store in cache
    k_cache = k_cache.at[:seq_len, :, :].set(k_batched.swapaxes(0, 1))
    v_cache = v_cache.at[:seq_len, :, :].set(v_batched.swapaxes(0, 1))

    # Attention with causal mask
    res_batched = q_batched @ k_batched.swapaxes(-2, -1)
    mask = jnp.triu(jnp.ones([seq_len, seq_len]), k=1) * -1e9
    res_batched = res_batched + mask
    res_batched = res_batched / math.sqrt(HEAD_DIM)
    res_batched = jax.nn.softmax(res_batched, axis=-1)

    out = res_batched @ v_batched
    out = out.swapaxes(0, 1).reshape(seq_len, EMBEDDINGS_DIM)

    output_proj_weights = get_weights_tensor(OUTPUT_PROJ_WEIGHT_KEY.format(layer_num))
    output_proj_bias = get_weights_tensor(OUTPUT_PROJ_BIAS_KEY.format(layer_num))
    out = out @ output_proj_weights.T + output_proj_bias + x

    return out, k_cache, v_cache, seq_len


# === DECODE: Process single token (subsequent passes) ===
# Note: cache_len is NOT static - we use masking instead of dynamic slicing
@functools.partial(jax.jit, static_argnums=(0,))
def decode_attention(
    layer_num: int, x: jnp.ndarray, k_cache: jnp.ndarray, v_cache: jnp.ndarray, cache_len: jnp.ndarray
) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    # x is [1, 768] - single token
    # cache_len is a scalar JAX array (traced)
    new_seq_len = cache_len + 1

    gamma = get_weights_tensor(LN1_WEIGHT_FORMAT_PRE_ATTN_KEY.format(layer_num))
    beta = get_weights_tensor(LN1_BIAS_FORMAT_PRE_ATTN_KEY.format(layer_num))
    mean = x.mean(axis=-1, keepdims=True)
    std = x.std(axis=-1, keepdims=True)
    output = (x - mean) / (std + EPSILON)
    output = output * gamma + beta

    w_qkv = get_weights_tensor(QKV_WEIGHTS_KEY.format(layer_num))
    qkv = output @ w_qkv.T
    qkv_bias = get_weights_tensor(QKV_BIAS_KEY.format(layer_num))
    qkv = qkv + qkv_bias

    q, k, v = jnp.split(qkv, 3, axis=-1)
    q_batched = q.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)  # [N_HEADS, 1, HEAD_DIM]
    k_batched = k.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    v_batched = v.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)

    # RoPE for position cache_len (the new token's position) - use precomputed, slice dynamically
    angles = jax.lax.dynamic_slice(ROPE_FREQS, (cache_len, 0), (1, HEAD_DIM // 2))
    q_batched = apply_rope(q_batched, angles)
    k_batched = apply_rope(k_batched, angles)

    # Append to cache
    k_cache = k_cache.at[cache_len, :, :].set(k_batched.swapaxes(0, 1)[0])
    v_cache = v_cache.at[cache_len, :, :].set(v_batched.swapaxes(0, 1)[0])

    # Use FULL cache but mask out positions beyond new_seq_len
    # k_cache: [MAX_CONTEXT_LENGTH, N_HEADS, HEAD_DIM]
    k_full = k_cache.swapaxes(0, 1)  # [N_HEADS, MAX_CONTEXT_LENGTH, HEAD_DIM]
    v_full = v_cache.swapaxes(0, 1)  # [N_HEADS, MAX_CONTEXT_LENGTH, HEAD_DIM]

    # Attention with masking
    # q_batched: [N_HEADS, 1, HEAD_DIM]
    # k_full: [N_HEADS, MAX_CONTEXT_LENGTH, HEAD_DIM]
    res_batched = q_batched @ k_full.swapaxes(-2, -1)  # [N_HEADS, 1, MAX_CONTEXT_LENGTH]
    res_batched = res_batched / math.sqrt(HEAD_DIM)
    
    # Mask out positions >= new_seq_len (they're garbage/zero)
    positions = jnp.arange(MAX_CONTEXT_LENGTH)
    mask = jnp.where(positions < new_seq_len, 0.0, -1e9)  # [MAX_CONTEXT_LENGTH]
    res_batched = res_batched + mask  # broadcast to [N_HEADS, 1, MAX_CONTEXT_LENGTH]
    
    res_batched = jax.nn.softmax(res_batched, axis=-1)

    out = res_batched @ v_full  # [N_HEADS, 1, HEAD_DIM]
    out = out.swapaxes(0, 1).reshape(1, EMBEDDINGS_DIM)

    output_proj_weights = get_weights_tensor(OUTPUT_PROJ_WEIGHT_KEY.format(layer_num))
    output_proj_bias = get_weights_tensor(OUTPUT_PROJ_BIAS_KEY.format(layer_num))
    out = out @ output_proj_weights.T + output_proj_bias + x

    return out, k_cache, v_cache, new_seq_len


@functools.partial(jax.jit, static_argnums=(0,))
def process_mlp(layer_num: int, x: jnp.ndarray) -> jnp.ndarray:
    gamma = get_weights_tensor(LN2_WEIGHT_FORMAT_PRE_ATTN_KEY.format(layer_num))
    beta = get_weights_tensor(LN2_BIAS_FORMAT_PRE_ATTN_KEY.format(layer_num))
    mean = x.mean(axis=-1, keepdims=True)
    std = x.std(axis=-1, keepdims=True)
    output = (x - mean) / (std + EPSILON)
    output = output * gamma + beta
    fc1_weights = get_weights_tensor(FC1_WEIGHT_KEY.format(layer_num))
    fc1_bias = get_weights_tensor(FC1_BIAS_KEY.format(layer_num))
    fc2_weights = get_weights_tensor(FC2_WEIGHT_KEY.format(layer_num))
    fc2_bias = get_weights_tensor(FC2_BIAS_KEY.format(layer_num))
    output = output @ fc1_weights.T + fc1_bias
    output = jax.nn.gelu(output)
    output = output @ fc2_weights.T + fc2_bias
    return output + x


@jax.jit
def process_final_layer(x: jnp.ndarray) -> jnp.ndarray:
    gamma = get_weights_tensor(LNF_WEIGHT_FORMAT_KEY)
    beta = get_weights_tensor(LNF_BIAS_FORMAT_KEY)
    mean = x.mean(axis=-1, keepdims=True)
    std = x.std(axis=-1, keepdims=True)
    output = (x - mean) / (std + EPSILON)
    output = output * gamma + beta
    lm_head = get_weights_tensor(LM_HEAD_WEIGHT_KEY)
    return output @ lm_head.T


def main() -> None:
    prompt = "Hello, World!"
    prompt = format_prompt(prompt)
    initial_tokens = get_tokens_for_prompt(prompt)

    all_tokens = jnp.zeros(len(initial_tokens) + MAX_OUTPUT_TOKENS, dtype=jnp.int32)
    all_tokens = all_tokens.at[: len(initial_tokens)].set(initial_tokens)
    num_tokens = len(initial_tokens)
    prompt_len = len(initial_tokens)

    k_caches = [jnp.zeros((MAX_CONTEXT_LENGTH, N_HEADS, HEAD_DIM)) for _ in range(N_LAYERS)]
    v_caches = [jnp.zeros((MAX_CONTEXT_LENGTH, N_HEADS, HEAD_DIM)) for _ in range(N_LAYERS)]
    # Use JAX arrays for cache_lens so they can be traced
    cache_lens = [jnp.array(0, dtype=jnp.int32) for _ in range(N_LAYERS)]

    key = jax.random.PRNGKey(42)
    is_prefill = True

    for _ in range(min(MAX_OUTPUT_TOKENS, MAX_CONTEXT_LENGTH - prompt_len)):
        tokens = all_tokens[:num_tokens]
        
        if is_prefill:
            # First pass: process full prompt
            x = get_embeddings_for_selection(tokens)
            for layer_num in range(N_LAYERS):
                x, k_caches[layer_num], v_caches[layer_num], cache_len_out = prefill_attention(
                    layer_num, x, k_caches[layer_num], v_caches[layer_num]
                )
                cache_lens[layer_num] = jnp.array(cache_len_out, dtype=jnp.int32)
                x = process_mlp(layer_num, x)
            is_prefill = False
        else:
            # Subsequent passes: process only last token
            x = get_embeddings_for_selection(tokens[-1:])
            for layer_num in range(N_LAYERS):
                x, k_caches[layer_num], v_caches[layer_num], cache_lens[layer_num] = decode_attention(
                    layer_num, x, k_caches[layer_num], v_caches[layer_num], cache_lens[layer_num]
                )
                x = process_mlp(layer_num, x)

        x = process_final_layer(x)

        logits = x[-1] / TEMPERATURE
        probs = jax.nn.softmax(logits, axis=-1)

        top_probs, top_indices = jax.lax.top_k(probs, k=100)
        sorted_indices = jnp.argsort(top_probs)[::-1]
        sorted_probs = top_probs[sorted_indices]
        cum_probs = jnp.cumsum(sorted_probs, axis=0)
        cum_probs_keep = cum_probs <= TOP_P
        cum_probs_keep = cum_probs_keep.at[0].set(True)
        sorted_probs = jnp.where(cum_probs_keep, sorted_probs, 0.0)

        key, subkey = jax.random.split(key)
        sampled_idx = jax.random.categorical(subkey, jnp.log(sorted_probs + 1e-10))
        next_token = top_indices[sorted_indices[sampled_idx]]

        if next_token == STOP_TOKEN_ID:
            break
        all_tokens = all_tokens.at[num_tokens].set(next_token)
        num_tokens += 1

    output_tokens = all_tokens[prompt_len:num_tokens].tolist()
    output_words = tokenizer.decode(output_tokens)
    print(output_words)


# Run it!
main()