Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement helper @as_jax_op to wrap JAX functions in PyTensor #537

Open
ricardoV94 opened this issue Dec 7, 2023 · 0 comments
Open

Implement helper @as_jax_op to wrap JAX functions in PyTensor #537

ricardoV94 opened this issue Dec 7, 2023 · 0 comments

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 7, 2023

Description

This blogpost walks through the logic for 3 different examples: https://github.com/pymc-labs/communications/issues/21 and shows the logic is always the same:

  1. Wrap jitted forward pass in Op
  2. Wrap jitted jvp (or vjp I can never remember) as a GradOp to provide gradient implementation
  3. Dispatch unjitted versions of the two Ops for integration with `function(... , mode="JAX")

Things that cannot be obtained automatically (or maybe they can?) and should be opt-in as in @as_op:
4. Input and outputs types
5. infer_shape

@ricardoV94 ricardoV94 added enhancement New feature or request jax backend compatibility feature request and removed enhancement New feature or request labels Dec 7, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant