diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD index f17e442cf..9315c9fa5 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/BUILD +++ b/tensorflow_model_optimization/python/core/sparsity/keras/BUILD @@ -132,7 +132,6 @@ py_library( deps = [ ":pruning_utils", # tensorflow dep1, - # python:state_ops tensorflow dep2, # python:summary tensorflow dep2, ], ) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py index ef2005595..32ecf3d3f 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl.py @@ -20,12 +20,11 @@ import tensorflow as tf -# b/(139939526): update assign ops to v2 API. -from tensorflow.python.ops import state_ops from tensorflow.python.ops import summary_ops_v2 from tensorflow.python.summary import summary as summary_ops_v1 from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils + class Pruning(object): """Implementation of magnitude-based weight pruning.""" @@ -54,6 +53,13 @@ def __init__(self, training_step_fn, pruning_vars, pruning_schedule, self._validate_block() + @staticmethod + def _assign(ref, value): + if tf.__version__[0] == '1': + return tf.assign(ref, value) + else: + return ref.assign(value) + def _validate_block(self): if self._block_size != [1, 1]: for weight, _, _ in self._pruning_vars: @@ -144,8 +150,15 @@ def _maybe_update_block_mask(self, weights): squeezed_weights.get_shape()[1]]) return new_threshold, tf.reshape(sliced_mask, tf.shape(weights)) - def _get_weight_assign_ops(self): - """Gather the assign ops for assigning weights<=weights*mask.""" + def _weight_assign_objs(self): + """Gather the assign objs for assigning weights<=weights*mask. + + The objs are ops for graph execution and tensors for eager + execution. + + Returns: + group of objs for weight assignment. + """ def update_fn(distribution, values_and_vars): # TODO(yunluli): Need this ReduceOp because the weight is created by the @@ -158,16 +171,16 @@ def update_fn(distribution, values_and_vars): values_and_vars = zip(reduced_values, var_list) def update_var(variable, reduced_value): - return state_ops.assign(variable, reduced_value) + return self._assign(variable, reduced_value) - update_ops = [] + update_objs = [] for value, var in values_and_vars: - update_ops.append( + update_objs.append( distribution.extended.update(var, update_var, args=(value,))) - return tf.group(update_ops) + return tf.group(update_objs) - assign_ops = [] + assign_objs = [] if tf.distribute.get_replica_context(): values_and_vars = [] @@ -175,17 +188,17 @@ def update_var(variable, reduced_value): masked_weight = tf.math.multiply(weight, mask) values_and_vars.append((masked_weight, weight)) if values_and_vars: - assign_ops.append(tf.distribute.get_replica_context().merge_call( + assign_objs.append(tf.distribute.get_replica_context().merge_call( update_fn, args=(values_and_vars,))) else: for weight, mask, _ in self._pruning_vars: masked_weight = tf.math.multiply(weight, mask) - assign_ops.append(state_ops.assign(weight, masked_weight)) + assign_objs.append(self._assign(weight, masked_weight)) - return assign_ops + return assign_objs def weight_mask_op(self): - return tf.group(self._get_weight_assign_ops()) + return tf.group(self._weight_assign_objs()) def conditional_mask_update(self): """Returns an op to updates masks as per the pruning schedule.""" @@ -200,14 +213,14 @@ def mask_update(): """Updates mask without distribution strategy.""" def update(): - assign_ops = [] + assign_objs = [] for weight, mask, threshold in self._pruning_vars: new_threshold, new_mask = self._maybe_update_block_mask(weight) - assign_ops.append(state_ops.assign(threshold, new_threshold)) - assign_ops.append(state_ops.assign(mask, new_mask)) + assign_objs.append(self._assign(threshold, new_threshold)) + assign_objs.append(self._assign(mask, new_mask)) - return tf.group(assign_ops) + return tf.group(assign_objs) return tf.cond(maybe_update_masks(), update, no_update) @@ -215,20 +228,24 @@ def mask_update_distributed(distribution): """Updates mask with distribution strategy.""" def update(var, value): - return state_ops.assign(var, value) + return self._assign(var, value) def update_distributed(): - """Gather distributed update ops.""" - assign_ops = [] + """Gather distributed update objs. + + The objs are ops for graph execution and tensors for eager + execution. + """ + assign_objs = [] for weight, mask, threshold in self._pruning_vars: new_threshold, new_mask = self._maybe_update_block_mask(weight) - assign_ops.append( + assign_objs.append( distribution.extended.update(mask, update, (new_mask,))) - assign_ops.append( + assign_objs.append( distribution.extended.update(threshold, update, (new_threshold,))) - return tf.group(assign_ops) + return tf.group(assign_objs) return tf.cond(maybe_update_masks(), update_distributed, no_update) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py index aafddb4b7..bb525785a 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_impl_test.py @@ -59,7 +59,7 @@ def setUp(self): # setUp() lies outside of the "eager scope" that wraps the test cases # themselves, resulting in initializing graph tensors instead of eager # tensors when testing eager execution. - def initialize_training_step_fn_and_all_variables(self): + def initialize(self): self.global_step = tf.Variable( tf.zeros([], dtype=dtypes.int32), dtype=dtypes.int32, @@ -81,7 +81,7 @@ def testUpdateSingleMask(self): dtype=weight_dtype) threshold = tf.Variable( tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype) - self.initialize_training_step_fn_and_all_variables() + self.initialize() p = pruning_impl.Pruning( pruning_vars=[(weight, mask, threshold)], @@ -102,7 +102,7 @@ def testUpdateSingleMask(self): self.assertAllEqual(np.count_nonzero(mask_after_pruning), 50) def testConstructsMaskAndThresholdCorrectly(self): - self.initialize_training_step_fn_and_all_variables() + self.initialize() p = pruning_impl.Pruning( lambda: 0, None, # Sparsity math often returns values with small tolerances. @@ -125,7 +125,7 @@ def _blockMasking(self, block_size, block_pooling_type, weight, dtype=weight.dtype) threshold = tf.Variable( tf.zeros([], dtype=weight.dtype), name="threshold", dtype=weight.dtype) - self.initialize_training_step_fn_and_all_variables() + self.initialize() # Set up pruning p = pruning_impl.Pruning( @@ -163,7 +163,7 @@ def testBlockMaskingMax(self): self._blockMasking(block_size, block_pooling_type, weight, expected_mask) def testBlockMaskingWithHigherDimensionsRaisesError(self): - self.initialize_training_step_fn_and_all_variables() + self.initialize() block_size = (2, 2) block_pooling_type = "AVG" # Weights as in testBlockMasking, but with one extra dimension. @@ -186,7 +186,7 @@ def testConditionalMaskUpdate(self): dtype=weight_dtype) threshold = tf.Variable( tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype) - self.initialize_training_step_fn_and_all_variables() + self.initialize() def linear_sparsity(step): sparsity_val = tf.convert_to_tensor( diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index b7a6a57ed..67bb8ca88 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -236,6 +236,8 @@ def no_op(): # Always execute the op that performs weights = weights * mask # Relies on UpdatePruningStep callback to ensure the weights # are sparse after the final backpropagation. + # + # self.add_update does nothing during eager execution. self.add_update(self.pruning_obj.weight_mask_op()) return self.layer.call(inputs)