Skip to content

Commit

Permalink
Migrate pruning assign to 2.X/1.X public APIs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 286029751
  • Loading branch information
alanchiao authored and tensorflower-gardener committed Dec 17, 2019
1 parent ff464f9 commit dbae704
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 30 deletions.
Expand Up @@ -132,7 +132,6 @@ py_library(
deps = [
":pruning_utils",
# tensorflow dep1,
# python:state_ops tensorflow dep2,
# python:summary tensorflow dep2,
],
)
Expand Down
Expand Up @@ -20,12 +20,11 @@

import tensorflow as tf

# b/(139939526): update assign ops to v2 API.
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.summary import summary as summary_ops_v1
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_utils


class Pruning(object):
"""Implementation of magnitude-based weight pruning."""

Expand Down Expand Up @@ -54,6 +53,13 @@ def __init__(self, training_step_fn, pruning_vars, pruning_schedule,

self._validate_block()

@staticmethod
def _assign(ref, value):
if tf.__version__[0] == '1':
return tf.assign(ref, value)
else:
return ref.assign(value)

def _validate_block(self):
if self._block_size != [1, 1]:
for weight, _, _ in self._pruning_vars:
Expand Down Expand Up @@ -144,8 +150,15 @@ def _maybe_update_block_mask(self, weights):
squeezed_weights.get_shape()[1]])
return new_threshold, tf.reshape(sliced_mask, tf.shape(weights))

def _get_weight_assign_ops(self):
"""Gather the assign ops for assigning weights<=weights*mask."""
def _weight_assign_objs(self):
"""Gather the assign objs for assigning weights<=weights*mask.
The objs are ops for graph execution and tensors for eager
execution.
Returns:
group of objs for weight assignment.
"""

def update_fn(distribution, values_and_vars):
# TODO(yunluli): Need this ReduceOp because the weight is created by the
Expand All @@ -158,34 +171,34 @@ def update_fn(distribution, values_and_vars):
values_and_vars = zip(reduced_values, var_list)

def update_var(variable, reduced_value):
return state_ops.assign(variable, reduced_value)
return self._assign(variable, reduced_value)

update_ops = []
update_objs = []
for value, var in values_and_vars:
update_ops.append(
update_objs.append(
distribution.extended.update(var, update_var, args=(value,)))

return tf.group(update_ops)
return tf.group(update_objs)

assign_ops = []
assign_objs = []

if tf.distribute.get_replica_context():
values_and_vars = []
for weight, mask, _ in self._pruning_vars:
masked_weight = tf.math.multiply(weight, mask)
values_and_vars.append((masked_weight, weight))
if values_and_vars:
assign_ops.append(tf.distribute.get_replica_context().merge_call(
assign_objs.append(tf.distribute.get_replica_context().merge_call(
update_fn, args=(values_and_vars,)))
else:
for weight, mask, _ in self._pruning_vars:
masked_weight = tf.math.multiply(weight, mask)
assign_ops.append(state_ops.assign(weight, masked_weight))
assign_objs.append(self._assign(weight, masked_weight))

return assign_ops
return assign_objs

def weight_mask_op(self):
return tf.group(self._get_weight_assign_ops())
return tf.group(self._weight_assign_objs())

def conditional_mask_update(self):
"""Returns an op to updates masks as per the pruning schedule."""
Expand All @@ -200,35 +213,39 @@ def mask_update():
"""Updates mask without distribution strategy."""

def update():
assign_ops = []
assign_objs = []

for weight, mask, threshold in self._pruning_vars:
new_threshold, new_mask = self._maybe_update_block_mask(weight)
assign_ops.append(state_ops.assign(threshold, new_threshold))
assign_ops.append(state_ops.assign(mask, new_mask))
assign_objs.append(self._assign(threshold, new_threshold))
assign_objs.append(self._assign(mask, new_mask))

return tf.group(assign_ops)
return tf.group(assign_objs)

return tf.cond(maybe_update_masks(), update, no_update)

def mask_update_distributed(distribution):
"""Updates mask with distribution strategy."""

def update(var, value):
return state_ops.assign(var, value)
return self._assign(var, value)

def update_distributed():
"""Gather distributed update ops."""
assign_ops = []
"""Gather distributed update objs.
The objs are ops for graph execution and tensors for eager
execution.
"""
assign_objs = []

for weight, mask, threshold in self._pruning_vars:
new_threshold, new_mask = self._maybe_update_block_mask(weight)
assign_ops.append(
assign_objs.append(
distribution.extended.update(mask, update, (new_mask,)))
assign_ops.append(
assign_objs.append(
distribution.extended.update(threshold, update, (new_threshold,)))

return tf.group(assign_ops)
return tf.group(assign_objs)

return tf.cond(maybe_update_masks(), update_distributed, no_update)

Expand Down
Expand Up @@ -59,7 +59,7 @@ def setUp(self):
# setUp() lies outside of the "eager scope" that wraps the test cases
# themselves, resulting in initializing graph tensors instead of eager
# tensors when testing eager execution.
def initialize_training_step_fn_and_all_variables(self):
def initialize(self):
self.global_step = tf.Variable(
tf.zeros([], dtype=dtypes.int32),
dtype=dtypes.int32,
Expand All @@ -81,7 +81,7 @@ def testUpdateSingleMask(self):
dtype=weight_dtype)
threshold = tf.Variable(
tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype)
self.initialize_training_step_fn_and_all_variables()
self.initialize()

p = pruning_impl.Pruning(
pruning_vars=[(weight, mask, threshold)],
Expand All @@ -102,7 +102,7 @@ def testUpdateSingleMask(self):
self.assertAllEqual(np.count_nonzero(mask_after_pruning), 50)

def testConstructsMaskAndThresholdCorrectly(self):
self.initialize_training_step_fn_and_all_variables()
self.initialize()
p = pruning_impl.Pruning(
lambda: 0, None,
# Sparsity math often returns values with small tolerances.
Expand All @@ -125,7 +125,7 @@ def _blockMasking(self, block_size, block_pooling_type, weight,
dtype=weight.dtype)
threshold = tf.Variable(
tf.zeros([], dtype=weight.dtype), name="threshold", dtype=weight.dtype)
self.initialize_training_step_fn_and_all_variables()
self.initialize()

# Set up pruning
p = pruning_impl.Pruning(
Expand Down Expand Up @@ -163,7 +163,7 @@ def testBlockMaskingMax(self):
self._blockMasking(block_size, block_pooling_type, weight, expected_mask)

def testBlockMaskingWithHigherDimensionsRaisesError(self):
self.initialize_training_step_fn_and_all_variables()
self.initialize()
block_size = (2, 2)
block_pooling_type = "AVG"
# Weights as in testBlockMasking, but with one extra dimension.
Expand All @@ -186,7 +186,7 @@ def testConditionalMaskUpdate(self):
dtype=weight_dtype)
threshold = tf.Variable(
tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype)
self.initialize_training_step_fn_and_all_variables()
self.initialize()

def linear_sparsity(step):
sparsity_val = tf.convert_to_tensor(
Expand Down
Expand Up @@ -236,6 +236,8 @@ def no_op():
# Always execute the op that performs weights = weights * mask
# Relies on UpdatePruningStep callback to ensure the weights
# are sparse after the final backpropagation.
#
# self.add_update does nothing during eager execution.
self.add_update(self.pruning_obj.weight_mask_op())

return self.layer.call(inputs)
Expand Down

0 comments on commit dbae704

Please sign in to comment.