-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
tf.keras.models.save_model not saving the probabilistic_model #50774
Comments
Could you please provide the colab gist with the required dependencies to analyse the issue better. Thanks |
@UsharaniPagadala https://github.com/rrklearn2020/probabilistic_model_trials.git |
@UsharaniPagadala https://github.com/rrklearn2020/probabilistic_model_trials.git Complete details are shared with this link. |
Please post this issue on keras-team/keras repo. |
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you. |
The issue is not yet solved, waiting for support. |
It looks like the Issue relates to the Keras component. Please submit it to the github.com/keras-team/keras repository instead. As previously announced all future development of Keras is expected to happen in the keras-team/keras repository. If your issue lies with the TF-Core area please comment back with your explanation and we can look into it further. Thanks! |
@UsharaniPagadala It would be a great help if the Tensorflow team and Keras team can support together as one team to solve this issue of saving a TF probabilistic_model as a 'TensorFlow SavedModel' |
@rrklearn2020 |
@UsharaniPagadala I had already posted the issue link posted in Keras-team/keras, 19 days ago, The link is shared below. |
Could you please move this to closed status as it is tracking in keras-team/keras repo. Thanks |
Before closing the ticket, It would be a great help if the Tensorflow team and Keras team can support together as one team to solve this issue of saving a TF probabilistic_model as a 'TensorFlow SavedModel' |
Thanks for the update,sure we will do that |
This issue has been automatically marked as stale because it has no recent activity. It will be closed if no further activity occurs. Thank you. |
The issue is not yet solved, waiting for support. |
This is still a problem that significantly limits the utility of the tensorflow probability layers. Great if we could save keras models with tensorflow probability layers please. |
System information
Describe the current behavior
Tensorflow model with tensorflow_probability layers creates errors while saving using
** The model is created using the below code**
`
_model = Sequential([
Conv2D(8, 5, activation='relu', padding='valid', input_shape=input_shape),
MaxPooling2D(6),
Flatten(),
Dense(10),
tfpl.OneHotCategorical(10)
])
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
probabilistic_model = get_probabilistic_model(
input_shape=(28, 28, 1),
loss=nll,
optimizer=RMSprop(),
metrics=['accuracy']
probabilistic_model.fit(x_train, y_train_oh, epochs=5)
_
`
For saving the model
_probabilistic_model.save('/tmp/model/probabilistic_model')_
The saving steps create the error as shown below.
`
_OperatorNotAllowedInGraphError Traceback (most recent call last)
/tmp/ipykernel_11377/1109926494.py in
----> 1 probabilistic_model.save('/tmp/model/probabilistic_model')
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in save(self, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
2109 """
2110 # pylint: enable=line-too-long
-> 2111 save.save_model(self, filepath, overwrite, include_optimizer, save_format,
2112 signatures, options, save_traces)
2113
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in save_model(model, filepath, overwrite, include_optimizer, save_format, signatures, options, save_traces)
148 else:
149 with generic_utils.SharedObjectSavingScope():
--> 150 saved_model_save.save(model, filepath, overwrite, include_optimizer,
151 signatures, options, save_traces)
152
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save.py in save(model, filepath, overwrite, include_optimizer, signatures, options, save_traces)
87 with K.deprecated_internal_learning_phase_scope(0):
88 with utils.keras_option_scope(save_traces):
---> 89 saved_nodes, node_paths = save_lib.save_and_return_nodes(
90 model, filepath, signatures, options)
91
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in save_and_return_nodes(obj, export_dir, signatures, options, raise_metadata_warning, experimental_skip_checkpoint)
1101
1102 _, exported_graph, object_saver, asset_info, saved_nodes, node_paths = (
-> 1103 _build_meta_graph(obj, signatures, options, meta_graph_def,
1104 raise_metadata_warning))
1105 saved_model.saved_model_schema_version = constants.SAVED_MODEL_SCHEMA_VERSION
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph(obj, signatures, options, meta_graph_def, raise_metadata_warning)
1288
1289 with save_context.save_context(options):
-> 1290 return _build_meta_graph_impl(obj, signatures, options, meta_graph_def,
1291 raise_metadata_warning)
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in _build_meta_graph_impl(obj, signatures, options, meta_graph_def, raise_metadata_warning)
1205 checkpoint_graph_view = _AugmentedGraphView(obj)
1206 if signatures is None:
-> 1207 signatures = signature_serialization.find_function_to_export(
1208 checkpoint_graph_view)
1209
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/signature_serialization.py in find_function_to_export(saveable_view)
97 # If the user did not specify signatures, check the root object for a function
98 # that can be made into a signature.
---> 99 functions = saveable_view.list_functions(saveable_view.root)
100 signature = functions.get(DEFAULT_SIGNATURE_ATTR, None)
101 if signature is not None:
~/tf2/lib/python3.8/site-packages/tensorflow/python/saved_model/save.py in list_functions(self, obj)
152 obj_functions = self._functions.get(obj, None)
153 if obj_functions is None:
--> 154 obj_functions = obj._list_functions_for_serialization( # pylint: disable=protected-access
155 self._serialization_cache)
156 self._functions[obj] = obj_functions
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py in _list_functions_for_serialization(self, serialization_cache)
2711 self.test_function = None
2712 self.predict_function = None
-> 2713 functions = super(
2714 Model, self)._list_functions_for_serialization(serialization_cache)
2715 self.train_function = train_function
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/base_layer.py in _list_functions_for_serialization(self, serialization_cache)
3014
3015 def _list_functions_for_serialization(self, serialization_cache):
-> 3016 return (self._trackable_saved_model_saver
3017 .list_functions_for_serialization(serialization_cache))
3018
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/base_serialization.py in list_functions_for_serialization(self, serialization_cache)
90 return {}
91
---> 92 fns = self.functions_to_serialize(serialization_cache)
93
94 # The parent AutoTrackable class saves all user-defined tf.functions, and
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in functions_to_serialize(self, serialization_cache)
71
72 def functions_to_serialize(self, serialization_cache):
---> 73 return (self._get_serialized_attributes(
74 serialization_cache).functions_to_serialize)
75
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes(self, serialization_cache)
87 return serialized_attr
88
---> 89 object_dict, function_dict = self._get_serialized_attributes_internal(
90 serialization_cache)
91
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/model_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
51 # the ones serialized by Layer.
52 objects, functions = (
---> 53 super(ModelSavedModelSaver, self)._get_serialized_attributes_internal(
54 serialization_cache))
55 functions['_default_save_signature'] = default_signature
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/layer_serialization.py in _get_serialized_attributes_internal(self, serialization_cache)
97 """Returns dictionary of serialized attributes."""
98 objects = save_impl.wrap_layer_objects(self.obj, serialization_cache)
---> 99 functions = save_impl.wrap_layer_functions(self.obj, serialization_cache)
100 # Attribute validator requires that the default save signature is added to
101 # function dict, even if the value is None.
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrap_layer_functions(layer, serialization_cache)
202 if isinstance(fn, LayerCall):
203 fn = fn.wrapped_call
--> 204 fn.get_concrete_function()
205
206 # Restore overwritten functions and losses
/usr/lib/python3.8/contextlib.py in exit(self, type, value, traceback)
118 if type is None:
119 try:
--> 120 next(self.gen)
121 except StopIteration:
122 return False
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in tracing_scope()
365 if training is not None:
366 with K.deprecated_internal_learning_phase_scope(training):
--> 367 fn.get_concrete_function(*args, **kwargs)
368 else:
369 fn.get_concrete_function(*args, **kwargs)
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
1365 ValueError: if this object has not yet been called on concrete values.
1366 """
-> 1367 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
1368 concrete._garbage_collector.release() # pylint: disable=protected-access
1369 return concrete
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
1282 # In this case we have not created variables on the first call. So we can
1283 # run the first trace but we should fail if variables are created.
-> 1284 concrete = self._stateful_fn._get_concrete_function_garbage_collected( # pylint: disable=protected-access
1285 *args, **kwargs)
1286 if self._created_variables:
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
3098 args, kwargs = None, None
3099 with self._lock:
-> 3100 graph_function, _ = self._maybe_define_function(args, kwargs)
3101 seen_names = set()
3102 captured = object_identity.ObjectIdentitySet(
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442
3443 self._function_cache.missed.add(call_context_key)
-> 3444 graph_function = self._create_graph_function(args, kwargs)
3445 self._function_cache.primary[cache_key] = graph_function
3446
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277 arg_names = base_arg_names + missing_arg_names
3278 graph_function = ConcreteFunction(
-> 3279 func_graph_module.func_graph_from_py_func(
3280 self._name,
3281 self._python_function,
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997 _, original_func = tf_decorator.unwrap(python_func)
998
--> 999 func_outputs = python_func(*func_args, **func_kwargs)
1000
1001 # invariant:
func_outputs
contains only Tensors, CompositeTensors,~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670 # the function a weak reference to itself to avoid a reference cycle.
671 with OptionalXlaContext(compile_with_xla):
--> 672 out = weak_wrapped_fn().wrapped(*args, **kwds)
673 return out
674
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
597 with autocast_variable.enable_auto_cast_variables(
598 layer._compute_dtype_object): # pylint: disable=protected-access
--> 599 ret = method(*args, **kwargs)
600 _restore_layer_losses(original_losses)
601 return ret
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
163 return wrapped_call(*args, **kwargs)
164
--> 165 return control_flow_util.smart_cond(
166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
107 return control_flow_ops.cond(
108 pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 109 return smart_module.smart_cond(
110 pred, true_fn=true_fn, false_fn=false_fn, name=name)
111
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
52 if pred_value is not None:
53 if pred_value:
---> 54 return true_fn()
55 else:
56 return false_fn()
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in ()
164
165 return control_flow_util.smart_cond(
--> 166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
168
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
161 def replace_training_and_call(training):
162 set_training_arg(training, training_arg_index, args, kwargs)
--> 163 return wrapped_call(*args, **kwargs)
164
165 return control_flow_util.smart_cond(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in call(inputs, *args, **kwargs)
679 return layer.keras_api.call # pylint: disable=protected-access
680 def call(inputs, *args, **kwargs):
--> 681 return call_and_return_conditional_losses(inputs, *args, **kwargs)[0]
682 return _create_call_fn_decorator(layer, call)
683
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in call(self, *args, **kwargs)
637 def call(self, *args, **kwargs):
638 self._maybe_trace(args, kwargs)
--> 639 return self.wrapped_call(*args, **kwargs)
640
641 def get_concrete_function(self, *args, **kwargs):
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in call(self, *args, **kwds)
887
888 with OptionalXlaContext(self._jit_compile):
--> 889 result = self._call(*args, **kwds)
890
891 new_tracing_count = self.experimental_get_tracing_count()
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in _call(self, *args, **kwds)
922 # In this case we have not created variables on the first call. So we can
923 # run the first trace but we should fail if variables are created.
--> 924 results = self._stateful_fn(*args, **kwds)
925 if self._created_variables:
926 raise ValueError("Creating variables on a non-first call to a function"
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in call(self, *args, **kwargs)
3020 with self._lock:
3021 (graph_function,
-> 3022 filtered_flat_args) = self._maybe_define_function(args, kwargs)
3023 return graph_function._call_flat(
3024 filtered_flat_args, captured_inputs=graph_function.captured_inputs) # pylint: disable=protected-access
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
3442
3443 self._function_cache.missed.add(call_context_key)
-> 3444 graph_function = self._create_graph_function(args, kwargs)
3445 self._function_cache.primary[cache_key] = graph_function
3446
~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
3277 arg_names = base_arg_names + missing_arg_names
3278 graph_function = ConcreteFunction(
-> 3279 func_graph_module.func_graph_from_py_func(
3280 self._name,
3281 self._python_function,
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
997 _, original_func = tf_decorator.unwrap(python_func)
998
--> 999 func_outputs = python_func(*func_args, **func_kwargs)
1000
1001 # invariant:
func_outputs
contains only Tensors, CompositeTensors,~/tf2/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
670 # the function a weak reference to itself to avoid a reference cycle.
671 with OptionalXlaContext(compile_with_xla):
--> 672 out = weak_wrapped_fn().wrapped(*args, **kwds)
673 return out
674
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in wrapper(*args, **kwargs)
597 with autocast_variable.enable_auto_cast_variables(
598 layer._compute_dtype_object): # pylint: disable=protected-access
--> 599 ret = method(*args, **kwargs)
600 _restore_layer_losses(original_losses)
601 return ret
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
163 return wrapped_call(*args, **kwargs)
164
--> 165 return control_flow_util.smart_cond(
166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/control_flow_util.py in smart_cond(pred, true_fn, false_fn, name)
107 return control_flow_ops.cond(
108 pred, true_fn=true_fn, false_fn=false_fn, name=name)
--> 109 return smart_module.smart_cond(
110 pred, true_fn=true_fn, false_fn=false_fn, name=name)
111
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
52 if pred_value is not None:
53 if pred_value:
---> 54 return true_fn()
55 else:
56 return false_fn()
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in ()
164
165 return control_flow_util.smart_cond(
--> 166 training, lambda: replace_training_and_call(True),
167 lambda: replace_training_and_call(False))
168
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
161 def replace_training_and_call(training):
162 set_training_arg(training, training_arg_index, args, kwargs)
--> 163 return wrapped_call(*args, **kwargs)
164
165 return control_flow_util.smart_cond(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/saved_model/save_impl.py in call_and_return_conditional_losses(*args, **kwargs)
661 def call_and_return_conditional_losses(*args, **kwargs):
662 """Returns layer (call_output, conditional losses) tuple."""
--> 663 call_output = layer_call(*args, **kwargs)
664 if version_utils.is_v1_layer_or_model(layer):
665 conditional_losses = layer.get_losses_for(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py in call(self, inputs, training, mask)
378 if not self.built:
379 self._init_graph_network(self.inputs, self.outputs)
--> 380 return super(Sequential, self).call(inputs, training=training, mask=mask)
381
382 outputs = inputs # handle the corner case where self.layers is empty
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in call(self, inputs, training, mask)
418 a list of tensors if there are more than one outputs.
419 """
--> 420 return self._run_internal_graph(
421 inputs, training=training, mask=mask)
422
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/functional.py in _run_internal_graph(self, inputs, training, mask)
554
555 args, kwargs = node.map_arguments(tensor_dict)
--> 556 outputs = node.layer(*args, **kwargs)
557
558 # Update tensor_dict.
~/tf2/lib/python3.8/site-packages/tensorflow_probability/python/layers/distribution_layer.py in call(self, inputs, *args, **kwargs)
228 def call(self, inputs, *args, **kwargs):
229 self._enter_dunder_call = True
--> 230 distribution, _ = super(DistributionLambda, self).call(
231 inputs, *args, **kwargs)
232 self._enter_dunder_call = False
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in iter(self)
518 def iter(self):
519 if not context.executing_eagerly():
--> 520 self._disallow_iteration()
521
522 shape = self._shape_tuple()
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_iteration(self)
511 self._disallow_when_autograph_disabled("iterating over
tf.Tensor
")512 elif ag_ctx.control_status_ctx().status == ag_ctx.Status.ENABLED:
--> 513 self._disallow_when_autograph_enabled("iterating over
tf.Tensor
")514 else:
515 # Default: V1-style Graph execution.
~/tf2/lib/python3.8/site-packages/tensorflow/python/framework/ops.py in _disallow_when_autograph_enabled(self, task)
487
488 def _disallow_when_autograph_enabled(self, task):
--> 489 raise errors.OperatorNotAllowedInGraphError(
490 "{} is not allowed: AutoGraph did convert this function. This might"
491 " indicate you are trying to use an unsupported feature.".format(task))
OperatorNotAllowedInGraphError: iterating over
tf.Tensor
is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature._`
Workaround works with limited capability as shown in https://github.com/tensorflow/probability/issues/325#issuecomment-477213850
But this only saves the weights and not other details of the model.
Workaround works with h5 format
h5 format saving works, but cannot load the model
`
loaded_model = tf.keras.models.load_model('/tmp/model/probabilistic_model.h5')
Error while using h5 format for saving and then loading the model is shown below.
ValueError Traceback (most recent call last)
/tmp/ipykernel_11377/686337657.py in
----> 1 loaded_model = tf.keras.models.load_model('/tmp/model/probabilistic_model.h5')
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
199 if (h5py is not None and
200 (isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 201 return hdf5_format.load_model_from_hdf5(filepath, custom_objects,
202 compile)
203
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
178 model_config = model_config.decode('utf-8')
179 model_config = json_utils.decode(model_config)
--> 180 model = model_config_lib.model_from_config(model_config,
181 custom_objects=custom_objects)
182
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
57 '
Sequential.from_config(config)
?')58 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
---> 59 return deserialize(config, custom_objects=custom_objects)
60
61
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
157 """
158 populate_deserializable_objects()
--> 159 return generic_utils.deserialize_keras_object(
160 config,
161 module_objects=LOCAL.ALL_OBJECTS,
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
666
667 if 'custom_objects' in arg_spec.args:
--> 668 deserialized_obj = cls.from_config(
669 cls_config,
670 custom_objects=dict(
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/engine/sequential.py in from_config(cls, config, custom_objects)
495 model = cls(name=name)
496 for layer_config in layer_configs:
--> 497 layer = layer_module.deserialize(layer_config,
498 custom_objects=custom_objects)
499 model.add(layer)
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
157 """
158 populate_deserializable_objects()
--> 159 return generic_utils.deserialize_keras_object(
160 config,
161 module_objects=LOCAL.ALL_OBJECTS,
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
651 # In this case we are dealing with a Keras config dictionary.
652 config = identifier
--> 653 (cls, cls_config) = class_and_config_for_serialized_keras_object(
654 config, module_objects, custom_objects, printable_module_name)
655
~/tf2/lib/python3.8/site-packages/tensorflow/python/keras/utils/generic_utils.py in class_and_config_for_serialized_keras_object(config, module_objects, custom_objects, printable_module_name)
554 cls = get_registered_object(class_name, custom_objects, module_objects)
555 if cls is None:
--> 556 raise ValueError(
557 'Unknown {}: {}. Please ensure this object is '
558 'passed to the
custom_objects
argument. See 'ValueError: Unknown layer: OneHotCategorical. Please ensure this object is passed to the
custom_objects
argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.`
Describe the expected behavior
Saves a probabilistic_model as a TensorFlow SavedModel
The text was updated successfully, but these errors were encountered: