In [1]:
import os
import numpy as np
# sets number of CPU device count
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=20"

In [2]:
import jax
jax.config.update('jax_platform_name', 'cpu')
cpus = jax.devices('cpu')
cpu_count = jax.device_count()

print(f"CPU count = {cpu_count}")
print(f"CPU list = {cpus}")

CPU count = 20
CPU list = [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9), CpuDevice(id=10), CpuDevice(id=11), CpuDevice(id=12), CpuDevice(id=13), CpuDevice(id=14), CpuDevice(id=15), CpuDevice(id=16), CpuDevice(id=17), CpuDevice(id=18), CpuDevice(id=19)]


In [4]:
from functools import partial
import jax
import jax.numpy as jnp
@partial(jax.pmap, axis_name='rows')
@partial(jax.pmap, axis_name='cols')
def normalize(x):
    row_normed = x / jax.lax.psum(x, 'rows')
    col_normed = x / jax.lax.psum(x, 'cols')
    doubly_normed = x / jax.lax.psum(x, ('rows', 'cols'))
    return row_normed, col_normed, doubly_normed

x = jnp.arange(6.).reshape((3, 2))
row_normed, col_normed, doubly_normed = normalize(x)  
print(row_normed.sum(0))  
print(col_normed.sum(1))  
print(doubly_normed.sum((0, 1)))  

[1. 1.]
[1. 1. 1.]
1.0000001


In [5]:
f = lambda x: x + jax.lax.psum(x, axis_name='i')
data = jnp.arange(6) if jax.process_index() == 0 else jnp.arange(6, 12)
out = pmap(f, axis_name='i')(data)
print(out)  

AttributeError: module 'jax' has no attribute 'process_index'

In [13]:
from functools import partial
@partial(jax.pmap, axis_name='i', devices=jax.devices()[:6])
def f1(x):
    return x / jax.lax.psum(x, axis_name='i')

@partial(jax.pmap, axis_name='i', devices=jax.devices()[-2:])
def f2(x):
    return jax.lax.psum(x ** 2, axis_name='i')

def func_par(x, devices=jax.devices()):
    @partial(jax.pmap, axis_name='i', devices=devices)
    def f2(x):
        return jax.lax.psum(x ** 2, axis_name='i')
    
    return f2(x)

def func_eig(x, devices=jax.devices()):
    @partial(jax.pmap, devices=devices)
    def f3(x):
        return jax.scipy.linalg.eigh(x)
    
    return f3(x)

@jax.jit
def calceig(x):
    return jax.scipy.linalg.eigh(x)

print(f1(jnp.arange(6.)))  
print(func_par(jnp.array([2., 3.]), jax.devices()[-2:]))  

[0.         0.06666667 0.13333334 0.2        0.26666668 0.33333334]
[13. 13.]


In [15]:
a = np.random.rand(128, 128)
aj = jnp.asarray(a)
calceig(aj)
#func_eig(aj, devices=cpus[:5])

(Buffer([-4.47476959e+00, -4.32579851e+00, -4.13233232e+00,
         -4.02423859e+00, -3.95493865e+00, -3.90555954e+00,
         -3.76627660e+00, -3.69392967e+00, -3.44470358e+00,
         -3.41919041e+00, -3.30953288e+00, -3.27461457e+00,
         -3.16211414e+00, -3.04298115e+00, -3.00135088e+00,
         -2.94170976e+00, -2.93453813e+00, -2.80347967e+00,
         -2.66035557e+00, -2.58380842e+00, -2.47034979e+00,
         -2.40401959e+00, -2.37228274e+00, -2.35917568e+00,
         -2.28182244e+00, -2.22655416e+00, -2.18811488e+00,
         -2.12785149e+00, -2.03857446e+00, -1.98961008e+00,
         -1.90002215e+00, -1.84667826e+00, -1.83442557e+00,
         -1.75481093e+00, -1.63829613e+00, -1.56548703e+00,
         -1.48740411e+00, -1.44238770e+00, -1.37465513e+00,
         -1.35969830e+00, -1.27306783e+00, -1.22310758e+00,
         -1.17699921e+00, -1.06548786e+00, -1.02089632e+00,
         -9.72605348e-01, -9.12236869e-01, -8.29795182e-01,
         -7.88799226e-01, -7.63289511e-0

In [None]:
jax.devices()[0]