Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

### Tesseract-JAX

`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of [JAX](https://github.com/jax-ml/jax) programs, with full support for function transformations like JIT, `grad`, and more.
Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable.

[Read the docs](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/) |
[Explore the examples](https://github.com/pasteurlabs/tesseract-jax/tree/main/examples) |
Expand All @@ -12,7 +12,16 @@

---

The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.
The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](https://docs.pasteurlabs.ai/projects/tesseract-jax/latest/content/api.html#tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:

```python
@jax.jit
def vector_sum(x, y):
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
return res["vector_add"]["result"].sum()

jax.grad(vector_sum)(x, y) # 🎉
```

## Quick start

Expand All @@ -31,7 +40,8 @@ The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesser
2. Build an example Tesseract:

```bash
$ tesseract build examples/simple/vectoradd_jax
$ git clone https://github.com/pasteurlabs/tesseract-jax
$ tesseract build tesseract-jax/examples/simple/vectoradd_jax
```

3. Use it as part of a JAX program via the JAX-native `apply_tesseract` function:
Expand Down
7 changes: 2 additions & 5 deletions docs/content/get-started.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
# Get started

`tesseract-jax` executes [Tesseracts](https://github.com/pasteurlabs/tesseract-core) as part of JAX programs, with full support for function transformations like JIT, `grad`, `jvp`, and more.

The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines.

## Quick start

```{note}
Expand All @@ -23,7 +19,8 @@ For more detailed installation instructions, please refer to the [Tesseract Core
2. Build an example Tesseract:

```bash
$ tesseract build examples/simple/vectoradd_jax
$ git clone https://github.com/pasteurlabs/tesseract-jax
$ tesseract build tesseract-jax/examples/simple/vectoradd_jax
```

3. Use it as part of a JAX program:
Expand Down
15 changes: 13 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
# Tesseract-JAX

```{include} content/get-started.md
:start-line: 2
Tesseract-JAX is a lightweight extension to [Tesseract Core](https://github.com/pasteurlabs/tesseract-core) that makes Tesseracts look and feel like regular [JAX](https://github.com/jax-ml/jax) primitives, and makes them jittable, differentiable, and composable.

The API of Tesseract-JAX consists of a single function, [`apply_tesseract(tesseract_client, inputs)`](tesseract_jax.apply_tesseract), which is fully traceable by JAX. This enables end-to-end autodifferentiation and JIT compilation of Tesseract-based pipelines:

```python
@jax.jit
def vector_sum(x, y):
res = apply_tesseract(vectoradd_tesseract, {"a": {"v": x}, "b": {"v": y}})
return res["vector_add"]["result"].sum()

jax.grad(vector_sum)(x, y) # 🎉
```

Want to learn more? See how to [get started](content/get-started.md) with Tesseract-JAX, explore the [API reference](content/api.md), or learn by [example](demo_notebooks/simple.ipynb).

## License

Tesseract JAX is licensed under the [Apache License 2.0](https://github.com/pasteurlabs/tesseract-jax/LICENSE) and is free to use, modify, and distribute (under the terms of the license).
Expand Down