# FrawdLLM Inference on TPU (JAX)

This notebook implements transformer inference from scratch using JAX, running on Google TPU.

**Setup:** Runtime > Change runtime type > TPU

In [None]:
# Install dependencies
!pip install huggingface_hub safetensors tokenizers

In [None]:
# Verify TPU is available
import jax
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

## Load Weights from HuggingFace

In [None]:
from huggingface_hub import hf_hub_download

# Download weights and tokenizer from HuggingFace
weights_path = hf_hub_download(repo_id="tsingla98/frawdllm-100m", filename="model.safetensors")
tokenizer_path = hf_hub_download(repo_id="tsingla98/frawdllm-100m", filename="tokenizer.json")

print(f"Weights: {weights_path}")
print(f"Tokenizer: {tokenizer_path}")

In [None]:
from safetensors import safe_open
import jax.numpy as jnp

# Load weights into JAX arrays
weights = {}
with safe_open(weights_path, framework="numpy") as f:
    for key in f.keys():
        weights[key] = jnp.array(f.get_tensor(key))
        print(f"{key}: {weights[key].shape}")

In [None]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file(tokenizer_path)

## Model Constants

In [None]:
import jax.numpy as jnp
import jax
import math

EMBEDDINGS_WEIGHT_KEY = "model.embeddings.token_emb.weight"

LN1_WEIGHT_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln1.weight"
LN1_BIAS_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln1.bias"

LN2_WEIGHT_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln2.weight"
LN2_BIAS_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln2.bias"

LNF_WEIGHT_FORMAT_KEY = "model.ln_f.weight"
LNF_BIAS_FORMAT_KEY = "model.ln_f.bias"

LM_HEAD_WEIGHT_KEY = "model.embeddings.token_emb.weight"  # tied weights

QKV_WEIGHTS_KEY = "model.blocks.{}.attn.qkv_proj.weight"
QKV_BIAS_KEY = "model.blocks.{}.attn.qkv_proj.bias"

OUTPUT_PROJ_BIAS_KEY = "model.blocks.{}.attn.out_proj.bias"
OUTPUT_PROJ_WEIGHT_KEY = "model.blocks.{}.attn.out_proj.weight"

FC1_WEIGHT_KEY = "model.blocks.{}.mlp.fc1.weight"
FC1_BIAS_KEY = "model.blocks.{}.mlp.fc1.bias"

FC2_WEIGHT_KEY = "model.blocks.{}.mlp.fc2.weight"
FC2_BIAS_KEY = "model.blocks.{}.mlp.fc2.bias"

STOP_TOKEN_ID = 3

EPSILON = 1e-5
N_HEADS = 12
N_LAYERS = 12
HEAD_DIM = 64
EMBEDDINGS_DIM = N_HEADS * HEAD_DIM
ROPE_THETA = 10000.0
TEMPERATURE = 0.5
TOP_P = 0.9

MAX_OUTPUT_TOKENS = 300
MAX_CONTEXT_LENGTH = 4096

## Inference Implementation (JAX)

In [None]:
import functools

def format_prompt(user_message: str) -> str:
    return f"<|bos|><|user|>{user_message}<|assistant|>"


def get_weights_tensor(key: str) -> jnp.ndarray:
    return weights[key]


def get_tokens_for_prompt(prompt: str) -> jnp.ndarray:
    return jnp.array(tokenizer.encode(prompt, add_special_tokens=False).ids)


# Precompute RoPE frequencies for max context length
def _compute_rope_freqs(max_len: int) -> jnp.ndarray:
    dim_indices = jnp.arange(0, HEAD_DIM, 2).astype(jnp.float32)
    freqs = 1.0 / (ROPE_THETA ** (dim_indices / HEAD_DIM))
    positions = jnp.arange(max_len).astype(jnp.float32)
    angles = jnp.outer(positions, freqs)
    return angles

ROPE_FREQS = _compute_rope_freqs(MAX_CONTEXT_LENGTH)

# Fixed prompt padding size - compile once, run with any prompt up to this size
PROMPT_PAD_SIZE = 256


# === Stack all layer weights for vectorized access ===
def _stack_weights():
    """Stack weights across layers for efficient access."""
    ln1_weights = jnp.stack([get_weights_tensor(LN1_WEIGHT_FORMAT_PRE_ATTN_KEY.format(i)) for i in range(N_LAYERS)])
    ln1_biases = jnp.stack([get_weights_tensor(LN1_BIAS_FORMAT_PRE_ATTN_KEY.format(i)) for i in range(N_LAYERS)])
    ln2_weights = jnp.stack([get_weights_tensor(LN2_WEIGHT_FORMAT_PRE_ATTN_KEY.format(i)) for i in range(N_LAYERS)])
    ln2_biases = jnp.stack([get_weights_tensor(LN2_BIAS_FORMAT_PRE_ATTN_KEY.format(i)) for i in range(N_LAYERS)])
    qkv_weights = jnp.stack([get_weights_tensor(QKV_WEIGHTS_KEY.format(i)) for i in range(N_LAYERS)])
    qkv_biases = jnp.stack([get_weights_tensor(QKV_BIAS_KEY.format(i)) for i in range(N_LAYERS)])
    out_weights = jnp.stack([get_weights_tensor(OUTPUT_PROJ_WEIGHT_KEY.format(i)) for i in range(N_LAYERS)])
    out_biases = jnp.stack([get_weights_tensor(OUTPUT_PROJ_BIAS_KEY.format(i)) for i in range(N_LAYERS)])
    fc1_weights = jnp.stack([get_weights_tensor(FC1_WEIGHT_KEY.format(i)) for i in range(N_LAYERS)])
    fc1_biases = jnp.stack([get_weights_tensor(FC1_BIAS_KEY.format(i)) for i in range(N_LAYERS)])
    fc2_weights = jnp.stack([get_weights_tensor(FC2_WEIGHT_KEY.format(i)) for i in range(N_LAYERS)])
    fc2_biases = jnp.stack([get_weights_tensor(FC2_BIAS_KEY.format(i)) for i in range(N_LAYERS)])
    
    return {
        'ln1_w': ln1_weights, 'ln1_b': ln1_biases,
        'ln2_w': ln2_weights, 'ln2_b': ln2_biases,
        'qkv_w': qkv_weights, 'qkv_b': qkv_biases,
        'out_w': out_weights, 'out_b': out_biases,
        'fc1_w': fc1_weights, 'fc1_b': fc1_biases,
        'fc2_w': fc2_weights, 'fc2_b': fc2_biases,
        'ln_f_w': get_weights_tensor(LNF_WEIGHT_FORMAT_KEY),
        'ln_f_b': get_weights_tensor(LNF_BIAS_FORMAT_KEY),
        'embed': get_weights_tensor(EMBEDDINGS_WEIGHT_KEY),
        'lm_head': get_weights_tensor(LM_HEAD_WEIGHT_KEY),
    }

STACKED_WEIGHTS = _stack_weights()


def apply_rope(x: jnp.ndarray, angles: jnp.ndarray) -> jnp.ndarray:
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    cos = jnp.cos(angles)
    sin = jnp.sin(angles)
    x_even_rot = x_even * cos - x_odd * sin
    x_odd_rot = x_even * sin + x_odd * cos
    out = jnp.stack([x_even_rot, x_odd_rot], axis=-1)
    out = out.reshape(x.shape)
    return out


def layer_norm(x, gamma, beta):
    mean = x.mean(axis=-1, keepdims=True)
    std = x.std(axis=-1, keepdims=True)
    return (x - mean) / (std + EPSILON) * gamma + beta


def decode_layer(layer_idx, x, k_cache, v_cache, cache_len, W):
    """Process one layer during decode (single token)."""
    normed = layer_norm(x, W['ln1_w'][layer_idx], W['ln1_b'][layer_idx])
    
    qkv = normed @ W['qkv_w'][layer_idx].T + W['qkv_b'][layer_idx]
    q, k, v = jnp.split(qkv, 3, axis=-1)
    
    q = q.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    k = k.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    v = v.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    
    angles = jax.lax.dynamic_slice(ROPE_FREQS, (cache_len, 0), (1, HEAD_DIM // 2))
    q = apply_rope(q, angles)
    k = apply_rope(k, angles)
    
    k_cache = k_cache.at[layer_idx, cache_len, :, :].set(k.swapaxes(0, 1)[0])
    v_cache = v_cache.at[layer_idx, cache_len, :, :].set(v.swapaxes(0, 1)[0])
    
    k_full = k_cache[layer_idx].swapaxes(0, 1)
    v_full = v_cache[layer_idx].swapaxes(0, 1)
    
    scores = (q @ k_full.swapaxes(-2, -1)) / math.sqrt(HEAD_DIM)
    positions = jnp.arange(MAX_CONTEXT_LENGTH)
    mask = jnp.where(positions <= cache_len, 0.0, -1e9)
    scores = scores + mask
    
    attn = jax.nn.softmax(scores, axis=-1)
    out = (attn @ v_full).swapaxes(0, 1).reshape(1, EMBEDDINGS_DIM)
    
    out = out @ W['out_w'][layer_idx].T + W['out_b'][layer_idx] + x
    
    normed2 = layer_norm(out, W['ln2_w'][layer_idx], W['ln2_b'][layer_idx])
    hidden = jax.nn.gelu(normed2 @ W['fc1_w'][layer_idx].T + W['fc1_b'][layer_idx])
    out = hidden @ W['fc2_w'][layer_idx].T + W['fc2_b'][layer_idx] + out
    
    return out, k_cache, v_cache


def prefill_layer_padded(layer_idx, x, k_cache, v_cache, actual_len, W):
    """Process one layer during prefill with padding support."""
    # x: [PROMPT_PAD_SIZE, 768], actual_len is traced
    pad_size = PROMPT_PAD_SIZE
    
    normed = layer_norm(x, W['ln1_w'][layer_idx], W['ln1_b'][layer_idx])
    
    qkv = normed @ W['qkv_w'][layer_idx].T + W['qkv_b'][layer_idx]
    q, k, v = jnp.split(qkv, 3, axis=-1)
    
    q = q.reshape(pad_size, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    k = k.reshape(pad_size, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    v = v.reshape(pad_size, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    
    # RoPE - use full padded size
    angles = ROPE_FREQS[:pad_size]
    q = apply_rope(q, angles)
    k = apply_rope(k, angles)
    
    # Store in cache (only actual_len matters, but store all for simplicity)
    k_cache = k_cache.at[layer_idx, :pad_size, :, :].set(k.swapaxes(0, 1))
    v_cache = v_cache.at[layer_idx, :pad_size, :, :].set(v.swapaxes(0, 1))
    
    # Attention with combined causal + padding mask
    scores = (q @ k.swapaxes(-2, -1)) / math.sqrt(HEAD_DIM)
    
    # Causal mask: position i can only attend to j where j <= i
    row_idx = jnp.arange(pad_size)[:, None]  # [pad_size, 1]
    col_idx = jnp.arange(pad_size)[None, :]  # [1, pad_size]
    causal_mask = jnp.where(col_idx <= row_idx, 0.0, -1e9)  # [pad_size, pad_size]
    
    # Padding mask: can only attend to positions < actual_len
    padding_mask = jnp.where(col_idx < actual_len, 0.0, -1e9)  # [1, pad_size]
    
    # Combined mask
    combined_mask = causal_mask + padding_mask  # broadcasts to [pad_size, pad_size]
    scores = scores + combined_mask
    
    attn = jax.nn.softmax(scores, axis=-1)
    out = (attn @ v).swapaxes(0, 1).reshape(pad_size, EMBEDDINGS_DIM)
    
    out = out @ W['out_w'][layer_idx].T + W['out_b'][layer_idx] + x
    
    normed2 = layer_norm(out, W['ln2_w'][layer_idx], W['ln2_b'][layer_idx])
    hidden = jax.nn.gelu(normed2 @ W['fc1_w'][layer_idx].T + W['fc1_b'][layer_idx])
    out = hidden @ W['fc2_w'][layer_idx].T + W['fc2_b'][layer_idx] + out
    
    return out, k_cache, v_cache


def sample_token(logits, key):
    """Sample next token with temperature and top-p."""
    probs = jax.nn.softmax(logits / TEMPERATURE)
    top_probs, top_indices = jax.lax.top_k(probs, k=100)
    sorted_idx = jnp.argsort(top_probs)[::-1]
    sorted_probs = top_probs[sorted_idx]
    cum_probs = jnp.cumsum(sorted_probs)
    mask = (cum_probs <= TOP_P).at[0].set(True)
    sorted_probs = jnp.where(mask, sorted_probs, 0.0)
    
    key, subkey = jax.random.split(key)
    sampled = jax.random.categorical(subkey, jnp.log(sorted_probs + 1e-10))
    next_token = top_indices[sorted_idx[sampled]]
    
    return next_token, key


@functools.partial(jax.jit, static_argnums=(2,))
def generate(padded_tokens, actual_len, max_new_tokens, W, key):
    """
    Full generation loop in XLA - no Python sync points.
    
    Args:
        padded_tokens: [PROMPT_PAD_SIZE] - prompt tokens padded with zeros
        actual_len: traced int - actual prompt length
        max_new_tokens: static int - max tokens to generate
        W: weights dict
        key: PRNG key
    
    Returns: (output_tokens, num_generated)
    """
    # Initialize KV cache
    k_cache = jnp.zeros((N_LAYERS, MAX_CONTEXT_LENGTH, N_HEADS, HEAD_DIM))
    v_cache = jnp.zeros((N_LAYERS, MAX_CONTEXT_LENGTH, N_HEADS, HEAD_DIM))
    
    # === PREFILL (padded) ===
    x = W['embed'][padded_tokens]  # [PROMPT_PAD_SIZE, 768]
    
    for layer_idx in range(N_LAYERS):
        x, k_cache, v_cache = prefill_layer_padded(layer_idx, x, k_cache, v_cache, actual_len, W)
    
    x = layer_norm(x, W['ln_f_w'], W['ln_f_b'])
    
    # Get logits for last REAL token (not last padded position)
    last_hidden = jax.lax.dynamic_slice(x, (actual_len - 1, 0), (1, EMBEDDINGS_DIM))[0]
    logits = last_hidden @ W['lm_head'].T
    
    # Sample first token
    first_token, key = sample_token(logits, key)
    
    # Initialize output buffer
    output_tokens = jnp.zeros(max_new_tokens, dtype=jnp.int32)
    output_tokens = output_tokens.at[0].set(first_token)
    
    # === DECODE with while_loop ===
    init_state = (
        output_tokens,
        jnp.array(1, dtype=jnp.int32),  # num_generated
        k_cache,
        v_cache,
        actual_len,  # cache_len starts at actual prompt length
        key,
        first_token == STOP_TOKEN_ID,  # done
    )
    
    def cond_fn(state):
        _, num_generated, _, _, _, _, done = state
        return jnp.logical_and(num_generated < max_new_tokens, ~done)
    
    def body_fn(state):
        output_tokens, num_generated, k_cache, v_cache, cache_len, key, _ = state
        
        last_token = output_tokens[num_generated - 1]
        x = W['embed'][last_token].reshape(1, -1)
        
        for layer_idx in range(N_LAYERS):
            x, k_cache, v_cache = decode_layer(layer_idx, x, k_cache, v_cache, cache_len, W)
        
        x = layer_norm(x, W['ln_f_w'], W['ln_f_b'])
        logits = x[0] @ W['lm_head'].T
        
        next_token, key = sample_token(logits, key)
        
        output_tokens = output_tokens.at[num_generated].set(next_token)
        done = next_token == STOP_TOKEN_ID
        
        return (output_tokens, num_generated + 1, k_cache, v_cache, cache_len + 1, key, done)
    
    final_state = jax.lax.while_loop(cond_fn, body_fn, init_state)
    
    output_tokens, num_generated, _, _, _, _, _ = final_state
    return output_tokens, num_generated


def pad_tokens(tokens, pad_size=PROMPT_PAD_SIZE):
    """Pad tokens to fixed size for JIT cache reuse."""
    padded = jnp.zeros(pad_size, dtype=jnp.int32)
    padded = padded.at[:len(tokens)].set(tokens)
    return padded, len(tokens)


def run_inference(prompt: str, key):
    """Helper to run inference on a prompt."""
    formatted = format_prompt(prompt)
    tokens = get_tokens_for_prompt(formatted)
    
    if len(tokens) > PROMPT_PAD_SIZE:
        raise ValueError(f"Prompt too long: {len(tokens)} > {PROMPT_PAD_SIZE}")
    
    padded, actual_len = pad_tokens(tokens)
    actual_len = jnp.array(actual_len, dtype=jnp.int32)
    
    output_tokens, num_generated = generate(padded, actual_len, MAX_OUTPUT_TOKENS, STACKED_WEIGHTS, key)
    
    num_gen = int(num_generated)
    out_list = [int(output_tokens[i]) for i in range(num_gen) if int(output_tokens[i]) != STOP_TOKEN_ID]
    
    return tokenizer.decode(out_list), len(out_list)


def main() -> None:
    import time
    
    print(f"Prompt pad size: {PROMPT_PAD_SIZE}")
    print(f"Max new tokens: {MAX_OUTPUT_TOKENS}")
    print()
    
    # First run - triggers compilation
    print("=== First prompt (triggers compilation) ===")
    key = jax.random.PRNGKey(42)
    
    start = time.time()
    output, num_tokens = run_inference("Hello, World!", key)
    elapsed = time.time() - start
    
    print(f"Time: {elapsed:.3f}s (includes compilation)")
    print(f"Generated: {num_tokens} tokens")
    print(f"Output: {output}")
    print()
    
    # Second run - same compiled code, different prompt
    print("=== Second prompt (uses cached compilation) ===")
    key = jax.random.PRNGKey(123)
    
    start = time.time()
    output, num_tokens = run_inference("What is the capital of France?", key)
    elapsed = time.time() - start
    
    print(f"Time: {elapsed:.3f}s")
    print(f"Speed: {num_tokens/elapsed:.1f} tok/s")
    print(f"Output: {output}")
    print()
    
    # Third run - another prompt
    print("=== Third prompt (cached) ===")
    key = jax.random.PRNGKey(456)
    
    start = time.time()
    output, num_tokens = run_inference("Explain quantum computing in simple terms.", key)
    elapsed = time.time() - start
    
    print(f"Time: {elapsed:.3f}s")
    print(f"Speed: {num_tokens/elapsed:.1f} tok/s")
    print(f"Output: {output}")


main()