In [2]:
import os

flags = os.environ.get("XLA_FLAGS", "")
flags += " --xla_force_host_platform_device_count=8"  # Simulate 8 devices
# Enforce CPU-only execution
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["XLA_FLAGS"] = flags

In [3]:
import functools
from typing import Any, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
from jax.experimental.shard_map import shard_map
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P

In [27]:
import ml_collections

In [4]:
print(jax.local_devices())
print(jax.devices())

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]


In [29]:
a = jnp.arange(8)
print("Array", a)
# help(a)
print("Device", a.device)
print("Sharding", a.sharding)

Array [0 1 2 3 4 5 6 7]
Device TFRT_CPU_0
Sharding SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)


In [12]:
mesh = Mesh(np.array(jax.devices()), ("i",))
mesh

Mesh(device_ids=array([0, 1, 2, 3, 4, 5, 6, 7]), axis_names=('i',), axis_types={Auto: ('i',)})

In [14]:
sharding = NamedSharding(
    mesh,
    P("i"),
)
sharding

NamedSharding(mesh=Mesh('i': 8), spec=PartitionSpec('i',), memory_kind=unpinned_host)

In [30]:
a_sharded = jax.device_put(a, sharding)
print("Sharded array", a_sharded)
print("Device", a_sharded.devices())
print("Sharding", a_sharded.sharding)

Sharded array [0 1 2 3 4 5 6 7]
Device {CpuDevice(id=0), CpuDevice(id=7), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=1), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6)}
Sharding NamedSharding(mesh=Mesh('i': 8), spec=PartitionSpec('i',), memory_kind=unpinned_host)


In [31]:
jax.debug.visualize_array_sharding(a_sharded)

In [32]:
out = nn.tanh(a_sharded)
print("Output array", out)
jax.debug.visualize_array_sharding(out)

Output array [0.         0.7615942  0.9640276  0.9950547  0.9993292  0.99990916
 0.9999876  0.99999833]
