diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py index fa36dd00c..adeb46574 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py @@ -667,7 +667,11 @@ def testStripClusteringSequentialModel(self): stripped_model = cluster.strip_clustering(clustered_model) self.assertEqual(self._count_clustered_layers(stripped_model), 0) - self.assertEqual(model.get_config(), stripped_model.get_config()) + model_config = model.get_config() + for layer in model_config['layers']: + # New serialization format includes `build_config` in wrapper + layer.pop('build_config', None) + self.assertEqual(model_config, stripped_model.get_config()) def testClusterStrippingFunctionalModel(self): """Verifies that stripping the clustering wrappers from a functional model produces the expected config.""" diff --git a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD index 2b2153549..c084c6451 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD +++ b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/BUILD @@ -66,5 +66,6 @@ py_strict_test( # absl/testing:parameterized dep1, # numpy dep1, # tensorflow dep1, + "//tensorflow_model_optimization/python/core/quantization/keras:utils", ], ) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py index 063c3d113..853d14ae1 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer_test.py @@ -23,6 +23,7 @@ import numpy as np import tensorflow as tf +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms @@ -159,7 +160,9 @@ def replacement(self, match_layer): match_layer_config = match_layer.layer['config'] my_dense_layer = self.MyDense(**match_layer_config) - replace_layer = keras.layers.serialize(my_dense_layer) + replace_layer = quantize_utils.serialize_layer( + my_dense_layer, use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer, match_layer.weights, []) @@ -176,8 +179,11 @@ def testReplaceSingleLayerWithSingleLayer_OneOccurrence(self, model_type): # build_input_shape is a TensorShape object and the two objects are not # considered the same even though the shapes are the same. - self._assert_config(model.get_config(), transformed_model.get_config(), - ['class_name', 'build_input_shape']) + self._assert_config( + model.get_config(), + transformed_model.get_config(), + ['class_name', 'build_input_shape', 'module', 'registered_name'], + ) self.assertEqual( 'MyDense', @@ -209,8 +215,11 @@ def testReplaceSingleLayerWithSingleLayer_MultipleOccurrences( # build_input_shape is a TensorShape object and the two objects are not # considered the same even though the shapes are the same. - self._assert_config(model.get_config(), transformed_model.get_config(), - ['class_name', 'build_input_shape']) + self._assert_config( + model.get_config(), + transformed_model.get_config(), + ['class_name', 'build_input_shape', 'module', 'registered_name'], + ) self.assertEqual( 'MyDense', @@ -268,7 +277,9 @@ def replacement(self, match_layer): match_layer_config['use_bias'] = False new_dense_layer = keras.layers.Dense(**match_layer_config) - replace_layer = keras.layers.serialize(new_dense_layer) + replace_layer = quantize_utils.serialize_layer( + new_dense_layer, use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer, match_layer_weights, []) @@ -311,7 +322,9 @@ def replacement(self, match_layer): match_layer_config = match_layer.layer['config'] my_dense_layer = QuantizedCustomDense(**match_layer_config) - replace_layer = keras.layers.serialize(my_dense_layer) + replace_layer = quantize_utils.serialize_layer( + my_dense_layer, use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer, match_layer.weights, []) @@ -355,7 +368,9 @@ def pattern(self): def replacement(self, match_layer): activation_layer = keras.layers.Activation('linear') - layer_config = keras.layers.serialize(activation_layer) + layer_config = quantize_utils.serialize_layer( + activation_layer, use_legacy_format=True + ) layer_config['name'] = activation_layer.name activation_layer_node = LayerNode( @@ -397,7 +412,9 @@ def pattern(self): def replacement(self, match_layer): activation_layer = keras.layers.Activation('linear') - layer_config = keras.layers.serialize(activation_layer) + layer_config = quantize_utils.serialize_layer( + activation_layer, use_legacy_format=True + ) layer_config['name'] = activation_layer.name activation_layer_node = LayerNode( @@ -435,7 +452,9 @@ def replacement(self, match_layer): new_dense_layer = keras.layers.Dense(**dense_layer_config) - replace_layer = keras.layers.serialize(new_dense_layer) + replace_layer = quantize_utils.serialize_layer( + new_dense_layer, use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer, dense_layer_weights, []) @@ -569,7 +588,9 @@ def pattern(self): return LayerPattern('ReLU') def replacement(self, match_layer): - replace_layer = keras.layers.serialize(keras.layers.Softmax()) + replace_layer = quantize_utils.serialize_layer( + keras.layers.Softmax(), use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer) @@ -579,7 +600,9 @@ def pattern(self): return LayerPattern('Softmax') def replacement(self, match_layer): - replace_layer = keras.layers.serialize(keras.layers.ELU()) + replace_layer = quantize_utils.serialize_layer( + keras.layers.ELU(), use_legacy_format=True + ) replace_layer['name'] = replace_layer['config']['name'] return LayerNode(replace_layer) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD index 970764e4d..4097eab19 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD +++ b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD @@ -64,6 +64,7 @@ py_strict_library( deps = [ ":pruning_wrapper", # tensorflow dep1, + "//tensorflow_model_optimization/python/core/quantization/keras:utils", ], ) @@ -89,6 +90,7 @@ py_strict_test( # absl/testing:parameterized dep1, # tensorflow dep1, "//tensorflow_model_optimization/python/core/keras:compat", + "//tensorflow_model_optimization/python/core/quantization/keras:utils", ], ) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py index 340078e2e..39c5fe1fe 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_policy.py @@ -18,6 +18,7 @@ import abc import tensorflow as tf +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper layers = tf.keras.layers @@ -216,9 +217,9 @@ def _check_layer_support(self, layer): elif isinstance(layer, layers.UpSampling2D): return layer.interpolation == 'bilinear' elif isinstance(layer, layers.Activation): - return activations.serialize(layer.activation) in ('relu', 'relu6', - 'leaky_relu', 'elu', - 'sigmoid') + return quantize_utils.serialize_activation( + layer.activation, use_legacy_format=True + ) in ('relu', 'relu6', 'leaky_relu', 'elu', 'sigmoid') elif layer.__class__.__name__ == 'TFOpLambda': return layer.function in (tf.identity, tf.__operators__.add, tf.math.add, tf.math.subtract, tf.math.multiply) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py index 916d080ab..36562b792 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py @@ -19,6 +19,7 @@ # TODO(b/139939526): move to public API. from tensorflow_model_optimization.python.core.keras import compat +from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule @@ -242,12 +243,13 @@ def testSerializeDeserialize(self): sparsity = pruning_schedule.ConstantSparsity(0.7, 10, 20, 10) config = sparsity.get_config() - sparsity_deserialized = tf.keras.utils.deserialize_keras_object( + sparsity_deserialized = quantize_utils.deserialize_keras_object( config, custom_objects={ 'ConstantSparsity': pruning_schedule.ConstantSparsity, - 'PolynomialDecay': pruning_schedule.PolynomialDecay - }) + 'PolynomialDecay': pruning_schedule.PolynomialDecay, + }, + ) self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__) @@ -278,12 +280,13 @@ def testSerializeDeserialize(self): sparsity = pruning_schedule.PolynomialDecay(0.2, 0.6, 10, 20, 5, 10) config = sparsity.get_config() - sparsity_deserialized = tf.keras.utils.deserialize_keras_object( + sparsity_deserialized = quantize_utils.deserialize_keras_object( config, custom_objects={ 'ConstantSparsity': pruning_schedule.ConstantSparsity, - 'PolynomialDecay': pruning_schedule.PolynomialDecay - }) + 'PolynomialDecay': pruning_schedule.PolynomialDecay, + }, + ) self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py index 89ff765f7..35064a911 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py @@ -121,13 +121,17 @@ def testPruneModel(self): # Test serialization model_config = self.model.get_config() + for layer in model_config['layers']: + layer.pop('build_config', None) self.assertEqual( model_config, self.model.__class__.from_config( - model_config, + self.model.get_config(), custom_objects={ 'PruneLowMagnitude': pruning_wrapper.PruneLowMagnitude - }).get_config()) + }, + ).get_config(), + ) def testCustomLayerNonPrunable(self): layer = CustomLayer(input_dim=16, output_dim=32)