Skip to content

Create vectorized, value_and_grad and shape versions of JAXOp #1645

@ricardoV94

Description

@ricardoV94

Description

When we have a JAXOp in the final graph in a non-jax backend we may want to manipulate the JAX Op for efficiency. We could rewrite Blockwise(JAXOp) -> JAXOp whose inner function is vectorized.

If we have both the Op and the gradient, we could rewrite into a single op that uses value_and_grad under the hood.

And similarly if we only need the shape we could rewrite into an Op whose internal function only computes the shape. This last one is only relevant if the original Op doesn't remain in the graph.

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions