Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ py_test(
# absl/testing:parameterized dep1,
# numpy dep1,
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/keras:compat",
"//tensorflow_model_optimization/python/core/keras:test_utils",
],
)
Expand Down
31 changes: 20 additions & 11 deletions tensorflow_model_optimization/python/core/sparsity/keras/prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,22 +194,25 @@ def _add_pruning_wrapper(layer):
'an object of type: {input}.'.format(input=to_prune.__class__.__name__))


def strip_pruning(model):
"""Strip pruning wrappers from the model.
def strip_pruning(to_strip):
"""Strip pruning wrappers from the model or layer.

Once a model has been pruned to required sparsity, this method can be used
to restore the original model with the sparse weights.
Once a model or layer has been pruned to required sparsity, this method can be
used
to restore the original model or layer with the sparse weights.

Only sequential and functional models are supported for now.

Arguments:
model: A `tf.keras.Model` instance with pruned layers.
to_strip: A `tf.keras.Model` instance with pruned layers or a
`tf.keras.layers.Layer` instance.

Returns:
A keras model with pruning wrappers removed.
A keras model or layer with pruning wrappers removed.

Raises:
ValueError: if the model is not a `tf.keras.Model` instance.
ValueError: if the model is not a `tf.keras.Model` or
`tf.keras.layers.Layer` instance.
NotImplementedError: if the model is a subclass model.

Usage:
Expand All @@ -222,9 +225,11 @@ def strip_pruning(model):
The exported_model and the orig_model share the same structure.
"""

if not isinstance(model, keras.Model):
if not isinstance(to_strip, keras.Model) and not isinstance(
to_strip, keras.layers.Layer):
raise ValueError(
'Expected model to be a `tf.keras.Model` instance but got: ', model)
'Expected `to_strip` to be a `tf.keras.Model` or `tf.keras.layers.Layer` instance but got: ',
to_strip)

def _strip_pruning_wrapper(layer):
if isinstance(layer, tf.keras.Model):
Expand All @@ -241,5 +246,9 @@ def _strip_pruning_wrapper(layer):
return layer.layer
return layer

return keras.models.clone_model(
model, input_tensors=None, clone_function=_strip_pruning_wrapper)
if isinstance(to_strip, keras.Model):
return keras.models.clone_model(
to_strip, input_tensors=None, clone_function=_strip_pruning_wrapper)

if isinstance(to_strip, keras.layers.Layer):
return _strip_pruning_wrapper(to_strip)
Loading