In [1]:
# @title Setup
import os

# Must be set before JAX/XLA init to partition host CPU for pmap testing.
# Re-run after restarting the runtime if you need to change this.
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=4'

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf

# Keep TF's hands off the GPU memory; JAX is the primary compute engine here.
tf.config.set_visible_devices([], 'GPU')

def report_environment():
    backend = jax.default_backend()
    devices = jax.devices()

    print(f"JAX Backend: {backend.upper()}")
    print(f"Primary Devices: {len(devices)}")
    for d in devices:
        print(f" - {d.device_kind} (ID: {d.id})")

    if backend == 'gpu':
        print("\nHardware Driver Status:")
        # Direct check for driver/CUDA alignment
        try:
            !nvidia-smi --query-gpu=driver_version,compute_cap --format=csv,noheader
        except:
            print("nvidia-smi check failed.")

    print(f"\nSoftware Stack:")
    print(f" - JAX: {jax.__version__}")
    print(f" - Local Device Count: {jax.local_devices()}")

report_environment()

--- System Topology Report ---
Primary Backend: CPU
Physical/Logical units available: 4
 - CPU (ID: 0)
 - CPU (ID: 1)
 - CPU (ID: 2)
 - CPU (ID: 3)

Host Backend: CPU
Logical cores partitioned: 4

Software Stack:
 - JAX version: 0.7.2

Execution Context:
 - Process Index: 0
 - Global Device Count: 4
