In [1]:
import jax

jax.config.update('jax_num_cpu_devices', 12)

In [2]:
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5),
 CpuDevice(id=6),
 CpuDevice(id=7),
 CpuDevice(id=8),
 CpuDevice(id=9),
 CpuDevice(id=10),
 CpuDevice(id=11)]

In [3]:
import numpy as np
import jax.numpy as jnp

arr = jnp.arange(32.0).reshape(4, 8)
arr.devices()

{CpuDevice(id=0)}

In [4]:
arr.sharding

SingleDeviceSharding(device=CpuDevice(id=0), memory_kind=unpinned_host)

In [5]:
# pip install rich

jax.debug.visualize_array_sharding(arr)

In [6]:
from jax.sharding import PartitionSpec as P

mesh = jax.make_mesh((2, 4), ('x', 'y')) # x和y是两个维度的名字
sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))
print(sharding)

# jax.sharding.NamedSharding 也可以直接写成 jax.NamedSharding

NamedSharding(mesh=Mesh('x': 2, 'y': 4, axis_types=(Auto, Auto)), spec=PartitionSpec('x', 'y'), memory_kind=unpinned_host)


In [7]:
arr_sharded = jax.device_put(arr, sharding)
# 注意 这儿的 sharding = jax.sharding.NamedSharding(mesh, P('x', 'y'))

print(arr_sharded)
jax.debug.visualize_array_sharding(arr_sharded)

[[ 0.  1.  2.  3.  4.  5.  6.  7.]
 [ 8.  9. 10. 11. 12. 13. 14. 15.]
 [16. 17. 18. 19. 20. 21. 22. 23.]
 [24. 25. 26. 27. 28. 29. 30. 31.]]


# Automatic parallelism via jit

In [8]:
@jax.jit
def f_elementwise(x):
  return 2 * jnp.sin(x) + 1

result = f_elementwise(arr_sharded)

print("shardings match:", result.sharding == arr_sharded.sharding)

shardings match: True


In [9]:
print(result)
jax.debug.visualize_array_sharding(result)

[[ 1.          2.682942    2.818595    1.28224    -0.513605   -0.9178486
   0.44116902  2.3139732 ]
 [ 2.9787164   1.824237   -0.08804226 -0.99998045 -0.07314587  1.840334
   2.9812148   2.3005757 ]
 [ 0.42419338 -0.92279494 -0.50197446  1.2997544   2.8258905   2.6733112
   0.98229736 -0.69244087]
 [-0.81115675  0.7352965   2.525117    2.912752    1.5418116  -0.32726777
  -0.97606325  0.19192469]]


In [10]:
@jax.jit
def f_contract(x):
  return x.sum(axis=0)

result = f_contract(arr_sharded)
jax.debug.visualize_array_sharding(result)
print(result)

[48. 52. 56. 60. 64. 68. 72. 76.]


ChatGPT告诉我

为什么 visualize_array_sharding(result) 会显示 CPU 0,4？
这是因为：

每个 y 列对应多个 x 行的设备（如 x=0 和 x=1）

当你从 P('x', 'y') 变为 P('y')，JAX 会默认让 x 方向上的两个设备共享该 shard

所以结果 shard 在 mesh 的每列上的所有设备中“共享”，比如：

(x=0,y=0) 和 (x=1,y=0) 都持有第一个结果 shard → CPU 0,4

这是一种 sharding 与 replication 混合模式，用于确保后续计算中一致性和设备可用性。

# Explicit sharding

In [11]:
some_array = np.arange(8)
print(f"JAX-level type of some_array: {jax.typeof(some_array)}")

JAX-level type of some_array: ShapedArray(int32[8])


In [12]:
@jax.jit
def foo(x):
  print(f"JAX-level type of x during tracing: {jax.typeof(x)}")
  return x + x

foo(some_array)

JAX-level type of x during tracing: ShapedArray(int32[8])


Array([ 0,  2,  4,  6,  8, 10, 12, 14], dtype=int32)

In [13]:
from jax.sharding import AxisType

mesh = jax.make_mesh((2, 4), ("X", "Y"),
                     axis_types=(AxisType.Explicit, AxisType.Explicit))
# 这里 AxisType.Explicit 是指我们在 shard 的时候，会指定怎么进行分配

In [14]:
replicated_array = np.arange(8).reshape(4, 2)
sharded_array = jax.device_put(replicated_array, jax.NamedSharding(mesh, P("X", None)))

print(f"replicated_array type: {jax.typeof(replicated_array)}")
print(f"sharded_array type: {jax.typeof(sharded_array)}")
print(sharded_array)

replicated_array type: ShapedArray(int32[4,2])
sharded_array type: ShapedArray(int32[4@X,2])
[[0 1]
 [2 3]
 [4 5]
 [6 7]]


We should read the type f32[4@X, 2] as “a 4-by-2 array of 32-bit floats whose first dimension is sharded along mesh axis ‘X’. The array is replicated along all other mesh axes”

In [15]:
jax.debug.visualize_array_sharding(sharded_array)

In [16]:
arg0 = jax.device_put(np.arange(4).reshape(4, 1),
                      jax.NamedSharding(mesh, P("X", None)))
arg1 = jax.device_put(np.arange(8).reshape(1, 8),
                      jax.NamedSharding(mesh, P(None, "Y")))

@jax.jit
def add_arrays(x, y):
  ans = x + y # 这儿会自动广播
  print(f"x sharding: {jax.typeof(x)}")
  print(f"y sharding: {jax.typeof(y)}")
  print(f"ans sharding: {jax.typeof(ans)}")
  return ans

with jax.sharding.use_mesh(mesh):
  temp = add_arrays(arg0, arg1)
  jax.debug.visualize_array_sharding(temp)

x sharding: ShapedArray(int32[4@X,1])
y sharding: ShapedArray(int32[1,8@Y])
ans sharding: ShapedArray(int32[4@X,8@Y])


In [17]:
add_arrays(arg0, arg1)

print(f"整个的求和结果是 {temp}")

print("试着查看 CPU0 上的数据")
for s in temp.addressable_shards:
    if s.device.id == 0:  # 只看 CPU 0
        print(f"Data on CPU 0:\n{s.data}")

x sharding: ShapedArray(int32[4@X,1])
y sharding: ShapedArray(int32[1,8@Y])
ans sharding: ShapedArray(int32[4@X,8@Y])
整个的求和结果是 [[ 0  1  2  3  4  5  6  7]
 [ 1  2  3  4  5  6  7  8]
 [ 2  3  4  5  6  7  8  9]
 [ 3  4  5  6  7  8  9 10]]
试着查看 CPU0 上的数据
Data on CPU 0:
[[0 1]
 [1 2]]


# Manual parallelism with shard_map

In [18]:
import jax
import jax.numpy as jnp
from jax.sharding import Mesh, PartitionSpec as P
from jax.experimental.shard_map import shard_map # Correct import for shard_map

# Assume f_elementwise is a function you have defined, for example:
def f_elementwise(x):
  return x * 2

mesh = jax.make_mesh((8,), ('x',))

f_elementwise_sharded = shard_map(
    f_elementwise,
    mesh=mesh,
    in_specs=P('x'),
    out_specs=P('x'))

arr = jnp.arange(32)
f_elementwise_sharded(arr)
jax.debug.visualize_array_sharding(arr)

In [19]:
x = jnp.arange(32)
print(f"global shape: {x.shape=}")

def f(x):
  print(f"device local shape: {x.shape=}")
  return x * 2

y = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
jax.debug.visualize_array_sharding(y)

global shape: x.shape=(32,)
device local shape: x.shape=(4,)


In [20]:
def f(x):
  # print(f"device local shape: {x.shape=}")
  sum_in_shard = x.sum(keepdims=True) # 如果没有这个 keepdims=True，sum_in_shard 就是一个标量 程序会报错
  #print(f"sum_in_shard shape: {sum_in_shard.shape=}")
  return sum_in_shard
  #return jnp.sum(x, keepdims=True)

#print(x)
x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P("x",))
                           )
for s in x_sharded.addressable_shards:
    print(f"Data on {s.device}: {s.data}")

jax.debug.visualize_array_sharding(x_sharded)
z = shard_map(f, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x)
print(z)
jax.debug.visualize_array_sharding(z)

Data on TFRT_CPU_0: [0 1 2 3]
Data on TFRT_CPU_1: [4 5 6 7]
Data on TFRT_CPU_2: [ 8  9 10 11]
Data on TFRT_CPU_3: [12 13 14 15]
Data on TFRT_CPU_4: [16 17 18 19]
Data on TFRT_CPU_5: [20 21 22 23]
Data on TFRT_CPU_6: [24 25 26 27]
Data on TFRT_CPU_7: [28 29 30 31]


[  6  22  38  54  70  86 102 118]


In [21]:
def g(x):
  sum_in_shard = x.sum(keepdims=True) # 如果没有这个 keepdims=True，sum_in_shard 就是一个标量 这个地方不会报错、
  #print(f"sum_in_shard shape: {sum_in_shard.shape=}")
  return jax.lax.psum(sum_in_shard, 'x')

print(shard_map(g, mesh=mesh, in_specs=P('x'), out_specs=P())(x))
# 注意，要跨 partition 进行计算，需要使用 jax.lax.psum，而且 out_specs 也要设置为 P() 这样才是只有一个输出数字嘛
# 如果设置为 P('x')，尝试下面的代码
print(shard_map(g, mesh=mesh, in_specs=P('x'), out_specs=P('x'))(x))

[496]
[496 496 496 496 496 496 496 496]


# Comparing the three approaches

In [22]:
@jax.jit
def layer(x, weights, bias):
  return jax.nn.sigmoid(x @ weights + bias)

In [23]:
import numpy as np
rng = np.random.default_rng(0)

x = rng.normal(size=(32,))
weights = rng.normal(size=(32, 4))
bias = rng.normal(size=(4,))

x = jnp.array(x)
weights = jnp.array(weights)
bias = jnp.array(bias)

layer(x, weights, bias)

Array([0.02138916, 0.8931118 , 0.5989196 , 0.9774251 ], dtype=float32)

In [24]:
mesh = jax.make_mesh((8,), ('x',))
x_sharded = jax.device_put(x, jax.NamedSharding(mesh, P('x')))
weights_sharded = jax.device_put(weights, jax.NamedSharding(mesh, P()))

print(layer(x_sharded, weights_sharded, bias))

jax.debug.visualize_array_sharding(x_sharded)
jax.debug.visualize_array_sharding(weights_sharded)

[0.02138916 0.8931118  0.5989196  0.9774251 ]


In [25]:
explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))

x_sharded = jax.device_put(x, jax.NamedSharding(explicit_mesh, P('X')))
jax.debug.visualize_array_sharding(x_sharded)
weights_sharded = jax.device_put(weights, jax.NamedSharding(explicit_mesh, P()))

@jax.jit
def layer_auto(x, weights, bias):
  print(f"x sharding: {jax.typeof(x)}")
  print(f"weights sharding: {jax.typeof(weights)}")
  print(f"bias sharding: {jax.typeof(bias)}")
  out = layer(x, weights, bias)
  print(f"out sharding: {jax.typeof(out)}")
  return out

with jax.sharding.use_mesh(explicit_mesh):
  layer_auto(x_sharded, weights_sharded, bias)

x sharding: ShapedArray(float32[32@X])
weights sharding: ShapedArray(float32[32,4])
bias sharding: ShapedArray(float32[4])
out sharding: ShapedArray(float32[4])


In [26]:
z = jnp.arange(32)
explicit_mesh = jax.make_mesh((8,), ('X',), axis_types=(AxisType.Explicit,))
z_sharded = jax.device_put(z, jax.NamedSharding(explicit_mesh, P('X')))
print(f"z sharding: {jax.typeof(z_sharded)}")

z sharding: ShapedArray(int32[32@X])


上一个单元格的输出中，变量 z 是一个 JAX 数组，它：
- 数据类型为 int32 (32位整数)。
- 有一个维度，其全局大小为 32。
- 这个维度是沿着一个名为 X 的设备网格轴进行分片 (sharded) 的。

简单来说，数组 z 是一个包含32个整数的数组，这些整数被分散存储在与网格轴 X 关联的多个设备上。

In [30]:
from functools import partial

@jax.jit
@partial(shard_map, mesh=mesh,
         in_specs=(P('x'), P('x', None), P(None)),
         out_specs=P(None))
def layer_sharded(x, weights, bias):
  print(f"x sharding: {jax.typeof(x)}")
  print(f"weights sharding: {jax.typeof(weights)}")
  print(f"bias sharding: {jax.typeof(bias)}")
  return jax.nn.sigmoid(jax.lax.psum(x @ weights, 'x') + bias)

print(layer_sharded(x, weights, bias))

print("注意到 x 的长度是32，而 mesh 的长度是8，所以 x 会被分成8份，每份的长度是4")

x sharding: ShapedArray(float32[4])
weights sharding: ShapedArray(float32[4,4])
bias sharding: ShapedArray(float32[4])
[0.02138916 0.8931118  0.5989196  0.9774251 ]
注意到 x 的长度是32，而 mesh 的长度是8，所以 x 会被分成8份，每份的长度是4
