In [1]:
# Disable TensorFlow GPU first
import tensorflow as tf
tf.config.experimental.set_visible_devices([], "GPU")

# Now test JAX
import jax
print('JAX version:', jax.__version__)
print('JAX devices:', jax.devices())
gpu_devices = [d for d in jax.devices() if d.platform == 'gpu']
print('GPU devices:', len(gpu_devices))
if len(gpu_devices) > 0:
    print('JAX GPU detected')
    for i, device in enumerate(gpu_devices):
        print(f'  GPU {i}: {device} ({device.device_kind})')
else:
    print('No GPU devices detected')

2026-01-05 17:34:07.929860: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2026-01-05 17:34:07.988122: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


JAX version: 0.4.7
JAX devices: [StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]
GPU devices: 1
JAX GPU detected
  GPU 0: gpu:0 (Tesla V100-PCIE-16GB)


In [2]:
from jax import numpy as jnp
if len(gpu_devices) > 0:
    try:
        # Create arrays on GPU
        x = jnp.array([1.0, 2.0, 3.0])
        y = jnp.array([4.0, 5.0, 6.0])
        z = x + y
        
        print(f"✅ GPU computation successful!")
        print(f"   Input x: {x}")
        print(f"   Input y: {y}")
        print(f"   Result z: {z}")
        print(f"   Device: {z.device()}")
    except Exception as e:
        print(f"❌ GPU computation failed: {e}")
else:
    print("⚠️  Running on CPU (no GPU available)")
    x = jnp.array([1.0, 2.0, 3.0])
    y = jnp.array([4.0, 5.0, 6.0])
    z = x + y
    print(f"   CPU computation: {z}")
    print(f"   Device: {z.device()}")


✅ GPU computation successful!
   Input x: [1. 2. 3.]
   Input y: [4. 5. 6.]
   Result z: [5. 7. 9.]
   Device: gpu:0
