Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def setUp(self):
self.custom_clusterable_layer = CustomClusterableLayer(10)
self.custom_non_clusterable_layer = CustomNonClusterableLayer(10)
self.keras_depthwiseconv2d_layer = layers.DepthwiseConv2D((3, 3), (1, 1))
self.clusterable_layer = MyClusterableLayer(10)
self.keras_custom_layer = KerasCustomLayer()
self.clusterable_layer = MyClusterableLayer(10)

Expand Down Expand Up @@ -242,6 +243,34 @@ def testClusterCustomNonClusterableLayer(self):
cluster_wrapper.ClusterWeights(custom_non_clusterable_layer,
**self.params)

def testClusterMyClusterableLayer(self):
# we have weights to cluster.
clusterable_layer = self.clusterable_layer
clusterable_layer.build(input_shape=(10, 10))

wrapped_layer = cluster_wrapper.ClusterWeights(clusterable_layer,
**self.params)

self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)

def testKerasCustomLayerClusterable(self):
"""
Verifies that we can wrap keras custom layer that is customerable.
"""
clusterable_layer = KerasCustomLayerClusterable()
wrapped_layer = cluster_wrapper.ClusterWeights(clusterable_layer,
**self.params)

self.assertIsInstance(wrapped_layer, cluster_wrapper.ClusterWeights)

def testClusterMyClusterableLayerInvalid(self):
"""
Verifies that assertion is thrown when function
get_clusterable_weights is not provided.
"""
with self.assertRaises(TypeError):
MyClusterableLayerInvalid(10) # pylint: disable=abstract-class-instantiated

def testClusterKerasCustomLayer(self):
"""Verifies that attempting to cluster a keras custom layer raises an exception."""
# If layer is not built, it has not weights, so
Expand Down Expand Up @@ -275,8 +304,7 @@ def testClusterMyClusterableLayerInvalid(self):
@keras_parameterized.run_all_keras_modes
def testClusterSequentialModelSelectively(self):
clustered_model = keras.Sequential()
clustered_model.add(
cluster.cluster_weights(self.keras_clusterable_layer, **self.params))
clustered_model.add(cluster.cluster_weights(self.keras_clusterable_layer, **self.params))
clustered_model.add(self.keras_clusterable_layer)
clustered_model.build(input_shape=(1, 10))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

@six.add_metaclass(abc.ABCMeta)
class AbstractClusteringAlgorithm(object):
"""Abstrac class to implement highly efficient vectorised look-ups.
"""Abstract class to implement highly efficient vectorised look-ups.

We do not utilise looping for that purpose, instead we `smartly` reshape and
tile arrays. The trade-off is that we are potentially using way more memory
Expand Down