From 1beb8eda6aaadacf5a7f0fc943892f285f588f60 Mon Sep 17 00:00:00 2001 From: Heiko Zimmermann Date: Wed, 23 Apr 2025 08:53:55 +0200 Subject: [PATCH 1/3] fix: fixed typo in jax recipe --- tesseract_core/sdk/templates/jax/tesseract_api.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tesseract_core/sdk/templates/jax/tesseract_api.py b/tesseract_core/sdk/templates/jax/tesseract_api.py index 04df5b24..8cee0499 100644 --- a/tesseract_core/sdk/templates/jax/tesseract_api.py +++ b/tesseract_core/sdk/templates/jax/tesseract_api.py @@ -107,16 +107,16 @@ def vector_jacobian_product( def abstract_eval(abstract_inputs): """Calculate output shape of apply from the shape of its inputs.""" - is_shapedtye_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"}) - is_shapedtye_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct) + is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"}) + is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct) jaxified_inputs = jax.tree.map( - lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtye_dict(x) else x, + lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x, abstract_inputs.model_dump(), - is_leaf=is_shapedtye_dict, + is_leaf=is_shapedtype_dict, ) dynamic_inputs, static_inputs = eqx.partition( - jaxified_inputs, filter_spec=is_shapedtye_struct + jaxified_inputs, firter_spec=is_shapedtype_struct ) def wrapped_apply(dynamic_inputs): @@ -126,10 +126,10 @@ def wrapped_apply(dynamic_inputs): jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs) return jax.tree.map( lambda x: ( - {"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtye_struct(x) else x + {"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x ), jax_shapes, - is_leaf=is_shapedtye_struct, + is_leaf=is_shapedtype_struct, ) From 119d88759ac4deb1a6dfaf38047ad68cd6b7e42f Mon Sep 17 00:00:00 2001 From: Heiko Zimmermann Date: Wed, 23 Apr 2025 09:06:27 +0200 Subject: [PATCH 2/3] fixed typo in equinox partition argument --- tesseract_core/sdk/templates/jax/tesseract_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tesseract_core/sdk/templates/jax/tesseract_api.py b/tesseract_core/sdk/templates/jax/tesseract_api.py index 8cee0499..65c45597 100644 --- a/tesseract_core/sdk/templates/jax/tesseract_api.py +++ b/tesseract_core/sdk/templates/jax/tesseract_api.py @@ -116,7 +116,7 @@ def abstract_eval(abstract_inputs): is_leaf=is_shapedtype_dict, ) dynamic_inputs, static_inputs = eqx.partition( - jaxified_inputs, firter_spec=is_shapedtype_struct + jaxified_inputs, filter_spec=is_shapedtype_struct ) def wrapped_apply(dynamic_inputs): From 68346ff5b3bebd76b29020797c9699344b8c3bb6 Mon Sep 17 00:00:00 2001 From: Jonathan Brodrick Date: Wed, 23 Apr 2025 09:58:01 +0100 Subject: [PATCH 3/3] Typo scalar -> static --- tesseract_core/sdk/templates/jax/tesseract_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tesseract_core/sdk/templates/jax/tesseract_api.py b/tesseract_core/sdk/templates/jax/tesseract_api.py index 65c45597..65fdb2a4 100644 --- a/tesseract_core/sdk/templates/jax/tesseract_api.py +++ b/tesseract_core/sdk/templates/jax/tesseract_api.py @@ -22,7 +22,7 @@ # inputs/outputs as static. As Tesseract scalar objects (e.g. Float32) are # essentially just wrappers around numpy 0D arrays, they will be considered to # be dynamic and will be traced by JAX. -# If you want to treat numerical values as scalar you will need to use +# If you want to treat scalar numerical values as static you will need to use # built-in Python types (e.g. float, int) instead of Float32.