diff --git a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py index e87d04e0d..7869f8afe 100644 --- a/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py +++ b/tensorflow_model_optimization/python/examples/clustering/keras/mnist/mnist_cnn.py @@ -22,10 +22,11 @@ from absl import app as absl_app from absl import flags -import tensorflow.compat.v1 as tf -from tensorflow.python import keras +import tensorflow as tf from tensorflow_model_optimization.python.core.clustering.keras import cluster +from tensorflow_model_optimization.python.core.clustering.keras import cluster_config +keras = tf.keras l = keras.layers FLAGS = flags.FLAGS @@ -96,7 +97,7 @@ def train_and_save(models, x_train, y_train, x_test, y_test): clustering_params = { 'number_of_clusters': 8, - 'cluster_centroids_init': 'density-based' + 'cluster_centroids_init': cluster_config.CentroidInitialization.DENSITY_BASED } # Cluster model