-
Notifications
You must be signed in to change notification settings - Fork 320
/
pruning_wrapper.py
308 lines (255 loc) · 11.7 KB
/
pruning_wrapper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A Keras wrapper to add pruning related variables to a layer."""
# pylint: disable=missing-docstring,g-multiple-import,unused-import,protected-access
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# import g3
import numpy as np
import tensorflow as tf
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
from tensorflow_model_optimization.python.core.sparsity.keras import prunable_layer
from tensorflow_model_optimization.python.core.sparsity.keras import prune_registry
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_impl
from tensorflow_model_optimization.python.core.sparsity.keras import pruning_schedule as pruning_sched
keras = tf.keras
K = keras.backend
Wrapper = keras.layers.Wrapper
class PruneLowMagnitude(Wrapper):
"""This wrapper augments a keras layer so the weight tensor may be pruned.
This wrapper implements magnitude-based pruning of the weight tensors.
Magnitude-based pruning achieves a target sparsity (s% of zeros) for a given
weight tensor by monitoring the distribution of the absolute values of the
weight tensor and determining the weight value (referred to as threshold)
below which s% of elements lie. For every weight tensor being pruned, the
wrapper maintains an identically shaped tensor (referred to as mask) which
stores 0 if the weight value lies below the threshold.
The mask and thresholds are computed during the training based on the
evolution of the weight values.
Block sparse patterns:
For certain SIMD hardware architectures, it may be beneficial to induce
spatially correlated sparsity. To train models in which the weight tensors
have block sparse structure, the pruning wrapper can be configured with
the block_height and block_width configuration parameters set to the desired
block configuration (2x2, 4x4, 4x1, 1x8, etc). This is applicable to
rank-2 weight tensor only and the tensor partitioned into non-overlapping
blocks of size [block_height, block_dim]. Either the average or max absolute
value in this block is taken as a proxy for the entire block
(set by block_pooling_function configuration parameter)
while computing the distribution of the weight values and
the threshold for pruning.
Custom keras layers:
The pruning wrapper can also be applied to a user-defined keras layer.
Such a layer may contain one or more weight tensors that may be pruned.
To apply pruning wrapper to such layers, set prunable_weight_names to mark
the weight tensors for pruning.
Sparsity function:
The target sparsity for the weight tensors are set through the
pruning_schedule parameter of the pruning wrapper. The user must create a
python callable that returns a scalar tensorflow tensor and pass this
callable to the sparsity_function parameter. This scalar tensor contains the
target sparsity value for the weight tensors in the layer.
The wrapper provides the following pre-built sparsity functions:
ConstantSparsity
GradualSparsity
Eg.
params = PruningParams(frequency=10,pruning_schedule=ConstantSparsity(0.9))
pruned_model = keras.model.Sequential()
pruned_model.add(
Prune(keras.layers.Dense(256), input_shape=(256,)))
pruned_model.add(Prune(keras.layers.Dense(1024), params=params))
"""
_PRUNE_CALLBACK_ERROR_MSG = (
'Prune() wrapper requires the UpdatePruningStep callback to be provided '
'during training. Please add it as a callback to your model.fit call.')
def __init__(self,
layer,
pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
block_size=(1, 1),
block_pooling_type='AVG',
**kwargs):
"""Create a pruning wrapper for a keras layer.
#TODO(pulkitb): Consider if begin_step should be 0 by default.
Args:
layer: The keras layer to be pruned.
pruning_schedule: A `PruningSchedule` object that controls pruning rate
throughout training.
block_size: (optional) The dimensions (height, weight) for the block
sparse pattern in rank-2 weight tensors.
block_pooling_type: (optional) The function to use to pool weights in the
block. Must be 'AVG' or 'MAX'.
**kwargs: Additional keyword arguments to be passed to the keras layer.
"""
self.pruning_schedule = pruning_schedule
self.block_size = block_size
self.block_pooling_type = block_pooling_type
# An instance of the Pruning class. This class contains the logic to prune
# the weights of this layer.
self.pruning_obj = None
# A list of all (weight,mask,threshold) tuples for this layer
self.pruning_vars = []
if block_pooling_type not in ['AVG', 'MAX']:
raise ValueError(
'Unsupported pooling type \'{}\'. Should be \'AVG\' or \'MAX\'.'
.format(block_pooling_type))
if not isinstance(layer, tf.keras.layers.Layer):
raise ValueError(
'Please initialize `Prune` layer with a '
'`Layer` instance. You passed: {input}'.format(input=layer))
# TODO(pulkitb): This should be pushed up to the wrappers.py
# Name the layer using the wrapper and underlying layer name.
# Prune(Dense) becomes prune_dense_1
kwargs.update({'name': '{}_{}'.format(
generic_utils.to_snake_case(self.__class__.__name__), layer.name)})
if isinstance(layer, prunable_layer.PrunableLayer):
# Custom layer in client code which supports pruning.
super(PruneLowMagnitude, self).__init__(layer, **kwargs)
elif prune_registry.PruneRegistry.supports(layer):
# Built-in keras layers which support pruning.
super(PruneLowMagnitude, self).__init__(
prune_registry.PruneRegistry.make_prunable(layer), **kwargs)
else:
raise ValueError(
'Please initialize `Prune` with a supported layer. Layers should '
'either be a `PrunableLayer` instance, or should be supported by the '
'PruneRegistry. You passed: {input}'.format(input=layer.__class__))
self._track_trackable(layer, name='layer')
# TODO(yunluli): Work-around to handle the first layer of Sequential model
# properly. Can remove this when it is implemented in the Wrapper base
# class.
# The _batch_input_shape attribute in the first layer makes a Sequential
# model to be built. This change makes sure that when we apply the wrapper
# to the whole model, this attribute is pulled into the wrapper to preserve
# the 'built' state of the model.
if not hasattr(self, '_batch_input_shape') and hasattr(
layer, '_batch_input_shape'):
self._batch_input_shape = self.layer._batch_input_shape
def build(self, input_shape):
super(PruneLowMagnitude, self).build(input_shape)
weight_vars, mask_vars, threshold_vars = [], [], []
self.prunable_weights = self.layer.get_prunable_weights()
# For each of the prunable weights, add mask and threshold variables
for weight in self.prunable_weights:
mask = self.add_variable(
'mask',
shape=weight.shape,
initializer=tf.keras.initializers.get('ones'),
dtype=weight.dtype,
trainable=False,
aggregation=tf.VariableAggregation.MEAN)
threshold = self.add_variable(
'threshold',
shape=[],
initializer=tf.keras.initializers.get('zeros'),
dtype=weight.dtype,
trainable=False,
aggregation=tf.VariableAggregation.MEAN)
weight_vars.append(weight)
mask_vars.append(mask)
threshold_vars.append(threshold)
self.pruning_vars = list(zip(weight_vars, mask_vars, threshold_vars))
# Add a scalar tracking the number of updates to the wrapped layer.
self.pruning_step = self.add_variable(
'pruning_step',
shape=[],
initializer=tf.keras.initializers.Constant(-1),
dtype=tf.int64,
trainable=False)
def training_step_fn():
return self.pruning_step
# Create a pruning object
self.pruning_obj = pruning_impl.Pruning(
training_step_fn=training_step_fn,
pruning_vars=self.pruning_vars,
pruning_schedule=self.pruning_schedule,
block_size=self.block_size,
block_pooling_type=self.block_pooling_type)
def call(self, inputs, training=None):
if training is None:
training = K.learning_phase()
def add_update():
with tf.control_dependencies([
tf.debugging.assert_greater_equal(
self.pruning_step,
np.int64(0),
message=self._PRUNE_CALLBACK_ERROR_MSG)
]):
with tf.control_dependencies(
[self.pruning_obj.conditional_mask_update()]):
return tf.no_op('update')
def no_op():
return tf.no_op('no_update')
update_op = tf_utils.smart_cond(training, add_update, no_op)
self.add_update(update_op)
# Always execute the op that performs weights = weights * mask
# Relies on UpdatePruningStep callback to ensure the weights
# are sparse after the final backpropagation.
#
# self.add_update does nothing during eager execution.
self.add_update(self.pruning_obj.weight_mask_op())
return self.layer.call(inputs)
def compute_output_shape(self, input_shape):
return self.layer.compute_output_shape(input_shape)
def get_config(self):
base_config = super(PruneLowMagnitude, self).get_config()
config = {
'pruning_schedule': self.pruning_schedule.get_config(),
'block_size': self.block_size,
'block_pooling_type': self.block_pooling_type
}
return dict(list(base_config.items()) + list(config.items()))
@classmethod
def from_config(cls, config):
config = config.copy()
pruning_schedule = config.pop('pruning_schedule')
deserialize_keras_object = keras.utils.deserialize_keras_object # pylint: disable=g-import-not-at-top
# TODO(pulkitb): This should ideally be fetched from pruning_schedule,
# which should maintain a list of all the pruning_schedules.
custom_objects = {
'ConstantSparsity': pruning_sched.ConstantSparsity,
'PolynomialDecay': pruning_sched.PolynomialDecay
}
config['pruning_schedule'] = deserialize_keras_object(
pruning_schedule,
module_objects=globals(),
custom_objects=custom_objects)
from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
layer = deserialize_layer(config.pop('layer'))
config['layer'] = layer
return cls(**config)
@property
def trainable(self):
return self.layer.trainable
@trainable.setter
def trainable(self, value):
self.layer.trainable = value
@property
def trainable_weights(self):
return self.layer.trainable_weights
@property
def non_trainable_weights(self):
return self.layer.non_trainable_weights + self._non_trainable_weights
@property
def updates(self):
return self.layer.updates + self._updates
@property
def losses(self):
return self.layer.losses + self._losses
def get_weights(self):
return self.layer.get_weights()
def set_weights(self, weights):
self.layer.set_weights(weights)