In [1]:

from pprint import pprint
from tesseract_core import Tesseract
from tesseract_jax import apply_tesseract
import jax.numpy as jnp
import jax

vectoradd = Tesseract.from_tesseract_api("examples/simple/partial/tesseract_api.py")


input_dict = {"a": jnp.array([1.0, 2.0, 3.0], dtype="float32")}

outputs = jax.jit(apply_tesseract)(vectoradd, inputs=input_dict)
# outputs = apply_tesseract(vectoradd, inputs=input_dict)
pprint(outputs)



{'b': Array([2., 4., 6.], dtype=float32),
 'c': Array([1., 1., 1.], dtype=float32)}


In [2]:
grad = vectoradd.vector_jacobian_product(
    inputs={
        "a": jnp.array([1.0, 2.0, 3.0], dtype="float32"),
    },
    vjp_inputs=["a"],
    vjp_outputs=["b"],
    cotangent_vector={"b": jnp.ones((3,), dtype="float32")},
)["a"]

print("Gradient shape:", grad.shape)

Gradient shape: (3,)


In [21]:
from jax.extend import core
import numpy as np
import jax

multiply_add_p = core.Primitive("multiply_add")  # Create the primitive

def multiply_add_prim(x, y, z):
  """The JAX-traceable way to use the JAX primitive.

  Note that the traced arguments must be passed as positional arguments
  to `bind`.
  """
  return multiply_add_p.bind(x, y, z)

def square_add_prim(a, b):
  """A square-add function implemented using the new JAX-primitive."""
  return multiply_add_prim(a, a, b)

def multiply_add_impl(x, y, z):
  """Concrete implementation of the primitive.

  This function does not need to be JAX traceable.

  Args:
    x, y, z: The concrete arguments of the primitive. Will only be called with
      concrete values.

  Returns:
    the concrete result of the primitive.
  """
  # Note: you can use the ordinary (non-JAX) NumPy, which is not JAX-traceable.

  a = {
      "x": x
  }

  b, treedef = jax.tree.flatten(a)
  a = jax.tree.unflatten(treedef, b)

  return np.add(np.multiply(a["x"], y), z)

# Now, register the primal implementation with JAX:
multiply_add_p.def_impl(multiply_add_impl)

def multiply_add_abstract_eval(xs, ys, zs):
  """Abstract evaluation of the primitive.

  This function does not need to be JAX traceable. It will be invoked with
  abstractions of the actual arguments

  Args:
    xs, ys, zs: Abstractions of the arguments.

  Result:
    a ShapedArray for the result of the primitive.
  """
  assert xs.shape == ys.shape
  assert xs.shape == zs.shape

  raise NotImplementedError("Abstract eval not implemented yet")
  return core.ShapedArray(xs.shape, xs.dtype)

# Now, register the abstract evaluation with JAX:
multiply_add_p.def_abstract_eval(multiply_add_abstract_eval)

from jax._src.lib.mlir.dialects import hlo
def multiply_add_lowering(ctx, xc, yc, zc):
  """The compilation to XLA of the primitive.

  Given an mlir.ir.Value for each argument, return the mlir.ir.Values for
  the results of the function.

  Does not need to be a JAX-traceable function.
  """
  return [hlo.AddOp(hlo.MulOp(xc, yc), zc).result]

# Now, register the lowering rule with JAX.
# For GPU, refer to the https://docs.jax.dev/en/latest/Custom_Operation_for_GPUs.html
from jax.interpreters import mlir

mlir.register_lowering(multiply_add_p, multiply_add_lowering, platform='cpu')

In [27]:
assert square_add_prim(2., 10.) == 14.

In [25]:

import jax 
jax.jit(square_add_prim)(2., 10.)


NotImplementedError: Abstract eval not implemented yet