-
Notifications
You must be signed in to change notification settings - Fork 332
Description
Describe the bug
When using a layer with more than one clusterable weight, the model will not build.
System information
TensorFlow version: 2.6.0.dev20210330
TensorFlow Model Optimization version: 0.5.0.dev20210407
Python version: 3.7.10
Describe the expected behavior
The wrapped layer with more than one clusterable weight should build successfully and should be able to accept inputs.
Describe the current behavior
Currently the following error occurs:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Code to reproduce the issue
import tensorflow as tf
from tensorflow_model_optimization.clustering.keras import CentroidInitialization
from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry
class MyCustomClusterableLayer(tf.keras.layers.Dense, clusterable_layer.ClusterableLayer):
def get_clusterable_weights(self):
return [('kernel', self.kernel), ('bias', self.bias)]
my_layer = MyCustomClusterableLayer(16)
my_layer.build((32, 8))
clustering_registry.ClusteringLookupRegistry.register_new_implementation(
{
MyCustomClusterableLayer: {
'kernel': clustering_registry.DenseWeightsCA,
'bias': clustering_registry.BiasWeightsCA
}
}
)
params = {'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.LINEAR
}
wrapped_layer = cluster.cluster_weights(my_layer, **params)
wrapped_layer.build((32, 8)) # Error
Additional information
The Tensorflow comparison function works element-wise, so the .index()
function used within the ClusterWeights
class cannot directly be applied on a list of tensors. I believe this is what causes the error.