Skip to content

Commit

Permalink
implement lth and refactor names
Browse files Browse the repository at this point in the history
  • Loading branch information
xwinxu committed Jun 25, 2020
1 parent bad40fb commit f0ae80d
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 10 deletions.
14 changes: 7 additions & 7 deletions tensorflow_model_optimization/python/core/sparsity_tf2/BUILD
Expand Up @@ -72,12 +72,12 @@ py_library(
)

py_library(
name = "pruning_impl",
srcs = ["pruning_impl.py"],
name = "pruner",
srcs = ["pruner.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":pruning_utils",
"//tensorflow_model_optimization/python/core/sparsity/keras:pruning_utils",
# tensorflow dep1,
# python:summary tensorflow dep2,
"//tensorflow_model_optimization/python/core/keras:compat",
Expand Down Expand Up @@ -196,14 +196,14 @@ py_test(
)

py_test(
name = "pruning_impl_test",
name = "pruner_test",
size = "medium",
srcs = ["pruning_impl_test.py"],
srcs = ["pruner_test.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":pruning_impl",
":pruning_schedule",
":pruner",
"//tensorflow_model_optimization/python/core/sparsity/keras:pruning_schedule",
# numpy dep1,
# python/keras tensorflow dep2,
"//tensorflow_model_optimization/python/core/keras:compat",
Expand Down
Expand Up @@ -117,7 +117,7 @@ def _maybe_update_block_mask(self, weights):
# TODO(pulkitb): Check if squeeze operations should now be removed since
# we are only accepting 2-D weights.

squeezed_weights = tf.squeeze(weights)
# squeezed_weights = tf.squeeze(weights)
abs_weights = tf.math.abs(squeezed_weights)
pooled_weights = pruning_utils.factorized_pool(
abs_weights,
Expand Down Expand Up @@ -167,7 +167,7 @@ def add_pruning_summaries(self, step, pruning_vars):
summary.scalar(mask.name + '/sparsity', 1.0 - tf.math.reduce_mean(mask))
summary.scalar(threshold.name + '/threshold', threshold)

def _mask_weight(self, weight, mask):
def _apply_mask(self, weight, mask):
"""Directly masks the weights (updating the weight variables)."""

# TODO(Kaftan/xwinxu): figure out if this is totally unneeded now
Expand Down Expand Up @@ -211,4 +211,61 @@ def prune(self, optimizer, var, grad):
mask = optimizer.get_slot(var, 'mask')
threshold = optimizer.get_slot(var, 'threshold')
self.update_masks([(var, mask, threshold)], step=optimizer.iterations)
self._mask_weight(var, mask)
self._apply_mask(var, mask)

class LTHPruner(Pruner):
"""
Implementation of Lottery Ticket Hypothesis experiments.
"""

def __init__(self,
pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
reload_schedule=None,
save_schedule=None,
block_size=(1,1),
block_pooling_type='AVG',
):
"""The logic for magnitude-based pruning weight tensors.
Args:
pruning_schedule: A `PruningSchedule` object that controls pruning rate
throughout training.
reload_schedule: A `PruningSchedule` object that controls reloading of weights
throughout training. Default same as pruning schedule.
save_schedule: A `PruningSchedule` objeect that controls the saving of weights
for relloading after checkpointing in LTH experiments.
block_size: The dimensions (height, weight) for the block sparse pattern
in rank-2 weight tensors.
block_pooling_type: (optional) The function to use to pool weights in the
block. Must be 'AVG' or 'MAX'.
"""
super(Pruner, self).__init__(pruning_schedule, block_size, block_pooling_type)
self.load_itr = load_itr
self.reload_schedule = reload_schedule if reload_schedule else pruning_schedule
self.save_schedule = save_schedule if save_schedule else pruning_sched.ConstantSparsity(0.0, 0, 0)

def create_slots(self, optimizer, var):
optimizer.add_slot(var, 'mask', initializer='ones')
optimizer.add_slot(var, 'threshold', initializer=tf.zeros(shape=()))
optimizer.add_slot(var, 'original_initialization', initializer='GlorotNormal')

def _maybe_save_weights(self, optimizer, var):
if self.save_schedule._should_prune_in_step(optimizer.iterations,
self.save_schedule.begin_step, self.save_schedule.end_step, self.save_schedule.frequency):
optimizer.get_slot(var, 'original_initialization').assign(var)

def _maybe_reload_weights(self, optimizer, var, mask):
if self.reload_schedule._should_prune_in_step(optimizer.iterations,
self.reload_schedule.begin_step, self.reload_schedule.end_step, self.reload_schedule.frequency):
reload_weights = tf.math.multiply(var, mask)
optimizer.get_slot(var, 'original_initialization').assign(reload_weights)


def prune(self, optimizer, var, grad):
# gradient is unused for lottery ticket pruning
self._maybe_save_weights(optimizer, var)
mask = optimizer.get_slot(var, 'mask')
threshold = optimizer.get_slot(var, 'threshold')
self.update_masks([(var, mask, threshold)], step=optimizer.iterations)
self._maybe_reload_weights(optimizer, var)
self._apply_mask(var, mask)
228 changes: 228 additions & 0 deletions tensorflow_model_optimization/python/core/sparsity_tf2/pruner_test.py
@@ -0,0 +1,228 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the key functions in pruner library."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# import g3

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

# TODO(b/139939526): move to public API.
from tensorflow.python.keras import keras_parameterized
from tensorflow_model_optimization.python.core.keras import compat
from tensorflow_model_optimization.python.core.tf2_sparsity.keras import pruning_impl
from tensorflow_model_optimization.python.core.tf2_sparsity.keras import pruning_schedule
from tensorflow_model_optimization.python.core.tf2_sparsity.keras import pruning_utils

K = tf.keras.backend
dtypes = tf.dtypes
test = tf.test


def assign_add(ref, value):
if hasattr(tf, "assign_add"):
return tf.assign_add(ref, value)
else:
return ref.assign_add(value)

class PruningTest(test.TestCase, parameterized.TestCase):

def setUp(self):
super(PruningTest, self).setUp()
self.block_size = (1, 1)
self.block_pooling_type = "AVG"

self.constant_sparsity = pruning_schedule.ConstantSparsity(0.5, 0, 100, 1)

# Variable initialization outside of setUp() is needed for compatibility with
# run_all_keras_modes.
#
# setUp() lies outside of the "eager scope" that wraps the test cases
# themselves, resulting in initializing graph tensors instead of eager
# tensors when testing eager execution.
def initialize(self):
self.global_step = tf.Variable(
tf.zeros([], dtype=dtypes.int32),
dtype=dtypes.int32,
name="global_step")

def training_step_fn():
return self.global_step
self.training_step_fn = training_step_fn

compat.initialize_variables(self)

def testUpdateSingleMask(self):
weight = tf.Variable(np.linspace(1.0, 100.0, 100), name="weights")
weight_dtype = weight.dtype.base_dtype
mask = tf.Variable(
tf.ones(weight.get_shape(), dtype=weight_dtype),
name="mask",
dtype=weight_dtype)
threshold = tf.Variable(
tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype)
self.initialize()
pruning_vars = [(weight, mask, threshold)]
next_step = self.training_step_fn() + 1

p = pruning_impl.Pruner(
pruning_schedule=self.constant_sparsity,
block_size=self.block_size,
block_pooling_type=self.block_pooling_type)

mask_before_pruning = K.get_value(mask)
self.assertAllEqual(np.count_nonzero(mask_before_pruning), 100)

if tf.executing_eagerly():
p.update_masks(pruning_vars, next_step)
else:
K.get_session().run(p.update_masks(pruning_vars, next_step))

mask_after_pruning = K.get_value(mask)
self.assertAllEqual(np.count_nonzero(mask_after_pruning), 50)

def testConstructsMaskAndThresholdCorrectly(self):
self.initialize()
p = pruning_impl.Pruner(
# Sparsity math often returns values with small tolerances.
lambda x: (True, 0.200000018),
(1, 1), None)
step = self.global_step

# input matrix is [ 1.0, 2.0, ..., 8.0, 9.0, 10.0 ]
threshold, mask = p._update_mask(step, np.arange(1, 11))

self.assertEqual(3, K.get_value(threshold))
self.assertAllEqual(
# expected matrix is [ 0.0, 0.0, 1.0, 1.0 ... 1.0 ]
np.concatenate((np.zeros(2), np.ones(8))), K.get_value(mask))

def _blockMasking(self, block_size, block_pooling_type, weight,
expected_mask):
mask = tf.Variable(
tf.ones(weight.get_shape(), dtype=weight.dtype),
name="mask",
dtype=weight.dtype)
threshold = tf.Variable(
tf.zeros([], dtype=weight.dtype), name="threshold", dtype=weight.dtype)
self.initialize()
step = self.training_step_fn()

# Set up pruning
p = pruning_impl.Pruner(
pruning_schedule=self.constant_sparsity,
block_size=block_size,
block_pooling_type=block_pooling_type)

_, new_mask = p._maybe_update_block_mask(step, weight)
# Check if the mask is the same size as the weights
self.assertAllEqual(new_mask.get_shape(), weight.get_shape())
mask_after_pruning = K.get_value(new_mask)
self.assertAllEqual(mask_after_pruning, expected_mask)

def testBlockMaskingAvg(self):
block_size = (2, 2)
block_pooling_type = "AVG"
weight = tf.constant([[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2],
[0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4, 0.4]])
expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[1., 1., 1., 1.], [1., 1., 1., 1.]]

self._blockMasking(block_size, block_pooling_type, weight, expected_mask)

def testBlockMaskingMax(self):
block_size = (2, 2)
block_pooling_type = "MAX"
weight = tf.constant([[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2],
[0.3, 0.0, 0.4, 0.0], [0.0, -0.3, 0.0,
-0.4]])
expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[1., 1., 1., 1.], [1., 1., 1., 1.]]

self._blockMasking(block_size, block_pooling_type, weight, expected_mask)

def testBlockMaskingWithHigherDimensionsRaisesError(self):
self.initialize()
block_size = (2, 2)
block_pooling_type = "AVG"
# Weights as in testBlockMasking, but with one extra dimension.
weight = tf.constant([[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2],
[0.3, 0.3, 0.4, 0.4], [0.3, 0.3, 0.4,
0.4]]])
expected_mask = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
[1., 1., 1., 1.], [1., 1., 1., 1.]]]

# Block masking should only be used with 2 Dimensional weights.
with self.assertRaises(ValueError):
self._blockMasking(block_size, block_pooling_type, weight, expected_mask)

def testConditionalMaskUpdate(self):
weight = tf.Variable(np.linspace(1.0, 100.0, 100), name="weights")
weight_dtype = weight.dtype.base_dtype
mask = tf.Variable(
tf.ones(weight.get_shape(), dtype=weight_dtype),
name="mask",
dtype=weight_dtype)
threshold = tf.Variable(
tf.zeros([], dtype=weight_dtype), name="threshold", dtype=weight_dtype)
self.initialize()
pruning_vars = [(weight, mask, threshold)]

def linear_sparsity(step):
sparsity_val = tf.convert_to_tensor(
[0.0, 0.1, 0.1, 0.3, 0.3, 0.5, 0.5, 0.5, 0.5, 0.5])
return tf.convert_to_tensor(True), sparsity_val[step]

def weight_mask_op(pruning_vars):
values_and_vars = []
for weight, mask, _ in pruning_vars:
# values_and_vars.append((tf.math.multiply(weight, mask), weight))
weight.assign(tf.math.multiply(weight, mask))
# return tf.group(values_and_vars)


# Set up pruning
p = pruning_impl.Pruner(
pruning_schedule=linear_sparsity,
block_size=self.block_size,
block_pooling_type=self.block_pooling_type)

step = self.training_step_fn

non_zero_count = []
for _ in range(10):
if tf.executing_eagerly():
p.update_masks(pruning_vars, step())
weight_mask_op(pruning_vars)
assign_add(self.global_step, 1)
else:
K.get_session().run(p.update_masks(pruning_vars, step()))
K.get_session().run(weight_mask_op(pruning_vars))
K.get_session().run(assign_add(self.global_step, 1))

non_zero_count.append(np.count_nonzero(K.get_value(weight)))

# Weights pruned at steps 1,3,5
expected_non_zero_count = [100, 90, 90, 70, 70, 50, 50, 50, 50, 50]
self.assertAllEqual(expected_non_zero_count, non_zero_count)


if __name__ == "__main__":
test.main()

0 comments on commit f0ae80d

Please sign in to comment.