# Zenith AI + JAX: Neural Network Training

**Train a neural network with JAX using Zenith's high-performance DataLoader**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/vibeswithkk/Zenith-dataplane/blob/main/notebooks/02_jax_training.ipynb)

## What You'll Learn
- Load data with Zenith for JAX training
- Convert Arrow batches to JAX arrays
- Train a simple MLP with high-performance data loading

---
## 1. Install Dependencies

In [None]:
# Install zenith-ai and JAX
!pip install zenith-ai jax jaxlib flax optax datasets pyarrow --quiet

# Verify installation
import zenith
import jax
zenith.info()
print(f"JAX devices: {jax.devices()}")

---
## 2. Prepare Dataset

In [None]:
from datasets import load_dataset
import pyarrow.parquet as pq
import pyarrow as pa
import numpy as np

print("Downloading MNIST dataset...")
dataset = load_dataset("mnist", split="train")
print(f"Dataset size: {len(dataset)} samples")

# Convert to Parquet
print("\nConverting to Parquet format...")

images = [np.array(x['image']).flatten().astype(np.float32) / 255.0 for x in dataset]
labels = [x['label'] for x in dataset]

# Create Arrow table with flattened images
table = pa.table({
    'pixels': [img.tobytes() for img in images],
    'label': labels
})

pq.write_table(table, 'mnist_train.parquet')

import os
size_mb = os.path.getsize('mnist_train.parquet') / (1024 * 1024)
print(f"Saved: mnist_train.parquet ({size_mb:.1f} MB)")

---
## 3. Define JAX/Flax Model

In [None]:
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax

class MLP(nn.Module):
    """Simple MLP for MNIST classification."""
    
    @nn.compact
    def __call__(self, x, training: bool = True):
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dropout(0.2, deterministic=not training)(x)
        x = nn.Dense(128)(x)
        x = nn.relu(x)
        x = nn.Dense(10)(x)
        return x

# Initialize model
model = MLP()
rng = jax.random.PRNGKey(0)
params = model.init(rng, jnp.ones([1, 784]))

# Create optimizer
tx = optax.adam(learning_rate=0.001)
state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx
)

print(f"Model initialized")
print(f"Parameters: {sum(x.size for x in jax.tree_leaves(params)):,}")

---
## 4. Load Data with Zenith

In [None]:
import zenith

# Create Zenith DataLoader
loader = zenith.DataLoader(
    'mnist_train.parquet',
    batch_size=128,
    shuffle=True
)

print(f"DataLoader: {loader}")

---
## 5. Training Loop

In [None]:
import time

@jax.jit
def train_step(state, images, labels):
    """Single training step."""
    def loss_fn(params):
        logits = state.apply_fn(params, images, training=True)
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
        return loss, logits
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    
    accuracy = (logits.argmax(-1) == labels).mean()
    return state, loss, accuracy

# Training loop
print("Training with Zenith DataLoader...")
print("-" * 40)

for epoch in range(3):
    start = time.time()
    total_loss = 0
    total_acc = 0
    num_batches = 0
    
    for batch in loader:
        # Convert Zenith batch to JAX arrays
        data = batch.to_numpy()
        
        # Reconstruct images from bytes
        images = jnp.array([np.frombuffer(p, dtype=np.float32) for p in data['pixels']])
        labels = jnp.array(data['label'])
        
        # Train step
        state, loss, acc = train_step(state, images, labels)
        total_loss += float(loss)
        total_acc += float(acc)
        num_batches += 1
    
    elapsed = time.time() - start
    print(f"Epoch {epoch+1}: Loss={total_loss/num_batches:.4f}, "
          f"Acc={100*total_acc/num_batches:.2f}%, Time={elapsed:.2f}s")

print("-" * 40)
print("Training complete")

---
## Summary

You've learned:
1. Load data with `zenith.DataLoader()`
2. Convert to JAX arrays with `batch.to_numpy()` then `jnp.array()`
3. Train JAX/Flax models with Zenith data loading

**GitHub:** https://github.com/vibeswithkk/Zenith-dataplane