Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def cluster_scope():
"""
return CustomObjectScope(
{
'ClusterWeights' : cluster_wrapper.ClusterWeights
'ClusterWeights': cluster_wrapper.ClusterWeights
}
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
k = keras.backend
Layer = keras.layers.Layer
Wrapper = keras.layers.Wrapper
CentroidInitialization = cluster_config.CentroidInitialization


class ClusterWeights(Wrapper):
Expand Down Expand Up @@ -105,7 +106,7 @@ def __init__(self,
self.number_of_clusters = number_of_clusters

# Stores the pairs of weight names and references to their tensors
self.clustered_vars = []
self.ori_weights_vars_tf = {}

# Stores references to class instances that implement different clustering
# behaviour for different shapes of objects
Expand Down Expand Up @@ -227,24 +228,33 @@ def build(self, input_shape):
)

# We store these pairs to easily update this variables later on
self.clustered_vars.append((weight_name, weight))
self.ori_weights_vars_tf[weight_name] = self.add_weight(
'ori_weights_vars_tf',
shape=weight.shape,
dtype=weight.dtype,
trainable=True,
initializer=initializers.Constant(
value=k.batch_get_value([weight])[0]
)
)

# We use currying here to get an updater which can be triggered at any time
# in future and it would return the latest version of clustered weights
def get_updater(for_weight_name):
def fn():
return self.clustering_impl[for_weight_name].get_clustered_weight(
self.pulling_indices_tf[for_weight_name]
)
# Get the clustered weights
pulling_indices = self.pulling_indices_tf[for_weight_name]
clustered_weights = self.clustering_impl[for_weight_name].\
get_clustered_weight(pulling_indices)
return clustered_weights

return fn

# This will allow us to restore the order of weights later
# This loop stores pairs of weight names and how to restore them

for ct, weight in enumerate(self.layer.weights):
name = self._weight_name(weight.name)
full_name = self.layer.name + "/" + name
full_name = '{}{}{}'.format(self.layer.name, '/', name)
if ct in self.gone_variables:
# Again, not sure if this is needed
weight_name = clusterable_weights_to_variables[name]
Expand All @@ -253,14 +263,26 @@ def fn():
self.restore.append((name, full_name, weight))

def call(self, inputs):
# In the forward pass, we need to update the cluster associations manually
# since they are integers and not differentiable. Gradients won't flow back
# through tf.argmin
# Go through all tensors and replace them with their clustered copies.
for weight_name, _ in self.clustered_vars:
setattr(
self.layer, weight_name,
self.clustering_impl[weight_name].get_clustered_weight(
self.pulling_indices_tf[weight_name]
)
)
for weight_name in self.ori_weights_vars_tf:
pulling_indices = self.pulling_indices_tf[weight_name]

# Update cluster associations
pulling_indices.assign(tf.dtypes.cast(
self.clustering_impl[weight_name].\
get_pulling_indices(self.ori_weights_vars_tf[weight_name]),
pulling_indices.dtype
))

clustered_weights = self.clustering_impl[weight_name].\
get_clustered_weight_forward(pulling_indices,\
self.ori_weights_vars_tf[weight_name])

# Replace the weights with their clustered counterparts
setattr(self.layer, weight_name, clustered_weights)

return self.layer.call(inputs)

Expand All @@ -271,7 +293,7 @@ def get_config(self):
base_config = super(ClusterWeights, self).get_config()
config = {
'number_of_clusters': self.number_of_clusters,
'cluster_centroids_init': self.cluster_centroids_init,
'cluster_centroids_init': self.cluster_centroids_init
}
return dict(list(base_config.items()) + list(config.items()))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
"""Tests for keras ClusterWeights wrapper API."""

import itertools
import numpy as np
import tensorflow as tf

from absl.testing import parameterized
Expand Down Expand Up @@ -155,9 +154,9 @@ def testIfLayerHasBatchShapeClusterWeightsMustHaveIt(self):
*itertools.product(
range(2, 16, 4),
(
CentroidInitialization.LINEAR,
CentroidInitialization.RANDOM,
CentroidInitialization.DENSITY_BASED
CentroidInitialization.LINEAR,
CentroidInitialization.RANDOM,
CentroidInitialization.DENSITY_BASED
)
)
)
Expand Down Expand Up @@ -194,5 +193,67 @@ def testValuesAreClusteredAfterStripping(self,
self.assertEqual(stripped_model.layers[0].weights[0].name, weights_name)
self.assertEqual(stripped_model.layers[0].weights[1].name, bias_name)

def testClusterReassociation(self):
"""
Verifies that the association of weights to cluster centroids are updated
every iteration.
"""

# Create a dummy layer for this test
input_shape = (1, 2,)
l = cluster_wrapper.ClusterWeights(
keras.layers.Dense(8, input_shape=input_shape),
number_of_clusters=2,
cluster_centroids_init=CentroidInitialization.LINEAR
)
# Build a layer with the given shape
l.build(input_shape)

# Get name of the clusterable weights
clusterable_weights = l.layer.get_clusterable_weights()
self.assertEqual(len(clusterable_weights), 1)
weights_name = clusterable_weights[0][0]
self.assertEqual(weights_name, 'kernel')
# Get cluster centroids
centroids = l.cluster_centroids_tf[weights_name]

# Calculate some statistics of the weights to set the centroids later on
mean_weight = tf.reduce_mean(l.layer.kernel)
min_weight = tf.reduce_min(l.layer.kernel)
max_weight = tf.reduce_max(l.layer.kernel)
max_dist = max_weight - min_weight

def assert_all_weights_associated(weights, centroid_index):
"""Helper function to make sure that all weights are associated with one
centroid."""
all_associated = tf.reduce_all(
tf.equal(
weights,
tf.constant(centroids[centroid_index], shape=weights.shape)
)
)
self.assertTrue(all_associated)

# Set centroids so that all weights should be re-associated with centroid 0
centroids[0].assign(mean_weight)
centroids[1].assign(mean_weight + 2.0 * max_dist)

# Update associations of weights to centroids
l.call(tf.ones(shape=input_shape))

# Weights should now be all clustered with the centroid 0
assert_all_weights_associated(l.layer.kernel, centroid_index=0)

# Set centroids so that all weights should be re-associated with centroid 1
centroids[0].assign(mean_weight - 2.0 * max_dist)
centroids[1].assign(mean_weight)

# Update associations of weights to centroids
l.call(tf.ones(shape=input_shape))

# Weights should now be all clustered with the centroid 1
assert_all_weights_associated(l.layer.kernel, centroid_index=1)


if __name__ == '__main__':
test.main()
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,22 @@ def get_pulling_indices(self, weight):
"""
pass

@tf.custom_gradient
def add_custom_gradients(self, clst_weights, weights):
"""
This function overrides gradients in the backprop stage: original mul
becomes add, tf.sign becomes tf.identity. It is to update the original
weights with the gradients updates directly from the layer wrapped. We
assume the gradients updates on individual elements inside a cluster
will be different so that there is no point of mapping the gradient
updates back to original weight matrix using the LUT.
"""
override_weights = tf.sign(tf.reshape(weights, shape=(-1,)) + 1e+6)
z = clst_weights*override_weights
def grad(dz):
return dz, dz
return z, grad

def get_clustered_weight(self, pulling_indices):
"""
Takes an array with integer number that represent lookup indices and forms a
Expand All @@ -75,9 +91,23 @@ def get_clustered_weight(self, pulling_indices):
return tf.reshape(
tf.gather(self.cluster_centroids,
tf.reshape(pulling_indices, shape=(-1,))),
pulling_indices.shape
shape=pulling_indices.shape
)

def get_clustered_weight_forward(self, pulling_indices, weight):
"""
Takes indices (pulling_indices) and original weights (weight) as inputs
and then forms a new array according to the given indices. The original
weights (weight) here are added to the graph since we want the backprop
to update their values via the new implementation using tf.custom_gradient
:param pulling_indices: an array of indices used for lookup.
:param weight: the original weights of the wrapped layer.
:return: array with the same shape as `pulling_indices`. Each array element
is a member of self.cluster_centroids
"""
x = tf.reshape(self.get_clustered_weight(pulling_indices), shape=(-1,))
return tf.reshape(self.add_custom_gradients(
x, tf.reshape(weight, shape=(-1,))), pulling_indices.shape)

class ConvolutionalWeightsCA(AbstractClusteringAlgorithm):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,48 @@ def _pull_values(self, ca, pulling_indices, expected_output):

self.assertSequenceEqual(res_np_list, expected_output)

def _check_gradients(self, ca, weight, pulling_indices, expected_output):
pulling_indices_tf = tf.convert_to_tensor(pulling_indices)
weight_tf = tf.convert_to_tensor(weight)
with tf.GradientTape(persistent=True) as t:
t.watch(pulling_indices_tf)
t.watch(weight_tf)
cls_weights_tf = tf.reshape(
ca.get_clustered_weight(pulling_indices_tf), shape=(-1,))
t.watch(cls_weights_tf)
out_forward = ca.add_custom_gradients(cls_weights_tf, weight_tf)
grad_cls_weight = t.gradient(out_forward, cls_weights_tf)
grad_weight = t.gradient(out_forward, weight_tf)

chk_output = tf.math.equal(grad_cls_weight, grad_weight)
chk_output_np = k.batch_get_value(chk_output)

self.assertSequenceEqual(chk_output_np, expected_output)

@parameterized.parameters(
([-0.800450444, 0.864694357],
[[0.220442653, 0.854694366, 0.0328432359, 0.506857157],
[0.0527950861, -0.659555554, -0.849919915, -0.54047],
[-0.305815876, 0.0865516588, 0.659202456, -0.355699599],
[-0.348868281, -0.662001, 0.6171574, -0.296582848]],
[[1, 1, 1, 1],
[1, 0, 0, 0],
[0, 1, 1, 0],
[0, 0, 1, 0]],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
)
)
def testDenseWeightsCAGrad(self,
clustering_centroids,
weight,
pulling_indices,
expected_output):
"""
Verifies that the gradients of DenseWeightsCA work as expected.
"""
ca = clustering_registry.DenseWeightsCA(clustering_centroids)
self._check_gradients(ca, weight, pulling_indices, expected_output)

@parameterized.parameters(
([-1, 1], [[0, 0, 1], [1, 1, 1]], [[-1, -1, 1], [1, 1, 1]]),
([-1, 0, 1], [[1, 1, 1], [1, 1, 1]], [[0, 0, 0], [0, 0, 0]]),
Expand Down Expand Up @@ -73,6 +115,29 @@ def testBiasWeightsCA(self,
ca = clustering_registry.BiasWeightsCA(clustering_centroids)
self._pull_values(ca, pulling_indices, expected_output)

@parameterized.parameters(
([0.0, 3.0],
[[0.1, 0.1, 0.1],
[3.0, 3.0, 3.0],
[0.2, 0.2, 0.2]],
[[0, 0, 0],
[1, 1, 1],
[0, 0, 0]],
[1, 1, 1, 1, 1, 1, 1, 1, 1]
)
)
def testConvolutionalWeightsCAGrad(self,
clustering_centroids,
weight,
pulling_indices,
expected_output):
"""
Verifies that the gradients of ConvolutionalWeightsCA work as expected.
"""
ca = clustering_registry.DenseWeightsCA(clustering_centroids)
self._check_gradients(ca, weight, pulling_indices, expected_output)


@parameterized.parameters(
([0, 3], [[[[0, 0, 0], [1, 1, 1], [0, 0, 0]]]],
[[[[0, 0, 0], [3, 3, 3], [0, 0, 0]]]]),
Expand Down