diff --git a/tesseract_core/sdk/templates/jax/tesseract_api.py b/tesseract_core/sdk/templates/jax/tesseract_api.py index 04df5b24..65fdb2a4 100644 --- a/tesseract_core/sdk/templates/jax/tesseract_api.py +++ b/tesseract_core/sdk/templates/jax/tesseract_api.py @@ -22,7 +22,7 @@ # inputs/outputs as static. As Tesseract scalar objects (e.g. Float32) are # essentially just wrappers around numpy 0D arrays, they will be considered to # be dynamic and will be traced by JAX. -# If you want to treat numerical values as scalar you will need to use +# If you want to treat scalar numerical values as static you will need to use # built-in Python types (e.g. float, int) instead of Float32. @@ -107,16 +107,16 @@ def vector_jacobian_product( def abstract_eval(abstract_inputs): """Calculate output shape of apply from the shape of its inputs.""" - is_shapedtye_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"}) - is_shapedtye_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct) + is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"}) + is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct) jaxified_inputs = jax.tree.map( - lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtye_dict(x) else x, + lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x, abstract_inputs.model_dump(), - is_leaf=is_shapedtye_dict, + is_leaf=is_shapedtype_dict, ) dynamic_inputs, static_inputs = eqx.partition( - jaxified_inputs, filter_spec=is_shapedtye_struct + jaxified_inputs, filter_spec=is_shapedtype_struct ) def wrapped_apply(dynamic_inputs): @@ -126,10 +126,10 @@ def wrapped_apply(dynamic_inputs): jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs) return jax.tree.map( lambda x: ( - {"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtye_struct(x) else x + {"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x ), jax_shapes, - is_leaf=is_shapedtye_struct, + is_leaf=is_shapedtype_struct, )