Skip to content

Commit

Permalink
Merge branch 'master' into conv3d_bug_clustering
Browse files Browse the repository at this point in the history
  • Loading branch information
wwwind committed Apr 23, 2021
2 parents d2a916d + 006e377 commit 3cc9cf7
Show file tree
Hide file tree
Showing 51 changed files with 1,212 additions and 1,642 deletions.
2 changes: 1 addition & 1 deletion ci/kokoro/gcp_ubuntu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# An image for building prerequisites for testing TensorFlow Model Optimization
# on ubuntu.
#
# TODO(tfmot): generalize to different versions of TensorFlow to
# TODO(b/185727356): generalize to different versions of TensorFlow to
# run CI against.
#
# Build as follows:
Expand Down
2 changes: 1 addition & 1 deletion ci/kokoro/gcp_ubuntu/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# be used to reproduce errors locally by modifying WORKDIR to be
# the top-level directory of the checked out TFMOT Github repository.

# TODO(tfmot): switch to prebuilt Docker image to speed this up.
# TODO(b/185727163): switch to prebuilt Docker image to speed this up.

# Fail on any error.
set -e
Expand Down
3 changes: 3 additions & 0 deletions tensorflow_model_optimization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def _ensure_tf_install(): # pylint: disable=g-statement-before-imports

# To ensure users only access the expected public API, the API structure is
# created in the `api` directory. Import all api modules.
from tensorflow_model_optimization.python.core import version
# pylint: disable=wildcard-import
from tensorflow_model_optimization.python.core.api import *
# pylint: enable=wildcard-import
Expand Down Expand Up @@ -110,3 +111,5 @@ def _ensure_tf_install(): # pylint: disable=g-statement-before-imports
except NameError:
pass
# pylint: enable=undefined-variable

__version__ = version.__version__
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ For example, here is how to specify 8 bit integer weight quantization:
```
import tensorflow as tf
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
converter.optimizations = [tf.lite.Optimize.OPTIMIZE_FOR_SIZE]
converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant_model = converter.convert()
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
# tensorflow dep1,
"//tensorflow_model_optimization/python/core/clustering/keras:cluster_config",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -293,36 +293,25 @@ def _strip_clustering_wrapper(layer):
if isinstance(layer, keras.Model):
return keras.models.clone_model(
layer, input_tensors=None, clone_function=_strip_clustering_wrapper)

elif isinstance(layer, cluster_wrapper.ClusterWeights):
if not hasattr(layer.layer, '_batch_input_shape') and\
hasattr(layer, '_batch_input_shape'):
# pylint: disable=protected-access
layer.layer._batch_input_shape = layer._batch_input_shape
# Update cluster associations in order to get the latest weights
layer.update_clustered_weights_associations()

# Construct a list of weights to initialize the clean layer
updated_weights = layer.layer.get_weights() # non clusterable weights only
for position_variable, weight_name in layer.position_original_weights.items():
# Add the clustered weights at the correct position
clustered_weight = getattr(layer.layer, weight_name)
updated_weights.insert(position_variable, clustered_weight)

# Construct a clean layer with the updated weights
clean_layer = layer.layer.from_config(layer.layer.get_config())
clean_layer.build(layer.build_input_shape)
clean_layer.set_weights(updated_weights)

return clean_layer

# We reset both arrays of weights, so that we can guarantee the correct
# order of newly created weights
# pylint: disable=protected-access
layer.layer._trainable_weights = []
layer.layer._non_trainable_weights = []
for i in range(len(layer.restore)):
# This is why we used integers as keys
name, weight_name, weight = layer.restore[i]
# In both cases we use k.batch_get_value since we need physical copies
# of the arrays to initialize a new tensor
if i in layer.gone_variables:
# If the variable was removed because it was clustered, we restore it
# by using updater we created earlier
new_weight_value = k.batch_get_value([weight()])[0]
else:
# If the value was not clustered(e.g. bias), we still store a valid
# reference to the tensor. We use this reference to get the value
new_weight_value = k.batch_get_value([weight])[0]
setattr(layer.layer,
name,
k.variable(new_weight_value, name=weight_name))
# When all weights are filled with the values, just return the underlying
# layer since it is now fully autonomous from its wrapper
return layer.layer
return layer

# Just copy the model with the right callback
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,19 @@ class CentroidInitialization(str, enum.Enum):
initialize the clusters centroids.
* `KMEANS_PLUS_PLUS`: cluster centroids using the kmeans++ algorithm
"""
LINEAR = "LINEAR"
RANDOM = "RANDOM"
DENSITY_BASED = "DENSITY_BASED"
KMEANS_PLUS_PLUS = "KMEANS_PLUS_PLUS"
LINEAR = "CentroidInitialization.LINEAR"
RANDOM = "CentroidInitialization.RANDOM"
DENSITY_BASED = "CentroidInitialization.DENSITY_BASED"
KMEANS_PLUS_PLUS = "CentroidInitialization.KMEANS_PLUS_PLUS"


class GradientAggregation(str, enum.Enum):
"""Specifies how the cluster gradient should be aggregated.
* `SUM`: The gradient of each cluster centroid is the sum of their
respective child’s weight gradient.
* `AVG`: The gradient of each cluster centroid is the averaged sum of
their respective child’s weight gradient.
"""
SUM = "GradientAggregation.SUM"
AVG = "GradientAggregation.AVG"
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def setUp(self):
"cluster_centroids_init": CentroidInitialization.LINEAR
}


@parameterized.parameters(_distribution_strategies())
def testClusterSimpleDenseModel(self, distribution):
"""End-to-end test."""
Expand All @@ -64,7 +63,7 @@ def testClusterSimpleDenseModel(self, distribution):
model.predict(np.random.rand(20, 10))

stripped_model = cluster.strip_clustering(model)
weights_as_list = stripped_model.get_weights()[0].reshape(-1,).tolist()
weights_as_list = stripped_model.layers[0].kernel.numpy().reshape(-1,).tolist()
unique_weights = set(weights_as_list)
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])

Expand All @@ -87,7 +86,7 @@ def testAssociationValuesPerReplica(self, distribution):
self.assertEqual(len(clusterable_weights), 1)
weights_name = clusterable_weights[0][0]
self.assertEqual(weights_name, 'kernel')
centroids1 = l.cluster_centroids_tf[weights_name]
centroids1 = l.cluster_centroids[weights_name]

mean_weight = tf.reduce_mean(l.layer.kernel)
min_weight = tf.reduce_min(l.layer.kernel)
Expand Down Expand Up @@ -119,18 +118,18 @@ def update_fn(v, val):
centroids1, update_fn, args=(initial_val,))
l.call(tf.ones(shape=input_shape))

clst_indices = l.pulling_indices_tf[weights_name]
clst_indices = l.pulling_indices[weights_name]
per_replica = distribution.experimental_local_results(clst_indices)
assert_all_cluster_indices(per_replica, 0)

second_val = tf.Variable([mean_weight - 2.0 * max_dist, mean_weight], \
aggregation=tf.VariableAggregation.MEAN)
centroids2 = l.cluster_centroids_tf[weights_name]
centroids2 = l.cluster_centroids[weights_name]
centroids2 = distribution.extended.update(
centroids2, update_fn, args=(second_val,))
l.call(tf.ones(shape=input_shape))

clst_indices = l.pulling_indices_tf[weights_name]
clst_indices = l.pulling_indices[weights_name]
per_replica = distribution.experimental_local_results(clst_indices)
assert_all_cluster_indices(per_replica, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def testSparsityIsPreservedDuringTraining(self):
original_model, **clustering_params)

stripped_model_before_tuning = cluster.strip_clustering(clustered_model)
weights_before_tuning = stripped_model_before_tuning.get_weights()[0]
weights_before_tuning = stripped_model_before_tuning.layers[0].kernel
non_zero_weight_indices_before_tuning = np.nonzero(weights_before_tuning)

clustered_model.compile(
Expand All @@ -185,9 +185,9 @@ def testSparsityIsPreservedDuringTraining(self):
clustered_model.fit(x=self.dataset_generator2(), steps_per_epoch=1)

stripped_model_after_tuning = cluster.strip_clustering(clustered_model)
weights_after_tuning = stripped_model_after_tuning.get_weights()[0]
weights_after_tuning = stripped_model_after_tuning.layers[0].kernel
non_zero_weight_indices_after_tuning = np.nonzero(weights_after_tuning)
weights_as_list_after_tuning = weights_after_tuning.reshape(-1,).tolist()
weights_as_list_after_tuning = weights_after_tuning.numpy().reshape(-1,).tolist()
unique_weights_after_tuning = set(weights_as_list_after_tuning)

# Check that the null weights stayed the same before and after tuning.
Expand Down Expand Up @@ -299,6 +299,54 @@ def clusters_check(stripped_model):

self.end_to_end_testing(original_model, clusters_check)

@keras_parameterized.run_all_keras_modes
def testWeightsAreLearningDuringClustering(self):
"""Verifies that training a clustered model does update
original_weights, clustered_centroids and bias."""
original_model = keras.Sequential([
layers.Dense(5, input_shape=(5,))
])

clustered_model = cluster.cluster_weights(original_model, **self.params)

clustered_model.compile(
loss=keras.losses.categorical_crossentropy,
optimizer="adam",
metrics=["accuracy"],
)

class CheckWeightsCallback(keras.callbacks.Callback):
def on_train_batch_begin(self, batch, logs=None):
# Save weights before batch
self.original_weight_kernel = (
self.model.layers[0].original_clusterable_weights['kernel'].numpy()
)
self.cluster_centroids_kernel = (
self.model.layers[0].cluster_centroids['kernel'].numpy()
)
self.bias = (
self.model.layers[0].layer.bias.numpy()
)

def on_train_batch_end(self, batch, logs=None):
# Check weights are different after batch
assert not np.array_equal(
self.original_weight_kernel,
self.model.layers[0].original_clusterable_weights['kernel'].numpy()
)
assert not np.array_equal(
self.cluster_centroids_kernel,
self.model.layers[0].cluster_centroids['kernel'].numpy()
)
assert not np.array_equal(
self.bias,
self.model.layers[0].layer.bias.numpy()
)

clustered_model.fit(x=self.dataset_generator(),
steps_per_epoch=5,
callbacks=[CheckWeightsCallback()])


if __name__ == "__main__":
test.main()
Loading

0 comments on commit 3cc9cf7

Please sign in to comment.