You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX is very useful and it would be nice is sympy supported lambdifying expressions using it. I can see a couple different features
Swap numpy and scipy calls for jax.numpy and jax.scipy when available.
Replace derivatives of scalar valued functions with JAX derivatives so that e.g., Derivative(f(x,y), (x,2)) maps to jax.grad(jax.grad(f, 0), 0)
Jacobian or Jacobian-vector products of Array. This seems less straightforward or meaningful since each array element is treated separately and not required to have the same arguments.
The text was updated successfully, but these errors were encountered:
JAX is very useful and it would be nice is
sympy
supportedlambdify
ing expressions using it. I can see a couple different featuresnumpy
andscipy
calls forjax.numpy
andjax.scipy
when available.Derivative(f(x,y), (x,2))
maps tojax.grad(jax.grad(f, 0), 0)
Array
. This seems less straightforward or meaningful since each array element is treated separately and not required to have the same arguments.The text was updated successfully, but these errors were encountered: