# FrawdLLM Inference on TPU (JAX)

This notebook implements transformer inference from scratch using JAX, running on Google TPU.

**Setup:** Runtime > Change runtime type > TPU

In [15]:
# Install dependencies
!pip install huggingface_hub safetensors tokenizers

zsh:1: command not found: pip


  pid, fd = os.forkpty()


In [16]:
# Verify TPU is available
import jax
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

Devices: [CpuDevice(id=0)]
Device count: 1


## Load Weights from HuggingFace

In [17]:
from huggingface_hub import hf_hub_download

# Download weights and tokenizer from HuggingFace
weights_path = hf_hub_download(repo_id="tsingla98/frawdllm-100m", filename="model.safetensors")
tokenizer_path = hf_hub_download(repo_id="tsingla98/frawdllm-100m", filename="tokenizer.json")

print(f"Weights: {weights_path}")
print(f"Tokenizer: {tokenizer_path}")

Weights: /Users/tushar/.cache/huggingface/hub/models--tsingla98--frawdllm-100m/snapshots/f74ba30cf03fb8928c261795cd7ee6c69e0f0e21/model.safetensors
Tokenizer: /Users/tushar/.cache/huggingface/hub/models--tsingla98--frawdllm-100m/snapshots/f74ba30cf03fb8928c261795cd7ee6c69e0f0e21/tokenizer.json


In [18]:
from safetensors import safe_open
import jax.numpy as jnp

# Load weights into JAX arrays
weights = {}
with safe_open(weights_path, framework="numpy") as f:
    for key in f.keys():
        weights[key] = jnp.array(f.get_tensor(key))
        print(f"{key}: {weights[key].shape}")

model.blocks.0.attn.mask: (1, 1, 4096, 4096)
model.blocks.0.attn.out_proj.bias: (768,)
model.blocks.0.attn.out_proj.weight: (768, 768)
model.blocks.0.attn.qkv_proj.bias: (2304,)
model.blocks.0.attn.qkv_proj.weight: (2304, 768)
model.blocks.0.ln1.bias: (768,)
model.blocks.0.ln1.weight: (768,)
model.blocks.0.ln2.bias: (768,)
model.blocks.0.ln2.weight: (768,)
model.blocks.0.mlp.fc1.bias: (3072,)
model.blocks.0.mlp.fc1.weight: (3072, 768)
model.blocks.0.mlp.fc2.bias: (768,)
model.blocks.0.mlp.fc2.weight: (768, 3072)
model.blocks.1.attn.mask: (1, 1, 4096, 4096)
model.blocks.1.attn.out_proj.bias: (768,)
model.blocks.1.attn.out_proj.weight: (768, 768)
model.blocks.1.attn.qkv_proj.bias: (2304,)
model.blocks.1.attn.qkv_proj.weight: (2304, 768)
model.blocks.1.ln1.bias: (768,)
model.blocks.1.ln1.weight: (768,)
model.blocks.1.ln2.bias: (768,)
model.blocks.1.ln2.weight: (768,)
model.blocks.1.mlp.fc1.bias: (3072,)
model.blocks.1.mlp.fc1.weight: (3072, 768)
model.blocks.1.mlp.fc2.bias: (768,)
model.b

In [19]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file(tokenizer_path)

## Model Constants

In [20]:
import jax.numpy as jnp
import jax
import math

EMBEDDINGS_WEIGHT_KEY = "model.embeddings.token_emb.weight"

LN1_WEIGHT_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln1.weight"
LN1_BIAS_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln1.bias"

LN2_WEIGHT_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln2.weight"
LN2_BIAS_FORMAT_PRE_ATTN_KEY = "model.blocks.{}.ln2.bias"

LNF_WEIGHT_FORMAT_KEY = "model.ln_f.weight"
LNF_BIAS_FORMAT_KEY = "model.ln_f.bias"

LM_HEAD_WEIGHT_KEY = "model.embeddings.token_emb.weight"  # tied weights

QKV_WEIGHTS_KEY = "model.blocks.{}.attn.qkv_proj.weight"
QKV_BIAS_KEY = "model.blocks.{}.attn.qkv_proj.bias"

OUTPUT_PROJ_BIAS_KEY = "model.blocks.{}.attn.out_proj.bias"
OUTPUT_PROJ_WEIGHT_KEY = "model.blocks.{}.attn.out_proj.weight"

FC1_WEIGHT_KEY = "model.blocks.{}.mlp.fc1.weight"
FC1_BIAS_KEY = "model.blocks.{}.mlp.fc1.bias"

FC2_WEIGHT_KEY = "model.blocks.{}.mlp.fc2.weight"
FC2_BIAS_KEY = "model.blocks.{}.mlp.fc2.bias"

STOP_TOKEN_ID = 3

EPSILON = 1e-5
N_HEADS = 12
N_LAYERS = 12
HEAD_DIM = 64
EMBEDDINGS_DIM = N_HEADS * HEAD_DIM
ROPE_THETA = 10000.0
TEMPERATURE = 0.5
TOP_P = 0.9

MAX_OUTPUT_TOKENS = 300
MAX_CONTEXT_LENGTH = 4096

## Inference Implementation (JAX)

In [21]:
import functools

def format_prompt(user_message: str) -> str:
    return f"<|bos|><|user|>{user_message}<|assistant|>"


def get_weights_tensor(key: str) -> jnp.ndarray:
    return weights[key]


def get_tokens_for_prompt(prompt: str) -> jnp.ndarray:
    return jnp.array(tokenizer.encode(prompt, add_special_tokens=False).ids)


# Precompute RoPE frequencies for max context length
def _compute_rope_freqs(max_len: int) -> jnp.ndarray:
    dim_indices = jnp.arange(0, HEAD_DIM, 2).astype(jnp.float32)
    freqs = 1.0 / (ROPE_THETA ** (dim_indices / HEAD_DIM))
    positions = jnp.arange(max_len).astype(jnp.float32)
    angles = jnp.outer(positions, freqs)
    return angles

ROPE_FREQS = _compute_rope_freqs(MAX_CONTEXT_LENGTH)


# === Stack all layer weights for vectorized access ===
def _stack_weights():
    """Stack weights across layers for efficient access."""
    ln1_weights = jnp.stack([get_weights_tensor(LN1_WEIGHT_FORMAT_PRE_ATTN_KEY.format(i)) for i in range(N_LAYERS)])
    ln1_biases = jnp.stack([get_weights_tensor(LN1_BIAS_FORMAT_PRE_ATTN_KEY.format(i)) for i in range(N_LAYERS)])
    ln2_weights = jnp.stack([get_weights_tensor(LN2_WEIGHT_FORMAT_PRE_ATTN_KEY.format(i)) for i in range(N_LAYERS)])
    ln2_biases = jnp.stack([get_weights_tensor(LN2_BIAS_FORMAT_PRE_ATTN_KEY.format(i)) for i in range(N_LAYERS)])
    qkv_weights = jnp.stack([get_weights_tensor(QKV_WEIGHTS_KEY.format(i)) for i in range(N_LAYERS)])
    qkv_biases = jnp.stack([get_weights_tensor(QKV_BIAS_KEY.format(i)) for i in range(N_LAYERS)])
    out_weights = jnp.stack([get_weights_tensor(OUTPUT_PROJ_WEIGHT_KEY.format(i)) for i in range(N_LAYERS)])
    out_biases = jnp.stack([get_weights_tensor(OUTPUT_PROJ_BIAS_KEY.format(i)) for i in range(N_LAYERS)])
    fc1_weights = jnp.stack([get_weights_tensor(FC1_WEIGHT_KEY.format(i)) for i in range(N_LAYERS)])
    fc1_biases = jnp.stack([get_weights_tensor(FC1_BIAS_KEY.format(i)) for i in range(N_LAYERS)])
    fc2_weights = jnp.stack([get_weights_tensor(FC2_WEIGHT_KEY.format(i)) for i in range(N_LAYERS)])
    fc2_biases = jnp.stack([get_weights_tensor(FC2_BIAS_KEY.format(i)) for i in range(N_LAYERS)])
    
    return {
        'ln1_w': ln1_weights, 'ln1_b': ln1_biases,
        'ln2_w': ln2_weights, 'ln2_b': ln2_biases,
        'qkv_w': qkv_weights, 'qkv_b': qkv_biases,
        'out_w': out_weights, 'out_b': out_biases,
        'fc1_w': fc1_weights, 'fc1_b': fc1_biases,
        'fc2_w': fc2_weights, 'fc2_b': fc2_biases,
        'ln_f_w': get_weights_tensor(LNF_WEIGHT_FORMAT_KEY),
        'ln_f_b': get_weights_tensor(LNF_BIAS_FORMAT_KEY),
        'embed': get_weights_tensor(EMBEDDINGS_WEIGHT_KEY),
        'lm_head': get_weights_tensor(LM_HEAD_WEIGHT_KEY),
    }

STACKED_WEIGHTS = _stack_weights()


def apply_rope(x: jnp.ndarray, angles: jnp.ndarray) -> jnp.ndarray:
    x_even = x[..., 0::2]
    x_odd = x[..., 1::2]
    cos = jnp.cos(angles)
    sin = jnp.sin(angles)
    x_even_rot = x_even * cos - x_odd * sin
    x_odd_rot = x_even * sin + x_odd * cos
    out = jnp.stack([x_even_rot, x_odd_rot], axis=-1)
    out = out.reshape(x.shape)
    return out


def layer_norm(x, gamma, beta):
    mean = x.mean(axis=-1, keepdims=True)
    std = x.std(axis=-1, keepdims=True)
    return (x - mean) / (std + EPSILON) * gamma + beta


def decode_layer(layer_idx, x, k_cache, v_cache, cache_len, W):
    """Process one layer during decode (single token)."""
    # LayerNorm 1
    normed = layer_norm(x, W['ln1_w'][layer_idx], W['ln1_b'][layer_idx])
    
    # QKV projection
    qkv = normed @ W['qkv_w'][layer_idx].T + W['qkv_b'][layer_idx]
    q, k, v = jnp.split(qkv, 3, axis=-1)
    
    q = q.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    k = k.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    v = v.reshape(1, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    
    # RoPE
    angles = jax.lax.dynamic_slice(ROPE_FREQS, (cache_len, 0), (1, HEAD_DIM // 2))
    q = apply_rope(q, angles)
    k = apply_rope(k, angles)
    
    # Update cache
    k_cache = k_cache.at[layer_idx, cache_len, :, :].set(k.swapaxes(0, 1)[0])
    v_cache = v_cache.at[layer_idx, cache_len, :, :].set(v.swapaxes(0, 1)[0])
    
    # Attention over full cache with mask
    k_full = k_cache[layer_idx].swapaxes(0, 1)
    v_full = v_cache[layer_idx].swapaxes(0, 1)
    
    scores = (q @ k_full.swapaxes(-2, -1)) / math.sqrt(HEAD_DIM)
    positions = jnp.arange(MAX_CONTEXT_LENGTH)
    mask = jnp.where(positions <= cache_len, 0.0, -1e9)
    scores = scores + mask
    
    attn = jax.nn.softmax(scores, axis=-1)
    out = (attn @ v_full).swapaxes(0, 1).reshape(1, EMBEDDINGS_DIM)
    
    # Output projection + residual
    out = out @ W['out_w'][layer_idx].T + W['out_b'][layer_idx] + x
    
    # MLP
    normed2 = layer_norm(out, W['ln2_w'][layer_idx], W['ln2_b'][layer_idx])
    hidden = jax.nn.gelu(normed2 @ W['fc1_w'][layer_idx].T + W['fc1_b'][layer_idx])
    out = hidden @ W['fc2_w'][layer_idx].T + W['fc2_b'][layer_idx] + out
    
    return out, k_cache, v_cache


def prefill_layer(layer_idx, x, k_cache, v_cache, W):
    """Process one layer during prefill (full sequence)."""
    seq_len = x.shape[0]
    
    normed = layer_norm(x, W['ln1_w'][layer_idx], W['ln1_b'][layer_idx])
    
    qkv = normed @ W['qkv_w'][layer_idx].T + W['qkv_b'][layer_idx]
    q, k, v = jnp.split(qkv, 3, axis=-1)
    
    q = q.reshape(seq_len, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    k = k.reshape(seq_len, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    v = v.reshape(seq_len, N_HEADS, HEAD_DIM).swapaxes(0, 1)
    
    angles = ROPE_FREQS[:seq_len]
    q = apply_rope(q, angles)
    k = apply_rope(k, angles)
    
    k_cache = k_cache.at[layer_idx, :seq_len, :, :].set(k.swapaxes(0, 1))
    v_cache = v_cache.at[layer_idx, :seq_len, :, :].set(v.swapaxes(0, 1))
    
    scores = (q @ k.swapaxes(-2, -1)) / math.sqrt(HEAD_DIM)
    mask = jnp.triu(jnp.ones((seq_len, seq_len)), k=1) * -1e9
    scores = scores + mask
    
    attn = jax.nn.softmax(scores, axis=-1)
    out = (attn @ v).swapaxes(0, 1).reshape(seq_len, EMBEDDINGS_DIM)
    
    out = out @ W['out_w'][layer_idx].T + W['out_b'][layer_idx] + x
    
    normed2 = layer_norm(out, W['ln2_w'][layer_idx], W['ln2_b'][layer_idx])
    hidden = jax.nn.gelu(normed2 @ W['fc1_w'][layer_idx].T + W['fc1_b'][layer_idx])
    out = hidden @ W['fc2_w'][layer_idx].T + W['fc2_b'][layer_idx] + out
    
    return out, k_cache, v_cache


def sample_token(logits, key):
    """Sample next token with temperature and top-p."""
    probs = jax.nn.softmax(logits / TEMPERATURE)
    top_probs, top_indices = jax.lax.top_k(probs, k=100)
    sorted_idx = jnp.argsort(top_probs)[::-1]
    sorted_probs = top_probs[sorted_idx]
    cum_probs = jnp.cumsum(sorted_probs)
    mask = (cum_probs <= TOP_P).at[0].set(True)
    sorted_probs = jnp.where(mask, sorted_probs, 0.0)
    
    key, subkey = jax.random.split(key)
    sampled = jax.random.categorical(subkey, jnp.log(sorted_probs + 1e-10))
    next_token = top_indices[sorted_idx[sampled]]
    
    return next_token, key


@functools.partial(jax.jit, static_argnums=(1, 2))
def generate(prompt_tokens, prompt_len, max_new_tokens, W, key):
    """
    Full generation loop in XLA - no Python sync points.
    
    Returns: (output_tokens, num_generated)
    """
    # Initialize KV cache
    k_cache = jnp.zeros((N_LAYERS, MAX_CONTEXT_LENGTH, N_HEADS, HEAD_DIM))
    v_cache = jnp.zeros((N_LAYERS, MAX_CONTEXT_LENGTH, N_HEADS, HEAD_DIM))
    
    # === PREFILL ===
    x = W['embed'][prompt_tokens]
    for layer_idx in range(N_LAYERS):
        x, k_cache, v_cache = prefill_layer(layer_idx, x, k_cache, v_cache, W)
    
    x = layer_norm(x, W['ln_f_w'], W['ln_f_b'])
    logits = x[-1] @ W['lm_head'].T
    
    # Sample first token
    first_token, key = sample_token(logits, key)
    
    # Initialize output buffer
    output_tokens = jnp.zeros(max_new_tokens, dtype=jnp.int32)
    output_tokens = output_tokens.at[0].set(first_token)
    
    # === DECODE with while_loop ===
    # State: (output_tokens, num_generated, k_cache, v_cache, cache_len, key, done)
    init_state = (
        output_tokens,
        jnp.array(1, dtype=jnp.int32),  # num_generated (already have 1)
        k_cache,
        v_cache,
        jnp.array(prompt_len, dtype=jnp.int32),  # cache_len
        key,
        first_token == STOP_TOKEN_ID,  # done
    )
    
    def cond_fn(state):
        _, num_generated, _, _, _, _, done = state
        return jnp.logical_and(num_generated < max_new_tokens, ~done)
    
    def body_fn(state):
        output_tokens, num_generated, k_cache, v_cache, cache_len, key, _ = state
        
        # Get the last generated token
        last_token = output_tokens[num_generated - 1]
        
        # Forward pass through all layers
        x = W['embed'][last_token].reshape(1, -1)
        
        for layer_idx in range(N_LAYERS):
            x, k_cache, v_cache = decode_layer(layer_idx, x, k_cache, v_cache, cache_len, W)
        
        x = layer_norm(x, W['ln_f_w'], W['ln_f_b'])
        logits = x[0] @ W['lm_head'].T
        
        # Sample
        next_token, key = sample_token(logits, key)
        
        # Update state
        output_tokens = output_tokens.at[num_generated].set(next_token)
        done = next_token == STOP_TOKEN_ID
        
        return (output_tokens, num_generated + 1, k_cache, v_cache, cache_len + 1, key, done)
    
    final_state = jax.lax.while_loop(cond_fn, body_fn, init_state)
    
    output_tokens, num_generated, _, _, _, _, _ = final_state
    return output_tokens, num_generated


def main() -> None:
    import time
    
    prompt = "Hello, World!"
    prompt = format_prompt(prompt)
    tokens = get_tokens_for_prompt(prompt)
    prompt_len = len(tokens)
    max_new = MAX_OUTPUT_TOKENS
    
    print(f"Prompt: {prompt}")
    print(f"Prompt length: {prompt_len} tokens")
    print(f"Max new tokens: {max_new}")
    print()
    
    # Warmup / compile
    print("Compiling (first run)...")
    key = jax.random.PRNGKey(42)
    start = time.time()
    output_tokens, num_generated = generate(tokens, prompt_len, max_new, STACKED_WEIGHTS, key)
    output_tokens.block_until_ready()
    first_run = time.time() - start
    print(f"First run (includes compilation): {first_run:.3f}s")
    
    # Decode output
    num_gen = int(num_generated)
    out_list = [int(output_tokens[i]) for i in range(num_gen) if int(output_tokens[i]) != STOP_TOKEN_ID]
    print(f"Generated {len(out_list)} tokens")
    print(f"Output: {tokenizer.decode(out_list)}")
    print()
    
    # Second run (cached compilation)
    print("Second run (cached)...")
    key = jax.random.PRNGKey(123)
    start = time.time()
    output_tokens, num_generated = generate(tokens, prompt_len, max_new, STACKED_WEIGHTS, key)
    output_tokens.block_until_ready()
    second_run = time.time() - start
    
    num_gen = int(num_generated)
    out_list = [int(output_tokens[i]) for i in range(num_gen) if int(output_tokens[i]) != STOP_TOKEN_ID]
    
    print(f"Second run: {second_run:.3f}s")
    print(f"Tokens: {len(out_list)}, Speed: {len(out_list)/second_run:.1f} tok/s")
    print(f"Output: {tokenizer.decode(out_list)}")


main()

Prompt: <|bos|><|user|>Hello, World!<|assistant|>
Prompt length: 7 tokens
Max new tokens: 300

Compiling (first run)...
First run (includes compilation): 8.459s
Generated 300 tokens
Output: Here's a step-by-step guide to creating a World of Warcraft Online:

1. **Choose your location**: Pick a location that suits your interests, needs, and needs. Choose a location that suits your interests.

2. **Choose a location**: Pick a location that suits your interests, needs, and needs. Choose a location that suits your interests.

3. **Choose a location**: Pick a location that suits your interests, needs, and needs. Choose a location that suits your interests.

4. **Pick a location**: Pick a location that suits your interests, needs, and needs. Choose a location that suits your interests.

5. **Pick a location**: Pick a location that suits your interests, needs, and needs. Choose a location that suits your interests.

6. **Pick a location**: Pick a location that suits your interests, needs, and