# Nanochat JAX/Flax Training on TPU

This notebook sets up and runs JAX training on Google Colab TPU.

**Important:** Before running, go to `Runtime > Change runtime type` and select **TPU** as the hardware accelerator.

## 1. Verify TPU Runtime

First, let's check that we have TPU access.

In [None]:
import os

# Check if TPU is available
if 'COLAB_TPU_ADDR' in os.environ:
    print(f"TPU address: {os.environ['COLAB_TPU_ADDR']}")
else:
    print("WARNING: TPU not detected!")
    print("Go to Runtime > Change runtime type > Hardware accelerator > TPU")

## 2. Clone the Repository

In [None]:
!git clone https://github.com/snaidu/nanochat.git
%cd nanochat

## 3. Install uv Package Manager

In [None]:
!curl -LsSf https://astral.sh/uv/install.sh | sh

# Add uv to PATH for this session
os.environ['PATH'] = f"{os.environ['HOME']}/.local/bin:{os.environ['PATH']}"

## 4. Install Dependencies

In [None]:
# Install Python and sync dependencies with TPU support
!uv python install 3.13
!uv sync --extra tpu

In [None]:
# Install JAX with TPU support
!uv pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [None]:
!uv run python -c "
import jax
print('JAX version:', jax.__version__)
print('Devices:', jax.devices())
print('Number of devices:', len(jax.devices()))
print('Device type:', jax.devices()[0].platform if jax.devices() else 'None')
"

## 7. Prepare Dataset

This downloads the FineWeb-Edu dataset. For a quick test, we download just 10 shards (~1-2 GB).

**Note:** Full dataset is ~200GB (1823 shards). Use `-n 10` for testing.

In [None]:
# Download 10 dataset shards for testing (~1-2 GB)
# Increase -n for more data, or remove it for the full dataset
!uv run python -m nanochat.dataset -n 10

## 8. Train Tokenizer

Train a BPE tokenizer on the downloaded data.

In [None]:
# Train the BPE tokenizer
!uv run python -m scripts.tok_train

## 9. Test JAX Model

Quick sanity check that the model works.

In [None]:
!uv run python -c "
import jax
import jax.numpy as jnp
from flax import nnx
from nanochat.jax.gpt import GPT, GPTJaxConfig

config = GPTJaxConfig(
    sequence_len=128,
    vocab_size=50304,
    n_layer=4,
    n_head=4,
    n_kv_head=4,
    n_embd=128,
    dtype=jnp.bfloat16,
)
model = GPT(config, rngs=nnx.Rngs(0))
print('Model created successfully!')

# Test forward pass
x = jax.random.randint(jax.random.key(0), (2, 64), 0, 1000)
y = jax.random.randint(jax.random.key(1), (2, 64), 0, 1000)
loss = model(x, y)
print(f'Forward pass works! Loss: {loss}')
"

## 10. Run Training

Now run the actual training script. Adjust parameters as needed.

In [None]:
# Small test run (adjust parameters for longer training)
!uv run python -m scripts.jax.base_train \
    --depth=4 \
    --max-seq-len=512 \
    --device-batch-size=8 \
    --num-iterations=100 \
    --eval-every=50 \
    --warmup-steps=10 \
    --learning-rate=3e-4

## 11. Multi-Device Training (Optional)

If you have multiple TPU cores, enable multi-device training.

In [None]:
# Multi-device training (uses all TPU cores)
!uv run python -m scripts.jax.base_train \
    --depth=4 \
    --max-seq-len=512 \
    --device-batch-size=32 \
    --num-iterations=100 \
    --eval-every=50 \
    --multi-device

## 12. Larger Model Training

Once everything works, try a larger model.

In [None]:
# Larger model (depth=12, ~85M params)
!uv run python -m scripts.jax.base_train \
    --depth=12 \
    --max-seq-len=1024 \
    --device-batch-size=16 \
    --num-iterations=1000 \
    --eval-every=100 \
    --warmup-steps=100 \
    --multi-device