In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib widget

In [2]:
import time

import numpy as np

from jaxfi import jaxm
from jaxfi.experimental.auto_pmap import auto_pmap
from mpcjax import solve, solve_problems
from mpcjax import utils

import jax

from dynamics import f_fx_fu_fn

In [3]:
def fn(x, y, z):
    return (x * y["shape"]) - z, 2 - z


x = jaxm.randn(10)
y = {"shape": jaxm.randn(())}
z = jaxm.array(10)

#jaxm.jit(utils.auto_sharding(fn, in_axes=(0, dict(shape=None), None)))(x, y, z)
auto_pmap(fn, in_axes=(0, dict(shape=None), None))(x, y, z)

(Array([-10.062562 , -11.92967  ,  -9.772968 ,  -7.0987687,  -9.095894 ,
         -8.501585 ,  -9.0253725, -11.4022   , -10.289578 ,  -8.640935 ],      dtype=float32),
 Array([-8., -8., -8., -8., -8., -8., -8., -8., -8., -8.], dtype=float32))

In [4]:
def wait(x):
    time.sleep(1)
    return jaxm.array(1, dtype=np.int32)

def pure_fn(x):
    return jax.pure_callback(wait, jax.ShapeDtypeStruct((), np.int32), x)

In [5]:
pure_fn(jaxm.randn((10)))

Array(1, dtype=int32)

In [42]:
utils.auto_pmap(pure_fn, in_axes=0)(jaxm.randn((10)))

floor_size: 6, batch_size: 10


Array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=int32)

In [8]:
N, xdim, udim = 20, 4, 2

Q = np.tile(np.eye(xdim), (N, 1, 1))
R = np.tile(1e-2 * np.eye(udim), (N, 1, 1))
x0 = np.tile(np.ones(xdim), (1,))
X_ref, U_ref = np.zeros((N, xdim)), np.zeros((N, udim))
X_prev, U_prev = np.zeros((N, xdim)), np.zeros((N, udim))
u_lim = 1e0
u_l, u_u = -u_lim * np.ones((N, udim)), u_lim * np.ones((N, udim))

problem = dict(
    f_fx_fu_fn=f_fx_fu_fn,
    Q=Q,
    R=R,
    x0=x0,
    X_ref=X_ref,
    U_ref=U_ref,
    X_prev=X_prev,
    U_prev=U_prev,
    u_l=u_l,
    u_u=u_u,
    solver_settings=dict(smooth_alpha=1e2, solver="sqp", linesearch="scan", maxls=50),
    reg_x=1e0,
    reg_u=1e-1,
    max_it=10,
    res_tol=1e-7,
    verbose=True,
    slew_rate=1e-2,
    P=1.0 * jaxm.ones((N,)),
    dtype=np.float64,
    device="cpu",
)
problem = jaxm.to(problem, dtype=jaxm.float64)

In [21]:
#X1, U1, _ = solve(**problem, direct_solve=True)
#X1, U1, _ = solve_problems([problem], direct_solve=True)
sols = solve_problems(7 * [problem], direct_solve=True)
X1, U1, _ = sols[0]

+------+------------+------------+------------+----------+----------+----------+
|  it  |   elaps    |    obj     |   resid    |  reg_x   |  reg_u   |  alpha   |
+------+------------+------------+------------+----------+----------+----------+
| 0001 |  3.642e-02 | -3.758e+00 |  1.934e+00 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0002 |  6.806e-02 | -3.998e+00 |  1.553e-02 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0003 |  1.003e-01 | -3.998e+00 |  3.470e-03 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0004 |  1.327e-01 | -3.998e+00 |  2.126e-03 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0005 |  1.652e-01 | -3.998e+00 |  1.452e-03 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0006 |  1.953e-01 | -3.998e+00 |  9.921e-04 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0007 |  2.272e-01 | -3.998e+00 |  6.790e-04 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0008 |  2.601e-01 | -3.998e+00 |  4.654e-04 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0009 |  2.936e-01 | -3.998e+00 |  3.194e-04 |  1.0e+00 |  1.0e-01 |  1.0e+02 |
| 0010 |  3.253e-01 | -3.998

In [None]:
X1.dtype

dtype('float64')

In [23]:
from jax.experimental.maps import xmap


def f(x, y):
    return x * y


import jax.numpy as jnp

xs = jnp.arange(12).reshape(4, 3)  # This will be batched (4 batches, each of size 3)
y = jaxm.randn((4, 3))  # This will be the same for all batches


# xmap configuration
f_xmap = xmap(
    f,
    #in_axes=({0: "batch"}, {0: None}),
    #out_axes={0: "batch"},
    in_axes=(["i", ...], ["i", ...]),
    out_axes=["i", ...],
)

result_xmap = f_xmap(xs, y)

In [24]:
result_xmap

Array([[  0.        ,   1.4197668 ,   0.48174685],
       [ -4.028483  ,   1.479233  ,   5.863267  ],
       [  4.3966804 ,   5.27136   ,  -1.712921  ],
       [-12.868704  , -14.880747  ,  -7.1762233 ]], dtype=float32)

In [32]:
x = jaxm.randn((7, 12))
y = jnp.arange(1) + 1e3
#jaxm.jit(

def fn_call(x, y):
    return xmap(
        lambda x, y: x * y,
        in_axes=(["i", ...], [...]),
        # in_axes=(['i', ...], None),
        out_axes=["i", ...],
    )(x, y)

In [34]:
fn_call)(x, y)

AttributeError: 'NoneType' object has no attribute 'empty'

In [92]:
import jax
from jax.experimental import mesh_utils
from jax.sharding import PositionalSharding

devices = mesh_utils.create_device_mesh((4,), devices=jax.devices("cpu")[:4])
sharding = PositionalSharding(devices)

In [93]:
sharding.reshape((-1, 1))

PositionalSharding([[{CPU 0}]
                    [{CPU 1}]
                    [{CPU 2}]
                    [{CPU 3}]])

In [95]:
@jaxm.jit
def matmul(x, W, b):
    z = jnp.dot(x, W) + b
    z = jaxm.sin(jnp.dot(z, W) + b)
    z = jaxm.sin(jnp.dot(z, W) + b)
    return z

In [96]:
x = jaxm.randn((1024, 2048), device="cpu")
W, b = jaxm.randn((2048, 2048), device="cpu"), jaxm.randn((2048,), device="cpu")

In [99]:
x_reshape = jax.device_put(x.reshape((4, -1, x.shape[-1])), sharding.reshape((-1, 1, 1)))
jax.debug.visualize_array_sharding(x_reshape.reshape((x.shape[0], -1)))

In [118]:
matmul_vmap = jaxm.jit(jaxm.vmap(matmul, in_axes=(0, None, None)))

In [124]:
%timeit matmul_vmap(x, W, b).block_until_ready()

39.6 ms ± 798 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [125]:
W_replicate = jax.device_put(W, sharding.replicate())
b_replicate = jax.device_put(b, sharding.replicate())

In [126]:
%timeit matmul(x_reshape, W_replicate, b_replicate).block_until_ready()

36.7 ms ± 1.58 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [123]:
import jax
jax.debug.visualize_array_sharding(x)

In [10]:
x, y, axes = utils.arg_hist[0]

In [20]:
for x_, y_, ax in zip(x, y, axes):
    if ax is None:
        continue
    assert x_.ndim - 1 == y_.ndim
    assert x_.shape[0] == 2
    assert x_.shape[2:] == y_.shape[1:]
