Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,11 @@ def testStripClusteringSequentialModel(self):
stripped_model = cluster.strip_clustering(clustered_model)

self.assertEqual(self._count_clustered_layers(stripped_model), 0)
self.assertEqual(model.get_config(), stripped_model.get_config())
model_config = model.get_config()
for layer in model_config['layers']:
# New serialization format includes `build_config` in wrapper
layer.pop('build_config', None)
self.assertEqual(model_config, stripped_model.get_config())

def testClusterStrippingFunctionalModel(self):
"""Verifies that stripping the clustering wrappers from a functional model produces the expected config."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,6 @@ py_strict_test(
# absl/testing:parameterized dep1,
# numpy dep1,
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/quantization/keras:utils",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import tensorflow as tf

from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import model_transformer
from tensorflow_model_optimization.python.core.quantization.keras.graph_transformations import transforms

Expand Down Expand Up @@ -159,7 +160,9 @@ def replacement(self, match_layer):
match_layer_config = match_layer.layer['config']
my_dense_layer = self.MyDense(**match_layer_config)

replace_layer = keras.layers.serialize(my_dense_layer)
replace_layer = quantize_utils.serialize_layer(
my_dense_layer, use_legacy_format=True
)
replace_layer['name'] = replace_layer['config']['name']

return LayerNode(replace_layer, match_layer.weights, [])
Expand All @@ -176,8 +179,11 @@ def testReplaceSingleLayerWithSingleLayer_OneOccurrence(self, model_type):

# build_input_shape is a TensorShape object and the two objects are not
# considered the same even though the shapes are the same.
self._assert_config(model.get_config(), transformed_model.get_config(),
['class_name', 'build_input_shape'])
self._assert_config(
model.get_config(),
transformed_model.get_config(),
['class_name', 'build_input_shape', 'module', 'registered_name'],
)

self.assertEqual(
'MyDense',
Expand Down Expand Up @@ -209,8 +215,11 @@ def testReplaceSingleLayerWithSingleLayer_MultipleOccurrences(

# build_input_shape is a TensorShape object and the two objects are not
# considered the same even though the shapes are the same.
self._assert_config(model.get_config(), transformed_model.get_config(),
['class_name', 'build_input_shape'])
self._assert_config(
model.get_config(),
transformed_model.get_config(),
['class_name', 'build_input_shape', 'module', 'registered_name'],
)

self.assertEqual(
'MyDense',
Expand Down Expand Up @@ -268,7 +277,9 @@ def replacement(self, match_layer):
match_layer_config['use_bias'] = False
new_dense_layer = keras.layers.Dense(**match_layer_config)

replace_layer = keras.layers.serialize(new_dense_layer)
replace_layer = quantize_utils.serialize_layer(
new_dense_layer, use_legacy_format=True
)
replace_layer['name'] = replace_layer['config']['name']

return LayerNode(replace_layer, match_layer_weights, [])
Expand Down Expand Up @@ -311,7 +322,9 @@ def replacement(self, match_layer):
match_layer_config = match_layer.layer['config']
my_dense_layer = QuantizedCustomDense(**match_layer_config)

replace_layer = keras.layers.serialize(my_dense_layer)
replace_layer = quantize_utils.serialize_layer(
my_dense_layer, use_legacy_format=True
)
replace_layer['name'] = replace_layer['config']['name']

return LayerNode(replace_layer, match_layer.weights, [])
Expand Down Expand Up @@ -355,7 +368,9 @@ def pattern(self):

def replacement(self, match_layer):
activation_layer = keras.layers.Activation('linear')
layer_config = keras.layers.serialize(activation_layer)
layer_config = quantize_utils.serialize_layer(
activation_layer, use_legacy_format=True
)
layer_config['name'] = activation_layer.name

activation_layer_node = LayerNode(
Expand Down Expand Up @@ -397,7 +412,9 @@ def pattern(self):

def replacement(self, match_layer):
activation_layer = keras.layers.Activation('linear')
layer_config = keras.layers.serialize(activation_layer)
layer_config = quantize_utils.serialize_layer(
activation_layer, use_legacy_format=True
)
layer_config['name'] = activation_layer.name

activation_layer_node = LayerNode(
Expand Down Expand Up @@ -435,7 +452,9 @@ def replacement(self, match_layer):

new_dense_layer = keras.layers.Dense(**dense_layer_config)

replace_layer = keras.layers.serialize(new_dense_layer)
replace_layer = quantize_utils.serialize_layer(
new_dense_layer, use_legacy_format=True
)
replace_layer['name'] = replace_layer['config']['name']

return LayerNode(replace_layer, dense_layer_weights, [])
Expand Down Expand Up @@ -569,7 +588,9 @@ def pattern(self):
return LayerPattern('ReLU')

def replacement(self, match_layer):
replace_layer = keras.layers.serialize(keras.layers.Softmax())
replace_layer = quantize_utils.serialize_layer(
keras.layers.Softmax(), use_legacy_format=True
)
replace_layer['name'] = replace_layer['config']['name']
return LayerNode(replace_layer)

Expand All @@ -579,7 +600,9 @@ def pattern(self):
return LayerPattern('Softmax')

def replacement(self, match_layer):
replace_layer = keras.layers.serialize(keras.layers.ELU())
replace_layer = quantize_utils.serialize_layer(
keras.layers.ELU(), use_legacy_format=True
)
replace_layer['name'] = replace_layer['config']['name']
return LayerNode(replace_layer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ py_strict_library(
deps = [
":pruning_wrapper",
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/quantization/keras:utils",
],
)

Expand All @@ -89,6 +90,7 @@ py_strict_test(
# absl/testing:parameterized dep1,
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:compat",
"//tensorflow_model_optimization/python/core/quantization/keras:utils",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import abc
import tensorflow as tf

from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_wrapper

layers = tf.keras.layers
Expand Down Expand Up @@ -216,9 +217,9 @@ def _check_layer_support(self, layer):
elif isinstance(layer, layers.UpSampling2D):
return layer.interpolation == 'bilinear'
elif isinstance(layer, layers.Activation):
return activations.serialize(layer.activation) in ('relu', 'relu6',
'leaky_relu', 'elu',
'sigmoid')
return quantize_utils.serialize_activation(
layer.activation, use_legacy_format=True
) in ('relu', 'relu6', 'leaky_relu', 'elu', 'sigmoid')
elif layer.__class__.__name__ == 'TFOpLambda':
return layer.function in (tf.identity, tf.__operators__.add, tf.math.add,
tf.math.subtract, tf.math.multiply)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

# TODO(b/139939526): move to public API.
from tensorflow_model_optimization.python.core.keras import compat
from tensorflow_model_optimization.python.core.quantization.keras import utils as quantize_utils
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule


Expand Down Expand Up @@ -242,12 +243,13 @@ def testSerializeDeserialize(self):
sparsity = pruning_schedule.ConstantSparsity(0.7, 10, 20, 10)

config = sparsity.get_config()
sparsity_deserialized = tf.keras.utils.deserialize_keras_object(
sparsity_deserialized = quantize_utils.deserialize_keras_object(
config,
custom_objects={
'ConstantSparsity': pruning_schedule.ConstantSparsity,
'PolynomialDecay': pruning_schedule.PolynomialDecay
})
'PolynomialDecay': pruning_schedule.PolynomialDecay,
},
)

self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__)

Expand Down Expand Up @@ -278,12 +280,13 @@ def testSerializeDeserialize(self):
sparsity = pruning_schedule.PolynomialDecay(0.2, 0.6, 10, 20, 5, 10)

config = sparsity.get_config()
sparsity_deserialized = tf.keras.utils.deserialize_keras_object(
sparsity_deserialized = quantize_utils.deserialize_keras_object(
config,
custom_objects={
'ConstantSparsity': pruning_schedule.ConstantSparsity,
'PolynomialDecay': pruning_schedule.PolynomialDecay
})
'PolynomialDecay': pruning_schedule.PolynomialDecay,
},
)

self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,17 @@ def testPruneModel(self):

# Test serialization
model_config = self.model.get_config()
for layer in model_config['layers']:
layer.pop('build_config', None)
self.assertEqual(
model_config,
self.model.__class__.from_config(
model_config,
self.model.get_config(),
custom_objects={
'PruneLowMagnitude': pruning_wrapper.PruneLowMagnitude
}).get_config())
},
).get_config(),
)

def testCustomLayerNonPrunable(self):
layer = CustomLayer(input_dim=16, output_dim=32)
Expand Down