diff --git a/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py b/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py index d9b1209ac07..ef7ca7bb194 100644 --- a/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py +++ b/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py @@ -17,7 +17,6 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union from jax.experimental import jax2tf -from jax.experimental.export import shape_poly import tensorflow as tf from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion @@ -25,7 +24,6 @@ _TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY Array = Any DType = Any -PolyShape = shape_poly.PolyShape class _ReusableSavedModelWrapper(tf.train.Checkpoint): @@ -60,7 +58,7 @@ def convert_jax( *, input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]], model_dir: str, - polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None, + polymorphic_shapes: Optional[Sequence[str]] = None, **tfjs_converter_params): """Converts a JAX function `jax_apply_fn` and model parameters to a TensorflowJS model.