Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_schedule import PruningSchedule
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_schedule import ConstantSparsity
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_schedule import PolynomialDecay
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_schedule import ConstantMbyNSparsity
from tensorflow_model_optimization.python.core.sparsity.keras.pruning_schedule import PolynomialDecayMbyNSparsity

from tensorflow_model_optimization.python.core.sparsity.keras.prunable_layer import PrunableLayer

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,6 @@ def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type):
'm_by_n': (2, 4),
},
)

def testMbyNSparsityPruning_SupportedLayers(self,
layer_type,
layer_arg,
Expand All @@ -382,6 +381,8 @@ def testMbyNSparsityPruning_SupportedLayers(self,
sparsity_ratio=0.50):
"""Check that we prune supported layers with m by n sparsity."""
self.params.update({'sparsity_m_by_n': m_by_n})
self.params.update(
{'pruning_schedule': pruning_schedule.ConstantMbyNSparsity(prune_step=1)})

model = keras.Sequential()
model.add(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,11 @@ def _update_mask_sparsity_m_by_n(self, weights, m_by_n=(2, 4)):
that n elements with the lowest absolute values in the block
of m elements are set to be zero. We don't return any threshold.

If coverage ratio provided with the pruning schedule is less than 1.0,
a partially sparsity mask will be calculated and add up to m_by_n
sparsity mask, so that coverage ratio of m_by_n sparsity pattern
on mask is (coverage_ratio * 100%).

Args:
weights: The weight tensor that needs to be masked.
m_by_n: tuple of two integers, indicating m zeros out of n consecutive
Expand All @@ -125,17 +130,40 @@ def _update_mask_sparsity_m_by_n(self, weights, m_by_n=(2, 4)):
0 or 1 to indicate which of the values in weights should be set to zero.
It throws an error if the requested mask cannot be created.
"""
prepared_weights = pruning_utils.weights_rearrange(weights)
mask = pruning_utils.generate_m_by_n_mask(prepared_weights, m_by_n)
new_mask = pruning_utils.m_by_n_sparsity_mask_prepare(mask, weights.shape)
coverage_ratio = self._pruning_schedule(self._step_fn())[1]
with tf.name_scope('m_by_n_sparsity_pruning_ops'):
prepared_weights = pruning_utils.weights_rearrange(weights)

mask = pruning_utils.generate_m_by_n_mask(prepared_weights, m_by_n)
new_mask = pruning_utils.m_by_n_sparsity_mask_prepare(mask, weights.shape)

def update_mask_sparsity_m_by_n_with_coverage_ratio():
partial_covered_mask = pruning_utils.generate_partial_sparsity_mask(
prepared_weights, m_by_n[1], coverage_ratio)
new_partial_covered_mask = pruning_utils.m_by_n_sparsity_mask_prepare(
partial_covered_mask, weights.shape)

m_by_n_mask = tf.clip_by_value(
new_mask + new_partial_covered_mask,
clip_value_min=0.0,
clip_value_max=1.0
)

return m_by_n_mask

m_by_n_mask = tf.cond(
tf.math.less(coverage_ratio, 1.0),
update_mask_sparsity_m_by_n_with_coverage_ratio,
lambda: new_mask,
)

return new_mask
return m_by_n_mask

def _maybe_update_block_mask(self, weights):
"""Performs block-granular masking of the weights.

If sparsity_m_by_n is selected, then we return the relevant pruning mask,
that nullify two out of four elements in the block.
that nullify m out of n consecutive elements in the block.

Block pruning occurs only if the block_height or block_width is > 1 and
if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
Expand Down Expand Up @@ -240,11 +268,7 @@ def conditional_mask_update(self):
"""Returns an op to updates masks as per the pruning schedule."""

def maybe_update_masks():
if self._sparsity_m_by_n:
# Update structured sparsity masks only at step 1
return tf.math.equal(self._step_fn(), 1)
else:
return self._pruning_schedule(self._step_fn())[0]
return self._pruning_schedule(self._step_fn())[0]

def no_update():
return tf.no_op()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,12 @@ def linear_sparsity(step):
expected_non_zero_count = [100, 90, 90, 70, 70, 50, 50, 50, 50, 50]
self.assertAllEqual(expected_non_zero_count, non_zero_count)

def _sparsity_m_by_n_masking(self, weight, m_by_n=(2, 4)):
def _sparsity_m_by_n_masking(
self,
weight,
m_by_n=(2, 4),
pruning_schedule=pruning_schedule.ConstantMbyNSparsity(prune_step=1),
):
mask = tf.Variable(tf.ones(weight.get_shape()), name="mask")
threshold = tf.Variable(1, name="threshold")
self.initialize()
Expand All @@ -258,7 +263,7 @@ def _sparsity_m_by_n_masking(self, weight, m_by_n=(2, 4)):
p = pruning_impl.Pruning(
pruning_vars=[(weight, mask, threshold)],
training_step_fn=self.training_step_fn,
pruning_schedule=self.constant_sparsity,
pruning_schedule=pruning_schedule,
block_size=(1, 1),
block_pooling_type="AVG",
sparsity_m_by_n=m_by_n,
Expand Down Expand Up @@ -323,6 +328,65 @@ def testSparsityMbyNMaskingSimpleRaises(self, weights_shape):
with self.assertRaises(ValueError):
self._sparsity_m_by_n_masking(weights_ts)

def testSparsityMbyNMaskingUpdateWithPolynomialDecay(self):
weights = tf.reshape(tf.linspace(1.0, 16.0, 16), [4,4])
mask = tf.Variable(tf.ones(weights.get_shape()))
self.initialize()

polynomial_schedule = pruning_schedule.PolynomialDecayMbyNSparsity(
0.0, 1, 10, frequency=1,
)
p = pruning_impl.Pruning(
pruning_vars=[(weights, mask, tf.Variable(1))],
training_step_fn=self.training_step_fn,
pruning_schedule=polynomial_schedule,
block_size=(1, 1),
block_pooling_type="AVG",
sparsity_m_by_n=(2, 4),
)

case_list = [
{
"step": 1, # coverate ratio 0.0
"expected": [
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
],
},
{
"step": 3, # coverate ratio 0.530
"expected": [
[0.0, 0.0, 1.0, 1.0],
[0.0, 0.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
],
},
{
"step": 5, # coverate ratio 0.829
"expected": [
[0.0, 0.0, 0.0, 1.0],
[0.0, 0.0, 0.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
],
},
{
"step": 9, # coverate ratio 0.999
"expected": [
[0.0, 0.0, 0.0, 0.0],
[0.0, 0.0, 0.0, 0.0],
[1.0, 1.0, 1.0, 1.0],
[1.0, 1.0, 1.0, 1.0],
],
},
]
for case in case_list:
self.global_step.assign(case["step"])
_, new_mask = p._maybe_update_block_mask(weights)
self.assertAllEqual(new_mask, case["expected"])

if __name__ == "__main__":
test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,123 @@ def get_config(self):
'frequency': self.frequency
}
}


class ConstantMbyNSparsity(PruningSchedule):
"""M_by_N Sparsity Pruning Schedule with 100% coverage of sparsity mask
throughout training.
"""

def __init__(self, prune_step=0):
"""Initializes a M_by_N Pruning schedule with fully coverage.

m_by_n sparsity masks calculate and apply on weights at prune_step,
they will not be updated after, but keep applying on the weights
during the training process.
The m_by_n sparsity coverage ratio remain constant as 100%.

Args:
prune_step: Step at which apply m_by_n sparsity pruning.
"""
self.prune_step = prune_step

if prune_step < 0:
raise ValueError('prune_step should be >= 0')

def _should_prune_in_step(self, step, prune_step):
return tf.math.equal(step, prune_step)

def __call__(self, step):
return (self._should_prune_in_step(step, self.prune_step),
tf.constant(1.0, dtype=tf.float32))

def get_config(self):
return {
'class_name': self.__class__.__name__,
'config': {
'prune_step': self.prune_step
}
}


class PolynomialDecayMbyNSparsity(PruningSchedule):
"""M_by_N Sparsity Pruning Schedule with polynomial decay of
coverage ratio.
"""

def __init__(
self,
initial_coverage_ratio,
begin_step,
end_step,
power=3.0,
frequency=100
):
"""Initializes a M_by_N Pruning schedule with polynomial decay of coverage
ratio.

Coverage ratio grows rapidly in the beginning from
initial_coverage_ratio, then plateaus slowly to 100% coverage.

The function applied a polynomial decay function.
This schedule applies a polynomial decay function to m_by_n sparsity
coverage ratio in the interval [`begin_step`, `end_step`],
given a provided `initial_coverage_ratio`, to reach an 100% coverage
ratio at the `end_step`.

Args:
initial_coverage_ratio: coverage ratio of m_by_n sparsity at which
the pruning begins.
begin_step: Step at which to begin m_by_n sparsity pruning.
end_step: Step at which to end m_by_n sparsity pruning, reach an 100%
coverage ratio.
power: The power of the polynomial, defaults to 3.0.
frequency: Only apply pruning every `frequency` steps, defaults to 100.
"""
self.initial_coverage_ratio = initial_coverage_ratio
self.power = power

self.begin_step = begin_step
self.end_step = end_step
self.frequency = frequency

self._has_build_polynomial_decay = False

self._validate_step(self.begin_step, self.end_step, self.frequency, False)
self._validate_sparsity(initial_coverage_ratio, 'initial_coverage_ratio')

def _build_polynomial_decay(self):
self._has_build_polynomial_decay = True
self._coverage_ratio_polynomial_decay = (
tf.keras.optimizers.schedules.PolynomialDecay(
self.initial_coverage_ratio,
self.end_step - self.begin_step, # decay steps
1.0, # final_coverage_ratio
power=self.power,
cycle=False,
name='MbyNCoverageRatioPolynomialDecay',
)
)

def __call__(self, step):
if not self._has_build_polynomial_decay:
self._build_polynomial_decay()

return (
self._should_prune_in_step(step, self.begin_step, self.end_step,
self.frequency),
self._coverage_ratio_polynomial_decay(step - self.begin_step),
)

def get_config(self):

return {
'class_name': self.__class__.__name__,
'config': {
'initial_coverage_ratio': self.initial_coverage_ratio,
'power': self.power,
'begin_step': self.begin_step,
'end_step': self.end_step,
'frequency': self.frequency,
}
}
Loading