Skip to content

Clustering wrapper will not build with more than one clusterable weight #659

@metinsuloglu

Description

@metinsuloglu

Describe the bug
When using a layer with more than one clusterable weight, the model will not build.

System information

TensorFlow version: 2.6.0.dev20210330

TensorFlow Model Optimization version: 0.5.0.dev20210407

Python version: 3.7.10

Describe the expected behavior
The wrapped layer with more than one clusterable weight should build successfully and should be able to accept inputs.

Describe the current behavior
Currently the following error occurs:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Code to reproduce the issue

import tensorflow as tf

from tensorflow_model_optimization.clustering.keras import CentroidInitialization
from tensorflow_model_optimization.python.core.clustering.keras import cluster
from tensorflow_model_optimization.python.core.clustering.keras import clusterable_layer
from tensorflow_model_optimization.python.core.clustering.keras import clustering_registry

class MyCustomClusterableLayer(tf.keras.layers.Dense, clusterable_layer.ClusterableLayer):
    def get_clusterable_weights(self):
        return [('kernel', self.kernel), ('bias', self.bias)]

my_layer = MyCustomClusterableLayer(16)
my_layer.build((32, 8))

clustering_registry.ClusteringLookupRegistry.register_new_implementation(
    {
     MyCustomClusterableLayer: {
         'kernel': clustering_registry.DenseWeightsCA,
         'bias': clustering_registry.BiasWeightsCA
         }
     }
    )

params = {'number_of_clusters': 8,
          'cluster_centroids_init': CentroidInitialization.LINEAR
          }

wrapped_layer = cluster.cluster_weights(my_layer, **params)
wrapped_layer.build((32, 8))  # Error

Additional information
The Tensorflow comparison function works element-wise, so the .index() function used within the ClusterWeights class cannot directly be applied on a list of tensors. I believe this is what causes the error.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtechnique:clusteringRegarding tfmot.clustering.keras APIs and docs

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions