In [1]:
# @title Setup
import os

# Must be set before JAX/XLA init to partition host CPU for pmap testing.
# Re-run after restarting the runtime if you need to change this.
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf

# Keep TF's hands off the GPU memory; JAX is the primary compute engine here.
tf.config.set_visible_devices([], 'GPU')

def report_environment():
    backend = jax.default_backend()
    devices = jax.devices()

    print(f"JAX Backend: {backend.upper()}")
    print(f"Primary Devices: {len(devices)}")
    for d in devices:
        print(f" - {d.device_kind} (ID: {d.id})")

    if backend == 'gpu':
        print("\nHardware Driver Status:")
        # Direct check for driver/CUDA alignment
        try:
            !nvidia-smi --query-gpu=driver_version,compute_cap --format=csv,noheader
        except:
            print("nvidia-smi check failed.")

    print(f"\nSoftware Stack:")
    print(f" - JAX: {jax.__version__}")
    print(f" - Local Device Count: {jax.local_devices()}")

report_environment()

JAX Backend: CPU
Primary Devices: 4
 - cpu (ID: 0)
 - cpu (ID: 1)
 - cpu (ID: 2)
 - cpu (ID: 3)

Software Stack:
 - JAX: 0.7.2
 - Local Device Count: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


In [2]:
# @title Data pipeline
import tensorflow_datasets as tfds

def load_cifar10(batch_size, train=True):
    split = 'train' if train else 'test'
    ds, info = tfds.load('cifar10', split=split, with_info=True, as_supervised=True)

    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        label = tf.one_hot(label, 10)
        return image, label

    ds = ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

    if train:
        ds = ds.shuffle(10000).repeat()

    ds = ds.batch(batch_size, drop_remainder=True)
    # Ensure the host stays ahead of the accelerator.
    ds = ds.prefetch(tf.data.AUTOTUNE)

    # Use as_numpy to avoid TF tensor overhead in JAX.
    return tfds.as_numpy(ds), info

# Initialize generators.
BATCH_SIZE = 64
train_ds_iterable, ds_info = load_cifar10(BATCH_SIZE, train=True)
test_ds_iterable, _ = load_cifar10(BATCH_SIZE, train=False)

# Create iterators for manual stepping.
train_ds = iter(train_ds_iterable)
test_ds = iter(test_ds_iterable)

# Verification.
sample_batch = next(train_ds)
print(f"Batch shapes: Images {sample_batch[0].shape}, Labels {sample_batch[1].shape}")
print(f"Data types:  Images {sample_batch[0].dtype}, Labels {sample_batch[1].dtype}")



Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /root/tensorflow_datasets/cifar10/3.0.2...


Dl Completed...: 0 url [00:00, ? url/s]

Dl Size...: 0 MiB [00:00, ? MiB/s]

Extraction completed...: 0 file [00:00, ? file/s]

Generating splits...:   0%|          | 0/2 [00:00<?, ? splits/s]

Generating train examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/cifar10/incomplete.8GC940_3.0.2/cifar10-train.tfrecord*...:   0%|         …

Generating test examples...: 0 examples [00:00, ? examples/s]

Shuffling /root/tensorflow_datasets/cifar10/incomplete.8GC940_3.0.2/cifar10-test.tfrecord*...:   0%|          …

Dataset cifar10 downloaded and prepared to /root/tensorflow_datasets/cifar10/3.0.2. Subsequent calls will reuse this data.
Batch shapes: Images (64, 32, 32, 3), Labels (64, 10)
Data types:  Images float32, Labels float32


In [3]:
# @title Model definition
import flax.linen as nn

class CIFAR10CNN(nn.Module):
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x

model = CIFAR10CNN()
key = jax.random.PRNGKey(0)

# Initialize with a dummy input to fix parameter shapes.
variables = model.init(key, jnp.ones((1, 32, 32, 3)))
params = variables['params']

# Quick shape verification.
jax.tree_util.tree_map(lambda x: print(f"Layer shape: {x.shape}"), params)

Layer shape: (32,)
Layer shape: (3, 3, 3, 32)
Layer shape: (64,)
Layer shape: (3, 3, 32, 64)
Layer shape: (256,)
Layer shape: (4096, 256)
Layer shape: (10,)
Layer shape: (256, 10)


{'Conv_0': {'bias': None, 'kernel': None},
 'Conv_1': {'bias': None, 'kernel': None},
 'Dense_0': {'bias': None, 'kernel': None},
 'Dense_1': {'bias': None, 'kernel': None}}

In [5]:
# @title Loss and update logic

import optax

# Standard Adam optimizer setup.
learning_rate = 1e-3
optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(params)

def compute_loss(params, images, labels):
    """
    Computes cross-entropy loss between model logits and one-hot labels.
    """
    logits = model.apply({'params': params}, images)
    # Using the current optax API for categorical cross entropy.
    loss = optax.softmax_cross_entropy(logits=logits, labels=labels)
    return jnp.mean(loss)

@jax.jit
def train_step(params, opt_state, images, labels):
    """
    Fuses the gradient computation and parameter update into one XLA kernel.
    """
    # Calculate scalar loss and the gradients for the parameter Pytree.
    loss, grads = jax.value_and_grad(compute_loss)(params, images, labels)

    # Transform gradients into updates based on the optimizer state.
    updates, opt_state = optimizer.update(grads, opt_state, params)

    # Apply those updates to the current parameters.
    params = optax.apply_updates(params, updates)

    return params, opt_state, loss

# Single execution to trigger JIT compilation and verify the step.
params, opt_state, loss = train_step(params, opt_state, sample_batch[0], sample_batch[1])
print(f"Initial loss: {loss:.4f}")

Initial loss: 2.3238


In [6]:
# @title Training loop and evaluation

import time

def train_model(params, opt_state, dataset, num_steps=500):
    """
    Main training loop for CIFAR-10.
    """
    step_losses = []
    start_time = time.time()

    print(f"Starting training for {num_steps} steps...")

    for step in range(num_steps):
        # Pull the next pre-processed batch from the iterator.
        batch_images, batch_labels = next(dataset)

        # Execute the compiled train_step on the GPU.
        params, opt_state, loss = train_step(params, opt_state, batch_images, batch_labels)
        step_losses.append(loss)

        if step % 100 == 0:
            avg_loss = np.mean(step_losses[-100:])
            print(f"Step {step:4d} | Loss: {avg_loss:.4f}")

    end_time = time.time()
    total_time = end_time - start_time

    print(f"\nTraining complete.")
    print(f"Total time: {total_time:.2f}s | Speed: {num_steps/total_time:.2f} steps/s")

    return params, opt_state, step_losses

# Execute the training.
params, opt_state, history = train_model(params, opt_state, train_ds)

Starting training for 500 steps...
Step    0 | Loss: 2.7791
Step  100 | Loss: 1.8862
Step  200 | Loss: 1.5241
Step  300 | Loss: 1.3880
Step  400 | Loss: 1.2813

Training complete.
Total time: 69.32s | Speed: 7.21 steps/s
