# Tesseract-JAX basic example: vector addition

<div class="admonition note alert alert-warning">
<p class="admonition-title">Note</p>

All examples are expected to run from the `examples/<example_name>` directory of the [Tesseract-JAX repository](https://github.com/pasteurlabs/tesseract-jax).
</div>

Tesseract-JAX is a lightweight extension to Tesseract Core that makes Tesseracts look and feel like regular JAX primitives, and makes them jittable and differentiable.

In this demo, you will learn how to:
1. Build a Tesseract
1. Access its endpoints via Tesseract-JAX's `apply_tesseract()` entrypoint
1. Compose Tesseracts into more complex functions, blending multiple Tesseract applications with local operations
1. Use JAX with the resulting function composition to perform JIT compilations, and / or autodifferentiate the function (via JVP, VJP, and explicit derivatives)

## Build the Tesseract

You may build the Tesseract either via your command line, or running the cell below (you can skip running this if already built).

In [1]:
%%bash
# Build vectoradd_jax Tesseract so we can use it below
tesseract build vectoradd_jax/

[2K [1;2m[[0m[34mi[0m[1;2m][0m Building image [33m...[0m
[2K[37m⠼[0m [37mProcessing[0m
[1A[2K [1;2m[[0m[34mi[0m[1;2m][0m Built image sh[1;92ma256:c790[0m2d7912d7, [1m[[0m[32m'vectoradd_jax:latest'[0m[1m][0m


["vectoradd_jax:latest"]


## Run the Tesseract

The main entrypoint to `tesseract_jax` is `apply_tesseract()`.

Using the `vectoradd_jax` Tesseract image we built earlier, let's add two vectors together.
The result should be:

$$\begin{pmatrix} 1 \\ 2 \\ 3 \end{pmatrix} + 2 \cdot \begin{pmatrix} 4 \\ 5 \\ 6 \end{pmatrix} = \begin{pmatrix} 9 \\ 12 \\ 15 \end{pmatrix}$$

We start by importing `tesseract_core.Tesseract`, and create an instance from the Tesseract image `vectoradd_jax` image we built.

Tesseract-Core offers a `.serve()` method to keep the Tesseract alive, as an alternative to using `with` statements repeatedly to access the Tesseract endpoints.

In [2]:
from tesseract_core import Tesseract

vectoradd = Tesseract.from_image("vectoradd_jax")
vectoradd.serve()

The operations a specific Tesseract provides via endpoints may vary. We can introspect this using the `.available_endpoints` attribute.

In [3]:
vectoradd.available_endpoints

['apply',
 'jacobian',
 'jacobian_vector_product',
 'vector_jacobian_product',
 'health',
 'input_schema',
 'output_schema',
 'abstract_eval']

To perform our vector addition, we make use of the `apply_tesseract()` function mentioned earlier.

The API of `vectoradd_jax` has an `InputSchema` which takes two parameters `a` and `b`.

`a` and `b` also share a schema, requiring a scalar `s` and vector `v` parameter (although `s = 1` by default).

This can be passed to `apply_tesseract()` with a dict of dicts:

In [5]:
from pprint import pprint

import numpy as np

from tesseract_jax import apply_tesseract

a = {"v": np.array([1.0, 2.0, 3.0], dtype="float32")}
b = {
    "v": np.array([4.0, 5.0, 6.0], dtype="float32"),
    "s": np.array(2.0, dtype="float32"),
}

outputs = apply_tesseract(vectoradd, inputs={"a": a, "b": b})

pprint(outputs)

{'vector_add': {'normed_result': Array([0.42426407, 0.56568545, 0.70710677], dtype=float32),
                'result': Array([ 9., 12., 15.], dtype=float32)},
 'vector_min': {'normed_result': Array([-0.5025707 , -0.5743665 , -0.64616233], dtype=float32),
                'result': Array([-7., -8., -9.], dtype=float32)}}


As we can see, `outputs['vector_add']` gives the expected $(9, 12, 15)$.

## Function composition via Tesseracts

Here you'll learn how Tesseract-JAX enables you to compose chains of Tesseract evaluations, blended with local operations, while retaining the benefits of JAX.

The function below applies `vectoradd` twice, *ie.* $(\mathbf{a} + \mathbf{b}) + \mathbf{a}$, then performs local arithmetic on the outputs, applies `vectoradd` once more, and finally returns a single element of the result.

As you will see, we can perform this fine-grained control of our Tesseract evaluation without sacrificing JAX's JIT compiler or autodifferentiation functionality.

In [6]:
def fancy_operation(a: np.ndarray, b: np.ndarray) -> np.float32:
    """Fancy operation."""
    result = apply_tesseract(vectoradd, inputs={"a": a, "b": b})
    result = apply_tesseract(
        vectoradd, inputs={"a": {"v": result["vector_add"]["result"]}, "b": b}
    )
    result = (
        2.0 * result["vector_add"]["normed_result"] + b["v"]
    )  # We can mix and match with local operations
    result = apply_tesseract(vectoradd, inputs={"a": {"v": result}, "b": b})
    return result["vector_add"]["result"][1]


fancy_operation(a, b)

Array(16.135319, dtype=float32)

### This is compatible with `jax.jit()`

In [8]:
import jax

jitted_op = jax.jit(fancy_operation)
jitted_op(a, b)

Array(16.135319, dtype=float32)

### We can use automatic differentiation

This is possible with or without using JIT.

#### Computing JVP

In [10]:
primal, jvp = jax.jvp(fancy_operation, (a, b), (a, b))
print(f"{primal=}, {jvp=}")

primal=Array(16.135319, dtype=float32), jvp=Array(25.004124, dtype=float32)


Where `jvp` is the Jacobian of `fancy_operation` calculated in $(a,b)$ multiplied with the vector $(a, a)$.

#### Computing VJP

In [11]:
primal, vjp = jax.vjp(fancy_operation, a, b)
pprint(vjp(primal))

({'v': Array([-0.20733577,  0.56435245, -0.329298  ], dtype=float32)},
 {'s': Array(80.709854, dtype=float32),
  'v': Array([-0.8293431, 50.663364 , -1.317192 ], dtype=float32)})


Where each element of the tuple is associated to the corresponding argument `a` or `b`.

#### Computing the gradient

Let's calculate the gradient of `fancy_operation` w.r.t. the `a` argument at the point $(a,b)$. `a` is the first argument, so we pass `jax.grad()` a parameter `argnums=0`.

In [12]:
jax.grad(fancy_operation, argnums=0)(a, b)

{'v': Array([-0.01284981,  0.03497622, -0.02040852], dtype=float32)}

Or similar to our VJP calculation, we could calculate the gradients for both parameters `a` and `b` simultaneously.

In [14]:
jax.grad(fancy_operation, argnums=[0, 1])(a, b)

({'v': Array([-0.01284981,  0.03497622, -0.02040852], dtype=float32)},
 {'s': Array(5.002062, dtype=float32),
  'v': Array([-0.05139923,  3.139905  , -0.08163408], dtype=float32)})

### Combining JIT and autodiff!

Notice that all of the above works also in conjunction with `jax.jit()`

In [16]:
# Jit inside differentiation
jax.jvp(jitted_op, (a, b), (a, b))

primal, vjp = jax.vjp(jitted_op, a, b)
vjp(primal)

jax.grad(jitted_op)(a, b)

# And jax.jit could also wrap everything
jax.jit(jax.grad(jitted_op))(a, b)

{'v': Array([-0.01284981,  0.03497622, -0.02040852], dtype=float32)}

## Teardown and conclusions

Since we kept the Tesseract alive using `.serve()`, now we need to stop it using `.teardown()`

In [18]:
vectoradd.teardown()

And that's it!
You've worked through building up differentiable pipelines with Tesseracts that blend seamlessly with JAX's API, thanks to Tesseract-JAX.