diff --git a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py index 37e3681a5..b5d47b447 100644 --- a/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py +++ b/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py @@ -62,6 +62,9 @@ def __init__(self, 'Please initialize `Cluster` layer with a ' '`Layer` instance. You passed: {input}'.format(input=layer)) + if 'name' not in kwargs: + kwargs['name'] = self._make_layer_name(layer) + if isinstance(layer, clusterable_layer.ClusterableLayer): # A user-defined custom layer super(ClusterWeights, self).__init__(layer, **kwargs) @@ -133,6 +136,10 @@ def __init__(self, and hasattr(layer, '_batch_input_shape'): self._batch_input_shape = self.layer._batch_input_shape + @staticmethod + def _make_layer_name(layer): + return '{}_{}'.format('cluster', layer.name) + @staticmethod def _weight_name(name): """Extracts the weight name from the full TensorFlow variable name.