Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support JAX for lambdify #21232

Open
3 tasks
jjaraalm opened this issue Apr 2, 2021 · 1 comment
Open
3 tasks

Support JAX for lambdify #21232

jjaraalm opened this issue Apr 2, 2021 · 1 comment

Comments

@jjaraalm
Copy link

jjaraalm commented Apr 2, 2021

JAX is very useful and it would be nice is sympy supported lambdifying expressions using it. I can see a couple different features

  • Swap numpy and scipy calls for jax.numpy and jax.scipy when available.
  • Replace derivatives of scalar valued functions with JAX derivatives so that e.g., Derivative(f(x,y), (x,2)) maps to jax.grad(jax.grad(f, 0), 0)
  • Jacobian or Jacobian-vector products of Array. This seems less straightforward or meaningful since each array element is treated separately and not required to have the same arguments.
@so-rose
Copy link

so-rose commented Apr 18, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants