From 45f1a86aceca5c13c09bf375e88ad3610828f004 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Tue, 23 Jan 2024 15:47:28 -0800 Subject: [PATCH] merge g3 cl (#8138) --- .../python/tensorflowjs/converters/jax_conversion.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py b/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py index d9b1209ac07..ef7ca7bb194 100644 --- a/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py +++ b/tfjs-converter/python/tensorflowjs/converters/jax_conversion.py @@ -17,7 +17,6 @@ from typing import Any, Callable, Optional, Sequence, Tuple, Union from jax.experimental import jax2tf -from jax.experimental.export import shape_poly import tensorflow as tf from tensorflowjs.converters import tf_saved_model_conversion_v2 as saved_model_conversion @@ -25,7 +24,6 @@ _TF_SERVING_KEY = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY Array = Any DType = Any -PolyShape = shape_poly.PolyShape class _ReusableSavedModelWrapper(tf.train.Checkpoint): @@ -60,7 +58,7 @@ def convert_jax( *, input_signatures: Sequence[Tuple[Sequence[Union[int, None]], DType]], model_dir: str, - polymorphic_shapes: Optional[Sequence[Union[str, PolyShape]]] = None, + polymorphic_shapes: Optional[Sequence[str]] = None, **tfjs_converter_params): """Converts a JAX function `jax_apply_fn` and model parameters to a TensorflowJS model.