In [14]:
import os
import numpy as np
# sets number of CPU device count
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=6"

In [15]:
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.devices()

[CpuDevice(id=0),
 CpuDevice(id=1),
 CpuDevice(id=2),
 CpuDevice(id=3),
 CpuDevice(id=4),
 CpuDevice(id=5)]

In [3]:
jax.device_count()

6

In [6]:
from functools import partial
import jax
import jax.numpy as jnp
@partial(jax.pmap, axis_name='rows')
@partial(jax.pmap, axis_name='cols')
def normalize(x):
    row_normed = x / jax.lax.psum(x, 'rows')
    col_normed = x / jax.lax.psum(x, 'cols')
    doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
    return row_normed, col_normed, doubly_normed

x = jnp.arange(6.).reshape((3, 2))
row_normed, col_normed, doubly_normed = normalize(x)  
print(row_normed.sum(0))  
print(col_normed.sum(1))  
print(doubly_normed.sum((0, 1)))  

[1. 1.]
[1. 1. 1.]
1.0000001


In [7]:
f = lambda x: x + jax.lax.psum(x, axis_name='i')
data = jnp.arange(6) if jax.process_index() == 0 else jnp.arange(6, 12)
out = pmap(f, axis_name='i')(data)
print(out)  

AttributeError: module 'jax' has no attribute 'process_index'

In [26]:
from functools import partial
@partial(jax.pmap, axis_name='i', devices=jax.devices()[:6])
def f1(x):
    return x / jax.lax.psum(x, axis_name='i')

@partial(jax.pmap, axis_name='i', devices=jax.devices()[-2:])
def f2(x):
    return jax.lax.psum(x ** 2, axis_name='i')

def func_par(x, devices=jax.devices()):
    @partial(jax.pmap, axis_name='i', devices=devices)
    def f2(x):
        return jax.lax.psum(x ** 2, axis_name='i')
    
    return f2(x)

def func_eig(x, devices=jax.devices()):
    @partial(jax.pmap, devices=devices)
    def f3(x):
        return jax.scipy.linalg.eigh(x)
    
    return f3(x)

print(f1(jnp.arange(6.)))  
print(func_par(jnp.array([2., 3.]), jax.devices()[-2:]))  

[0.         0.06666667 0.13333334 0.2        0.26666668 0.33333334]
[13. 13.]


In [29]:
a = np.random.rand(128, 128)
aj = jnp.asarray(a)
func_eig(aj, jax.devices()[0])

TypeError: 'jaxlib.xla_extension.CpuDevice' object is not iterable

In [30]:
jax.devices()[0]

CpuDevice(id=0)