diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD index 199608ad5..4ed2bb39e 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD +++ b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD @@ -194,6 +194,7 @@ py_test( # absl/testing:parameterized dep1, # numpy dep1, # tensorflow dep1, + "//tensorflow_model_optimization/python/core/keras:compat", "//tensorflow_model_optimization/python/core/keras:test_utils", ], ) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune.py index 5eb72803a..d95e79a0a 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune.py @@ -194,22 +194,25 @@ def _add_pruning_wrapper(layer): 'an object of type: {input}.'.format(input=to_prune.__class__.__name__)) -def strip_pruning(model): - """Strip pruning wrappers from the model. +def strip_pruning(to_strip): + """Strip pruning wrappers from the model or layer. - Once a model has been pruned to required sparsity, this method can be used - to restore the original model with the sparse weights. + Once a model or layer has been pruned to required sparsity, this method can be + used + to restore the original model or layer with the sparse weights. Only sequential and functional models are supported for now. Arguments: - model: A `tf.keras.Model` instance with pruned layers. + to_strip: A `tf.keras.Model` instance with pruned layers or a + `tf.keras.layers.Layer` instance. Returns: - A keras model with pruning wrappers removed. + A keras model or layer with pruning wrappers removed. Raises: - ValueError: if the model is not a `tf.keras.Model` instance. + ValueError: if the model is not a `tf.keras.Model` or + `tf.keras.layers.Layer` instance. NotImplementedError: if the model is a subclass model. Usage: @@ -222,9 +225,11 @@ def strip_pruning(model): The exported_model and the orig_model share the same structure. """ - if not isinstance(model, keras.Model): + if not isinstance(to_strip, keras.Model) and not isinstance( + to_strip, keras.layers.Layer): raise ValueError( - 'Expected model to be a `tf.keras.Model` instance but got: ', model) + 'Expected `to_strip` to be a `tf.keras.Model` or `tf.keras.layers.Layer` instance but got: ', + to_strip) def _strip_pruning_wrapper(layer): if isinstance(layer, tf.keras.Model): @@ -241,5 +246,9 @@ def _strip_pruning_wrapper(layer): return layer.layer return layer - return keras.models.clone_model( - model, input_tensors=None, clone_function=_strip_pruning_wrapper) + if isinstance(to_strip, keras.Model): + return keras.models.clone_model( + to_strip, input_tensors=None, clone_function=_strip_pruning_wrapper) + + if isinstance(to_strip, keras.layers.Layer): + return _strip_pruning_wrapper(to_strip) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py index 3b0e3ac1c..1ab24a6c5 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py @@ -15,6 +15,7 @@ """End to End tests for the Pruning API.""" import tempfile +import os from absl.testing import parameterized import numpy as np @@ -57,6 +58,48 @@ class PruneIntegrationTest(tf.test.TestCase, parameterized.TestCase, if not weights ] + @staticmethod + def _save_as_saved_model(model): + saved_model_dir = tempfile.mkdtemp() + model.save(saved_model_dir) + return saved_model_dir + + @staticmethod + def get_gzipped_model_size(model): + # It returns the size of the gzipped model in bytes. + import os + import zipfile + + _, keras_file = tempfile.mkstemp('.h5') + model.save_weights(keras_file) + + _, zipped_file = tempfile.mkstemp('.zip') + with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f: + f.write(keras_file) + return os.path.getsize(zipped_file) + + + @staticmethod + def _get_directory_size_in_bytes(directory): + import os + + total = 0 + try: + for entry in os.scandir(directory): + if entry.is_file(): + # if it's a file, use stat() function + total += entry.stat().st_size + elif entry.is_dir(): + # if it's a directory, recursively call this function + total += PruneIntegrationTest._get_directory_size_in_bytes(entry.path) + except NotADirectoryError: + # if `directory` isn't a directory, get the file size then + return os.path.getsize(directory) + except PermissionError: + # if for whatever reason we can't open the folder, return 0 + return 0 + return total + @staticmethod def _batch(dims, batch_size): """Adds provided batch_size to existing dims. @@ -120,23 +163,6 @@ def setUp(self): 'block_pooling_type': 'AVG' } - # TODO(pulkitb): Also assert correct weights are pruned. - # TODO(tfmot): this should not be verified in all the unit tests. - # As long as there are a few unit tests for strip_pruning, - # these checks are redundant. - def _check_strip_pruning_matches_original( - self, model, sparsity, input_data=None): - stripped_model = prune.strip_pruning(model) - test_utils.assert_model_sparsity(self, sparsity, stripped_model) - - if input_data is None: - input_data = np.random.randn( - *self._batch(model.input.get_shape().as_list(), 1)) - - model_result = model.predict(input_data) - stripped_model_result = stripped_model.predict(input_data) - np.testing.assert_almost_equal(model_result, stripped_model_result) - @staticmethod def _is_pruned(model): for layer in model.layers: @@ -146,7 +172,7 @@ def _is_pruned(model): @staticmethod def _train_model(model, epochs=1, x_train=None, y_train=None, callbacks=None): if x_train is None: - x_train = np.random.rand(20, 10), + x_train = np.random.rand(20, 10) if y_train is None: y_train = keras.utils.to_categorical( np.random.randint(5, size=(20, 1)), 5) @@ -165,411 +191,154 @@ def _train_model(model, epochs=1, x_train=None, y_train=None, callbacks=None): model.fit( x_train, y_train, epochs=epochs, batch_size=20, callbacks=callbacks) - def _get_pretrained_model(self): - model = keras_test_utils.build_simple_dense_model() - self._train_model(model, epochs=1) - return model - - ################################################################### - # Sanity checks and special cases for training with pruning. - - def testPrunesZeroSparsity_IsNoOp(self): - model = keras_test_utils.build_simple_dense_model() - - model2 = keras_test_utils.build_simple_dense_model() - model2.set_weights(model.get_weights()) - - params = self.params - params['pruning_schedule'] = pruning_schedule.ConstantSparsity( - target_sparsity=0, begin_step=0, frequency=1) - pruned_model = prune.prune_low_magnitude(model2, **params) - - x_train = np.random.rand(20, 10), - y_train = keras.utils.to_categorical(np.random.randint(5, size=(20, 1)), 5) - - self._train_model(model, epochs=1, x_train=x_train, y_train=y_train) - self._train_model(pruned_model, epochs=1, x_train=x_train, y_train=y_train) - - self._assert_weights_different_objects(model, pruned_model) - self._assert_weights_same_values(model, pruned_model) - - # TODO(tfmot): https://github.com/tensorflow/model-optimization/issues/215 - def testPruneWithHighSparsity_Fails(self): - params = self.params - params['pruning_schedule'] = pruning_schedule.ConstantSparsity( - target_sparsity=0.99, begin_step=0, frequency=1) - - model = prune.prune_low_magnitude( - keras_test_utils.build_simple_dense_model(), **params) - - with self.assertRaises(tf.errors.InvalidArgumentError): - self._train_model(model, epochs=1) - - ################################################################### - # Tests for training with pruning with pretrained models or weights. - - def testPrunePretrainedModel_RemovesOptimizer(self): - model = self._get_pretrained_model() - - self.assertIsNotNone(model.optimizer) - pruned_model = prune.prune_low_magnitude(model, **self.params) - self.assertIsNone(pruned_model.optimizer) - - def testPrunePretrainedModel_PreservesWeightObjects(self): - model = self._get_pretrained_model() - - pruned_model = prune.prune_low_magnitude(model, **self.params) - self._assert_weights_same_objects(model, pruned_model) + @staticmethod + def _get_subclassed_model(): - def testPrunePretrainedModel_SameInferenceWithoutTraining(self): - model = self._get_pretrained_model() - pruned_model = prune.prune_low_magnitude(model, **self.params) + class TestSubclassedModel(keras.Model): + """A model subclass.""" - input_data = np.random.rand(10, 10) + def __init__(self): + """A test subclass model with one dense layer.""" + super(TestSubclassedModel, self).__init__(name='test_model') + self.layer1 = keras.layers.Dense(8, activation='relu') + self.layer2 = keras.layers.Dense(5, activation='softmax') - out = model.predict(input_data) - pruned_out = pruned_model.predict(input_data) + def call(self, inputs): + x = self.layer1(inputs) + return self.layer2(x) - self.assertTrue((out == pruned_out).all()) + return TestSubclassedModel() - def testLoadTFWeightsThenPrune_SameInferenceWithoutTraining(self): - model = self._get_pretrained_model() - _, tf_weights = tempfile.mkstemp('.tf') - model.save_weights(tf_weights) + def testPrunePretrainedSubclassedModelAttributes_WrapperAdded_SizeIncreases( + self): + # Size increases since wrapper adds new weights and we don't call + # `strip_pruning`. - # load weights into model then prune. - same_architecture_model = keras_test_utils.build_simple_dense_model() - same_architecture_model.load_weights(tf_weights) - pruned_model = prune.prune_low_magnitude(same_architecture_model, - **self.params) + model = self._get_subclassed_model() + self._train_model(model, epochs=1) input_data = np.random.rand(10, 10) - out = model.predict(input_data) - pruned_out = pruned_model.predict(input_data) - - self.assertTrue((out == pruned_out).all()) - - # Test this and _DifferentInferenceWithoutTraining - # because pruning and then loading pretrained weights - # is unusual behavior and extra coverage is safer. - def testPruneThenLoadTFWeights_DoesNotPreserveWeights(self): - model = self._get_pretrained_model() + pruned_model = self._get_subclassed_model() - _, tf_weights = tempfile.mkstemp('.tf') - model.save_weights(tf_weights) + # Build the model and copy weights over. + pruned_model.build(input_data.shape) + pruned_model.set_weights(model.get_weights()) - # load weights into pruned model. - same_architecture_model = keras_test_utils.build_simple_dense_model() - pruned_model = prune.prune_low_magnitude(same_architecture_model, - **self.params) - pruned_model.load_weights(tf_weights) + # Apply pruning. + pruned_model.layer1 = prune.prune_low_magnitude(pruned_model.layer1, + **self.params) + pruned_model.layer2 = prune.prune_low_magnitude(pruned_model.layer2, + **self.params) - self._assert_weights_different_values(model, pruned_model) + # Rebuild given added wrappers from `prune_low_magnitude`. + pruned_model.build(input_data.shape) - def testPruneThenLoadTFWeights_DifferentInferenceWithoutTraining(self): - model = self._get_pretrained_model() + print("h5 weights:", self.get_gzipped_model_size(model)) + print("pruned h5 weights:", self.get_gzipped_model_size(pruned_model)) - _, tf_weights = tempfile.mkstemp('.tf') - model.save_weights(tf_weights) + def testPrunePretrainedSubclassedModelAttributes_WrapperAdded_CallChanged( + self): - # load weights into pruned model. - same_architecture_model = keras_test_utils.build_simple_dense_model() - pruned_model = prune.prune_low_magnitude(same_architecture_model, - **self.params) - pruned_model.load_weights(tf_weights) + model = self._get_subclassed_model() + self._train_model(model, epochs=1) input_data = np.random.rand(10, 10) - out = model.predict(input_data) - pruned_out = pruned_model.predict(input_data) - - self.assertFalse((out == pruned_out).any()) - - def testPruneThenLoadsKerasWeights_Fails(self): - model = self._get_pretrained_model() - - _, keras_weights = tempfile.mkstemp('.h5') - model.save_weights(keras_weights) - - # load weights into pruned model. - same_architecture_model = keras_test_utils.build_simple_dense_model() - pruned_model = prune.prune_low_magnitude(same_architecture_model, - **self.params) - - # error since number of keras_weights is fewer than weights in pruned model - # because pruning introduces weights. - with self.assertRaises(ValueError): - pruned_model.load_weights(keras_weights) + pruned_model = self._get_subclassed_model() - ################################################################### - # Tests for training with pruning from scratch. + # Build the model and copy weights over. + pruned_model.build(input_data.shape) + pruned_model.set_weights(model.get_weights()) - @parameterized.parameters(_PRUNABLE_LAYERS) - def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type): - model = keras.Sequential() - args, input_shape = self._get_params_for_layer(layer_type) - if args is None: - return # Test for layer not supported yet. - model.add(prune.prune_low_magnitude( - layer_type(*args), input_shape=input_shape, **self.params)) + # Apply pruning. + pruned_model.layer1 = prune.prune_low_magnitude(pruned_model.layer1, + **self.params) + pruned_model.layer2 = prune.prune_low_magnitude(pruned_model.layer2, + **self.params) - model.compile( - loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) - test_utils.assert_model_sparsity(self, 0.0, model) - model.fit( - np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)), - np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)), - callbacks=[pruning_callbacks.UpdatePruningStep()]) - - test_utils.assert_model_sparsity(self, 0.5, model) + print("starting training: should throw debugging message from wrapper given that the UpdatePruningStep callback isn't being called.") + self._train_model(pruned_model, epochs=1, callbacks=[]) - self._check_strip_pruning_matches_original(model, 0.5) + def testPrunePretrainedSubclassedModelAttributes_WrapperAddedAfterPredict_TrainingCallChanged( + self): - @parameterized.parameters(prune_registry.PruneRegistry._RNN_LAYERS - - {keras.layers.RNN}) - def testRNNLayersSingleCell_ReachesTargetSparsity(self, layer_type): - model = keras.Sequential() - model.add( - prune.prune_low_magnitude( - layer_type(10), input_shape=(3, 4), **self.params)) - - model.compile( - loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) - test_utils.assert_model_sparsity(self, 0.0, model) - model.fit( - np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)), - np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)), - callbacks=[pruning_callbacks.UpdatePruningStep()]) - - test_utils.assert_model_sparsity(self, 0.5, model) - - self._check_strip_pruning_matches_original(model, 0.5) - - def testRNNLayersWithRNNCellParams_ReachesTargetSparsity(self): - model = keras.Sequential() - model.add( - prune.prune_low_magnitude( - keras.layers.RNN([ - layers.LSTMCell(10), - layers.GRUCell(10), - tf.keras.experimental.PeepholeLSTMCell(10), - layers.SimpleRNNCell(10) - ]), - input_shape=(3, 4), - **self.params)) - - model.compile( - loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) - test_utils.assert_model_sparsity(self, 0.0, model) - model.fit( - np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)), - np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)), - callbacks=[pruning_callbacks.UpdatePruningStep()]) - - test_utils.assert_model_sparsity(self, 0.5, model) - - self._check_strip_pruning_matches_original(model, 0.5) - - def testPrunesEmbedding_ReachesTargetSparsity(self): - model = keras.Sequential() - model.add( - prune.prune_low_magnitude( - layers.Embedding(input_dim=10, output_dim=3), - input_shape=(5,), - **self.params)) - model.add(layers.Flatten()) - model.add(layers.Dense(1, activation='sigmoid')) - - model.compile( - loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy']) - test_utils.assert_model_sparsity(self, 0.0, model) - model.fit( - np.random.randint(10, size=(32, 5)), - np.random.randint(2, size=(32, 1)), - callbacks=[pruning_callbacks.UpdatePruningStep()]) - - test_utils.assert_model_sparsity(self, 0.5, model) - - input_data = np.random.randint(10, size=(32, 5)) - self._check_strip_pruning_matches_original(model, 0.5, input_data) - - @parameterized.parameters(test_utils.model_type_keys()) - def testPrunesMnist_ReachesTargetSparsity(self, model_type): - model = test_utils.build_mnist_model(model_type, self.params) - if model_type == 'layer_list': - model = keras.Sequential(prune.prune_low_magnitude(model, **self.params)) - elif model_type in ['sequential', 'functional']: - model = prune.prune_low_magnitude(model, **self.params) - - model.compile( - loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) - test_utils.assert_model_sparsity(self, 0.0, model, rtol=1e-4, atol=1e-4) - model.fit( - np.random.rand(32, 28, 28, 1), - keras.utils.to_categorical(np.random.randint(10, size=(32, 1)), 10), - callbacks=[pruning_callbacks.UpdatePruningStep()]) - - test_utils.assert_model_sparsity(self, 0.5, model, rtol=1e-4, atol=1e-4) - - self._check_strip_pruning_matches_original(model, 0.5) - - ################################################################### - # Tests for pruning with checkpointing. - - # TODO(tfmot): https://github.com/tensorflow/model-optimization/issues/206. - # - # Note the following: - # 1. This test doesn't exactly reproduce bug. Test should sometimes - # pass when ModelCheckpoint save_freq='epoch'. The behavior was seen when - # training mobilenet. - # 2. testPruneStopAndRestart_PreservesSparsity passes, indicating - # checkpointing in general works. Just don't use the checkpoint for - # serving. - def testPruneCheckpoints_CheckpointsNotSparse(self): - is_model_sparsity_not_list = [] - - # Run multiple times since problem doesn't always happen. - for _ in range(3): - model = keras_test_utils.build_simple_dense_model() - pruned_model = prune.prune_low_magnitude(model, **self.params) - - checkpoint_dir = tempfile.mkdtemp() - checkpoint_path = checkpoint_dir + '/weights.{epoch:02d}.tf' - - callbacks = [ - pruning_callbacks.UpdatePruningStep(), - tf.keras.callbacks.ModelCheckpoint( - filepath=checkpoint_path, save_weights_only=True, save_freq=1) - ] - - # Train one step. Sparsity reaches final sparsity. - self._train_model(pruned_model, epochs=1, callbacks=callbacks) - test_utils.assert_model_sparsity(self, 0.5, pruned_model) - - latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir) - - same_architecture_model = keras_test_utils.build_simple_dense_model() - pruned_model = prune.prune_low_magnitude(same_architecture_model, - **self.params) - - # Sanity check. - test_utils.assert_model_sparsity(self, 0, pruned_model) - - pruned_model.load_weights(latest_checkpoint) - is_model_sparsity_not_list.append( - test_utils.is_model_sparsity_not(0.5, pruned_model)) - - self.assertTrue(any(is_model_sparsity_not_list)) - - @parameterized.parameters(test_utils.save_restore_fns()) - def testPruneStopAndRestart_PreservesSparsity(self, save_restore_fn): - # TODO(tfmot): renable once SavedModel preserves step again. - # This existed in TF 2.0 and 2.1 and should be reenabled in - # TF 2.3. b/151755698 - if save_restore_fn.__name__ == '_save_restore_tf_model': - return - - begin_step, end_step = 0, 4 - params = self.params - params['pruning_schedule'] = pruning_schedule.PolynomialDecay( - 0.2, 0.6, begin_step, end_step, 3, 1) - - model = prune.prune_low_magnitude( - keras_test_utils.build_simple_dense_model(), **params) - model.compile( - loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) - # Model hasn't been trained yet. Sparsity 0.0 - test_utils.assert_model_sparsity(self, 0.0, model) - - # Train only 1 step. Sparsity 0.2 (initial_sparsity) + model = self._get_subclassed_model() self._train_model(model, epochs=1) - test_utils.assert_model_sparsity(self, 0.2, model) - model = save_restore_fn(model) + input_data = np.random.rand(10, 10) - # Training has run all 4 steps. Sparsity 0.6 (final_sparsity) - self._train_model(model, epochs=3) - test_utils.assert_model_sparsity(self, 0.6, model) + pruned_model = self._get_subclassed_model() - self._check_strip_pruning_matches_original(model, 0.6) + # Build the model and copy weights over. + pruned_model.build(input_data.shape) + pruned_model.set_weights(model.get_weights()) - @parameterized.parameters(test_utils.save_restore_fns()) - def testPruneWithPolynomialDecayPastEndStep_PreservesSparsity( - self, save_restore_fn): - # TODO(tfmot): renable once SavedModel preserves step again. - # This existed in TF 2.0 and 2.1 and should be reenabled in - # TF 2.3. b/151755698 - if save_restore_fn.__name__ == '_save_restore_tf_model': - return + # Call `call` to see if it makes it so that setting the attributes + # no longer does anything. + pruned_out = pruned_model.predict(input_data) - begin_step, end_step = 0, 2 - params = self.params - params['pruning_schedule'] = pruning_schedule.PolynomialDecay( - 0.2, 0.6, begin_step, end_step, 3, 1) + # Apply pruning. + print("applying wrappers after predict") + pruned_model.layer1 = prune.prune_low_magnitude(pruned_model.layer1, + **self.params) + pruned_model.layer2 = prune.prune_low_magnitude(pruned_model.layer2, + **self.params) - model = prune.prune_low_magnitude( - keras_test_utils.build_simple_dense_model(), **params) - model.compile( - loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) + print("starting training: if call changed, should throw debugging error from wrapper given that the UpdatePruningStep callback isn't being called.") + self._train_model(pruned_model, epochs=1, callbacks=[]) + # Error is indeed thrown and print statements (not tf.Print) in wrapper's + # `call` are still executed. - # Model hasn't been trained yet. Sparsity 0.0 - test_utils.assert_model_sparsity(self, 0.0, model) + def testPrunePretrainedSubclassedModelAttributes_WrapperAddedAfterPredict_PredictCallMaybeChanged( + self): - # Train 3 steps, past end_step. Sparsity 0.6 (final_sparsity) - self._train_model(model, epochs=3) - test_utils.assert_model_sparsity(self, 0.6, model) + model = self._get_subclassed_model() + self._train_model(model, epochs=1) - model = save_restore_fn(model) + input_data = np.random.rand(10, 10) - # Ensure sparsity is preserved. - test_utils.assert_model_sparsity(self, 0.6, model) + pruned_model = self._get_subclassed_model() - # Train one more step to ensure nothing happens that brings sparsity - # back below 0.6. - self._train_model(model, epochs=1) - test_utils.assert_model_sparsity(self, 0.6, model) + # Build the model and copy weights over. + pruned_model.build(input_data.shape) + pruned_model.set_weights(model.get_weights()) - self._check_strip_pruning_matches_original(model, 0.6) + # Call `call` to see if it makes it so that pruning isn't applied. + pruned_out = pruned_model.predict(input_data) + # Apply pruning. + print("applying wrappers after predict") + pruned_model.layer1 = prune.prune_low_magnitude(pruned_model.layer1, + **self.params) + pruned_model.layer2 = prune.prune_low_magnitude(pruned_model.layer2, + **self.params) -@keras_parameterized.run_all_keras_modes(always_skip_v1=True) -class PruneIntegrationCustomTrainingLoopTest(tf.test.TestCase, - parameterized.TestCase): + pruned_out = pruned_model.predict(input_data) + # print statements in `call` no longer called. - def testPrunesModel_CustomTrainingLoop_ReachesTargetSparsity(self): - pruned_model = prune.prune_low_magnitude( - keras_test_utils.build_simple_dense_model()) - batch_size = 20 - x_train = np.random.rand(20, 10) - y_train = keras.utils.to_categorical( - np.random.randint(5, size=(batch_size, 1)), 5) - loss = keras.losses.categorical_crossentropy - optimizer = keras.optimizers.SGD() +# pruned_model.layer1 = prune.strip_pruning(pruned_model.layer1) +# pruned_model.layer2 = prune.strip_pruning(pruned_model.layer2) - unused_arg = -1 +# print("prune predict") +# out = model.predict(input_data) +# pruned_out = pruned_model.predict(input_data) +# self.assertTrue((out == pruned_out).all()) - step_callback = pruning_callbacks.UpdatePruningStep() - step_callback.set_model(pruned_model) - pruned_model.optimizer = optimizer +# print("tf weights:", self.get_gzipped_model_size(model)) +# print("pruned tf weights:", self.get_gzipped_model_size(pruned_model)) - step_callback.on_train_begin() - # 2 epochs - for _ in range(2): - step_callback.on_train_batch_begin(batch=unused_arg) - inp = np.reshape(x_train, [batch_size, 10]) # original shape: from [10]. - with tf.GradientTape() as tape: - logits = pruned_model(inp, training=True) - loss_value = loss(y_train, logits) - grads = tape.gradient(loss_value, pruned_model.trainable_variables) - optimizer.apply_gradients(zip(grads, pruned_model.trainable_variables)) - step_callback.on_epoch_end(batch=unused_arg) +# original_saved_model_dir = self._save_as_saved_model(model) +# saved_model_dir = self._save_as_saved_model(pruned_model) - test_utils.assert_model_sparsity(self, 0.5, pruned_model) +# original_size = self._get_directory_size_in_bytes(original_saved_model_dir) +# compressed_size = self._get_directory_size_in_bytes(saved_model_dir) +# print("original size:", original_size) +# print("compressed size:", compressed_size) if __name__ == '__main__': tf.test.main() diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py index 2fd245014..2d3d0b5b7 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_test.py @@ -23,6 +23,7 @@ # TODO(b/139939526): move to public API. from tensorflow.python.keras import keras_parameterized +from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.keras import test_utils as keras_test_utils from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer from tensorflow_model_optimization.python.core.sparsity.keras import prune @@ -312,6 +313,12 @@ def testPruneSubclassModel(self): str(e.exception), self.INVALID_TO_PRUNE_PARAM_ERROR.format(input='TestSubclassedModel')) + def testPruneSubclassedModelAttributes(self): + model = TestSubclassedModel() + model.layer1 = prune.prune_low_magnitude(model.layer1) + + self.assertEqual(self._count_pruned_layers(model), 1) + def testPruneMiscObject(self): model = object() @@ -346,6 +353,13 @@ def testStripPruningFunctionalModel(self): self.assertEqual(self._count_pruned_layers(stripped_model), 0) self.assertEqual(model.get_config(), stripped_model.get_config()) + def testStripPruningSubclassedModelAttributes(self): + model = TestSubclassedModel() + model.layer1 = prune.prune_low_magnitude(model.layer1) + model.layer1 = prune.strip_pruning(model.layer1) + + self.assertEqual(self._count_pruned_layers(model), 0) + def testPruneScope_NeededForKerasModel(self): model = keras_test_utils.build_simple_dense_model() pruned_model = prune.prune_low_magnitude(model) @@ -387,9 +401,7 @@ def testPruneScope_NotNeededForTFCheckpoint(self): same_architecture_model.load_weights(tf_weights) def testPruneScope_NotNeededForTF2SavedModel(self): - # TODO(tfmot): replace with shared v1 test_util. - is_v1_apis = hasattr(tf, 'assign') - if is_v1_apis: + if compat.is_v1_apis(): return model = keras_test_utils.build_simple_dense_model() @@ -402,10 +414,21 @@ def testPruneScope_NotNeededForTF2SavedModel(self): # would error if `prune_scope` was needed. tf.saved_model.load(saved_model_dir) + def testSerializePrunedSubclassedModel_TF2(self): + if compat.is_v1_apis(): + return + + pruned_model = TestSubclassedModel() + pruned_model.layer1 = prune.prune_low_magnitude(pruned_model.layer1) + + saved_model_dir = tempfile.mkdtemp() + + tf.saved_model.save(pruned_model, saved_model_dir) + + tf.saved_model.load(saved_model_dir) + def testPruneScope_NeededForTF1SavedModel(self): - # TODO(tfmot): replace with shared v1 test_util. - is_v1_apis = hasattr(tf, 'assign') - if not is_v1_apis: + if not compat.is_v1_apis(): return model = keras_test_utils.build_simple_dense_model() @@ -421,6 +444,5 @@ def testPruneScope_NeededForTF1SavedModel(self): with prune.prune_scope(): tf.keras.experimental.load_from_saved_model(saved_model_dir) - if __name__ == '__main__': test.main() diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index 23639e7d3..3c392195a 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -231,6 +231,7 @@ def training_step_fn(): block_pooling_type=self.block_pooling_type) def call(self, inputs, training=None): + print("pruning wrapper call") if training is None: training = K.learning_phase()