In [2]:
import jax

def fwd_and_bwd(f):
  def fwd(*args):
    return jax.vjp(f, *args)
  def bwd(f_vjp, out_grad):
    return f_vjp(out_grad)
  return fwd, bwd

In [3]:
import jax.numpy as jnp

def foo(inp):
  a, b = inp
  return jnp.sin(a) @ jnp.cos(b)

fwd, bwd = fwd_and_bwd(foo)

In [4]:
a = jax.random.normal(jax.random.key(1), shape=(2, 2))
b = jax.random.normal(jax.random.key(2), shape=(2, 2))

In [6]:
out_primals, partial = fwd((a, b))

The partially evaluated VJP function is actually a PyTree. The jaxpr that is run
backwards later is in the tree metadata. See discussion: https://github.com/jax-ml/jax/issues/26579#issuecomment-2670531713

In [18]:
from jax.tree_util import tree_flatten

flat_residual, spec = tree_flatten(partial)

In [22]:
# These are simple arrays now
flat_residual

[Array([[0.98809826, 0.9964145 ],
        [0.9907689 , 0.98800594]], dtype=float32),
 Array([[ 0.3528105 ,  0.9594295 ],
        [-0.6733448 ,  0.92576915]], dtype=float32),
 Array([[0.9356948 , 0.28194872],
        [0.7393286 , 0.3780893 ]], dtype=float32),
 Array([[-0.15382399,  0.08460601],
        [-0.13556181, -0.15441589]], dtype=float32)]

In [23]:
# We can see the jaxpr in the tree spec
spec

PyTreeDef(CustomNode(Partial[_HashableCallableShim(functools.partial(<function _vjp_pullback_wrapper at 0x7f1f82f305e0>, 'foo', [ShapedArray(float32[2,2])], (PyTreeDef(*), PyTreeDef(((*, *),)))))], [(CustomNode(Partial[_HashableCallableShim(functools.partial(<function vjp.<locals>.unbound_vjp at 0x7f1f8290add0>, [(ShapedArray(float32[2,2]), None)], { [34m[22m[1mlambda [39m[22m[22ma[35m:f32[2,2][39m b[35m:f32[2,2][39m c[35m:f32[2,2][39m d[35m:f32[2,2][39m; e[35m:f32[2,2][39m f[35m:f32[2,2][39m. [34m[22m[1mlet
    [39m[22m[22mg[35m:f32[2,2][39m = pjit[
      name=sin
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; e[35m:f32[2,2][39m a[35m:f32[2,2][39m. [34m[22m[1mlet[39m[22m[22m g[35m:f32[2,2][39m = mul e a [34m[22m[1min [39m[22m[22m(g,) }
    ] e a
    h[35m:f32[2,2][39m = pjit[
      name=cos
      jaxpr={ [34m[22m[1mlambda [39m[22m[22m; f[35m:f32[2,2][39m b[35m:f32[2,2][39m. [34m[22m[1mlet
          [39m[22m[22mi[35m:f

In [29]:
# Inspired by https://github.com/lucidrains/jax2torch/blob/main/jax2torch/jax2torch.py, which is
# inspired by https://gist.github.com/mattjj/e8b51074fed081d765d2f3ff90edf0e9

import torch
from torch.utils import dlpack as torch_dlpack

import jax
from jax import dlpack as jax_dlpack
import jax.numpy as jnp
from jax.tree_util import tree_map, tree_flatten, tree_unflatten

from inspect import signature
from functools import wraps


def j2t(x_jax):
  x_torch = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(x_jax))
  return x_torch


def t2j(x_torch):
  x_torch = x_torch.contiguous()  # https://github.com/google/jax/issues/8082
  x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
  return x_jax


def tree_t2j(x_torch):
  return tree_map(lambda t: t2j(t)
                  if isinstance(t, torch.Tensor) else t, x_torch)


def tree_j2t(x_jax):
  return tree_map(lambda t: j2t(t) if isinstance(t, jnp.ndarray) else t, x_jax)


def jax2torch(fn):

  @wraps(fn)
  def inner(*args, **kwargs):

    class JaxFun(torch.autograd.Function):

      @staticmethod
      def forward(ctx, *args):
        args = tree_t2j(args)
        y_, fun_vjp = jax.vjp(fn, *args)
        residuals, ctx.vjp_spec = tree_flatten(fun_vjp)
        ctx.save_for_backward(*map(j2t, residuals))
        return tree_j2t(y_)

      @staticmethod
      def backward(ctx, *grad_args):
        fun_vjp = tree_unflatten(ctx.vjp_spec, map(t2j, ctx.saved_tensors))
        grad_args = tree_t2j(grad_args) if len(grad_args) > 1 else t2j(grad_args[0])
        grads = fun_vjp(grad_args)
        grads = tuple(
            map(lambda t: t if isinstance(t, jnp.ndarray) else None, grads))
        return tree_j2t(grads)

    sig = signature(fn)
    bound = sig.bind(*args, **kwargs)
    bound.apply_defaults()
    return JaxFun.apply(*bound.arguments.values())

  return inner

Demo

In [None]:
import jax
import torch

# Jax function

@jax.jit
def jax_pow(x, y = 2):
  return x ** y

# convert to Torch function

torch_pow = jax2torch(jax_pow)

# run it on Torch data!

x = torch.tensor([1., 2., 3.])
y = torch_pow(x, y = 3)
print(y)  # tensor([1., 8., 27.])

# And differentiate!

x = torch.tensor([2., 3.], requires_grad=True)
y = torch.sum(torch_pow(x, y=3))
y.backward()
print(x.grad) # tensor([12., 27.])

tensor([ 1.,  8., 27.])
tensor([12., 27.])


  x_jax = jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(x_torch))
