# Basic example

<div class="admonition note alert alert-warning">
<p class="admonition-title">Note</p>

All examples are expected to run from the `examples/<example_name>` directory of the [Tesseract-JAX repository](https://github.com/pasteurlabs/tesseract-jax).
</div>

Tesseract-JAX is a lightweight extension to Tesseract Core that makes Tesseracts look and feel like regular JAX primitives, and makes them jittable and differentiable.

We start by importing `tesseract_core.Tesseract`, and initialize it to an existing Tesseract image.

In [1]:
%%bash
# Build vectoradd_jax Tesseract so we can use it below
tesseract build vectoradd_jax/

[2K [1;2m[[0m[34mi[0m[1;2m][0m Building image [33m...[0m
[2K[37m⠼[0m [37mProcessing[0m
[1A[2K [1;2m[[0m[34mi[0m[1;2m][0m Built image sh[1;92ma256:c790[0m2d7912d7, [1m[[0m[32m'vectoradd_jax:latest'[0m[1m][0m


["vectoradd_jax:latest"]


In [2]:
from tesseract_core import Tesseract

vectoradd = Tesseract.from_image("vectoradd_jax")

# Call .serve to keep the Tesseract alive, so we don't have to use `with vectoradd` every time
vectoradd.serve()

We can inspect the operations a specific Tesseract implements like this:

In [3]:
vectoradd.available_endpoints

['apply',
 'jacobian',
 'jacobian_vector_product',
 'vector_jacobian_product',
 'health',
 'input_schema',
 'output_schema',
 'abstract_eval']

The main entrypoint to `tesseract_jax` is the function `apply_tesseract`. Let's try it out to perform an operation:

In [4]:
import numpy as np

from tesseract_jax import apply_tesseract

a = {"v": np.array([1.0, 2.0, 3.0], dtype="float32")}
b = {
    "v": np.array([4.0, 5.0, 6.0], dtype="float32"),
    "s": np.array(2.0, dtype="float32"),
}

apply_tesseract(vectoradd, inputs={"a": a, "b": b})

{'vector_add': {'normed_result': Array([0.42426407, 0.56568545, 0.70710677], dtype=float32),
  'result': Array([ 9., 12., 15.], dtype=float32)},
 'vector_min': {'normed_result': Array([-0.5025707 , -0.5743665 , -0.64616233], dtype=float32),
  'result': Array([-7., -8., -9.], dtype=float32)}}

It checks out, as $(1,2,3)+2\cdot(4,5,6)=(9,12,15)$.

We can also combine Tesseracts trivially, as if they were local functions:

In [5]:
def fancy_operation(a: np.ndarray, b: np.ndarray) -> np.float32:
    """Fancy operation."""
    result = apply_tesseract(vectoradd, inputs={"a": a, "b": b})
    result = apply_tesseract(
        vectoradd, inputs={"a": {"v": result["vector_add"]["result"]}, "b": b}
    )
    result = (
        2.0 * result["vector_add"]["normed_result"] + b["v"]
    )  # We can mix and match with local operations
    result = apply_tesseract(vectoradd, inputs={"a": {"v": result}, "b": b})
    return result["vector_add"]["result"][1]


fancy_operation(a, b)

Array(16.135319, dtype=float32)

And all this is compatible with jax.jit

In [6]:
import jax

jitted_op = jax.jit(fancy_operation)
jitted_op(a, b)

Array(16.135319, dtype=float32)

And even with automatic differentiation (both when jitting and when not); for instance, here is how to calculate jvp

In [7]:
jax.jvp(fancy_operation, (a, b), (a, b))

(Array(16.135319, dtype=float32), Array(25.004124, dtype=float32))

(where the first argument is the primal value, and the second is the Jacobian of fancy_operation calculated in $(a,b)$ multiplied with the vector $(a \, a)$). Similarly, to calculate VJPs we can just do the following:

In [8]:
primal, vjp = jax.vjp(fancy_operation, a, b)
vjp(primal)

({'v': Array([-0.20733577,  0.56435245, -0.329298  ], dtype=float32)},
 {'s': Array(80.709854, dtype=float32),
  'v': Array([-0.8293431, 50.663364 , -1.317192 ], dtype=float32)})

Where each element of the tuple is associated to the corresponding argument `a` or `b`.

Finally, to calculate the gradient of `fancy_operation` w.r.t. the `a` argument at the point $(a,b)$ we just do:

In [9]:
jax.grad(fancy_operation)(a, b)

{'v': Array([-0.01284981,  0.03497622, -0.02040852], dtype=float32)}

Notice that all of the above works also in conjunction with jax.jit

In [10]:
# Jit inside differentiation
jax.jvp(jitted_op, (a, b), (a, b))

primal, vjp = jax.vjp(jitted_op, a, b)
vjp(primal)

jax.grad(jitted_op)(a, b)

# And jax.jit could also wrap everything
jax.jit(jax.grad(jitted_op))(a, b)

{'v': Array([-0.01284981,  0.03497622, -0.02040852], dtype=float32)}

In [11]:
vectoradd.teardown()