# FrawdLLM Inference on TPU (JAX)

This notebook implements transformer inference from scratch using JAX, running on Google TPU.

**Setup:** Runtime > Change runtime type > TPU

In [None]:
# Install dependencies
!pip install huggingface_hub safetensors tokenizers

In [None]:
# Verify TPU is available
import jax
print(f"Devices: {jax.devices()}")
print(f"Device count: {jax.device_count()}")

## Load Weights from HuggingFace

In [None]:
from huggingface_hub import hf_hub_download

# Download weights and tokenizer from HuggingFace
weights_path = hf_hub_download(repo_id="tsingla1998/frawdllm-100m", filename="model.safetensors")
tokenizer_path = hf_hub_download(repo_id="tsingla1998/frawdllm-100m", filename="tokenizer.json")

print(f"Weights: {weights_path}")
print(f"Tokenizer: {tokenizer_path}")

In [None]:
from safetensors import safe_open
import jax.numpy as jnp

# Load weights into JAX arrays
weights = {}
with safe_open(weights_path, framework="numpy") as f:
    for key in f.keys():
        weights[key] = jnp.array(f.get_tensor(key))
        print(f"{key}: {weights[key].shape}")

In [None]:
from tokenizers import Tokenizer

tokenizer = Tokenizer.from_file(tokenizer_path)

## Model Constants

In [None]:
# Model architecture
N_HEADS = 12
N_LAYERS = 12
HEAD_DIM = 64
N_EMBD = N_HEADS * HEAD_DIM  # 768

# Generation
STOP_TOKEN_ID = 3
TEMPERATURE = 0.5
TOP_P = 0.9
TOP_K = 100
MAX_OUTPUT_TOKENS = 300
ROPE_THETA = 10000.0

## Inference Implementation

TODO: Implement the following in JAX:
- RoPE (Rotary Position Embeddings)
- LayerNorm
- Attention
- MLP
- Generation loop

In [None]:
# TODO: Implement inference