In [2]:
# @title Setup
import os

# Must be set before JAX/XLA init to partition host CPU for pmap testing.
# Re-run after restarting the runtime if you need to change this.
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf

# Keep TF's hands off the GPU memory; JAX is the primary compute engine here.
tf.config.set_visible_devices([], 'GPU')

def report_environment():
    backend = jax.default_backend()
    devices = jax.devices()

    print(f"JAX Backend: {backend.upper()}")
    print(f"Primary Devices: {len(devices)}")
    for d in devices:
        print(f" - {d.device_kind} (ID: {d.id})")

    if backend == 'gpu':
        print("\nHardware Driver Status:")
        # Direct check for driver/CUDA alignment
        try:
            !nvidia-smi --query-gpu=driver_version,compute_cap --format=csv,noheader
        except:
            print("nvidia-smi check failed.")

    print(f"\nSoftware Stack:")
    print(f" - JAX: {jax.__version__}")
    print(f" - Local Device Count: {jax.local_devices()}")

report_environment()

JAX Backend: CPU
Primary Devices: 4
 - cpu (ID: 0)
 - cpu (ID: 1)
 - cpu (ID: 2)
 - cpu (ID: 3)

Software Stack:
 - JAX: 0.7.2
 - Local Device Count: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3)]


In [4]:
# @title Data pipeline
import tensorflow_datasets as tfds

def load_cifar10(batch_size, train=True):
    split = 'train' if train else 'test'
    ds, info = tfds.load('cifar10', split=split, with_info=True, as_supervised=True)

    def preprocess(image, label):
        image = tf.cast(image, tf.float32) / 255.0
        label = tf.one_hot(label, 10)
        return image, label

    ds = ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)

    if train:
        ds = ds.shuffle(10000).repeat()

    ds = ds.batch(batch_size, drop_remainder=True)
    # Ensure the host stays ahead of the accelerator.
    ds = ds.prefetch(tf.data.AUTOTUNE)

    # Use as_numpy to avoid TF tensor overhead in JAX.
    return tfds.as_numpy(ds), info

# Initialize generators.
BATCH_SIZE = 64
train_ds_iterable, ds_info = load_cifar10(BATCH_SIZE, train=True)
test_ds_iterable, _ = load_cifar10(BATCH_SIZE, train=False)

# Create iterators for manual stepping.
train_ds = iter(train_ds_iterable)
test_ds = iter(test_ds_iterable)

# Verification.
sample_batch = next(train_ds)
print(f"Batch shapes: Images {sample_batch[0].shape}, Labels {sample_batch[1].shape}")
print(f"Data types:  Images {sample_batch[0].dtype}, Labels {sample_batch[1].dtype}")

Batch shapes: Images (64, 32, 32, 3), Labels (64, 10)
Data types:  Images float32, Labels float32
