Skip to content

Commit

Permalink
KMeans.training_graph() now returns an additional value, currently un…
Browse files Browse the repository at this point in the history
…used.

PiperOrigin-RevId: 170083271
  • Loading branch information
tensorflower-gardener committed Sep 26, 2017
1 parent 272a2c8 commit 26928c6
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion tensorflow/contrib/factorization/examples/mnist.py
Expand Up @@ -142,7 +142,7 @@ def inference(inp, num_clusters, hidden1_units, hidden2_units):
# initial_clusters=tf.contrib.factorization.KMEANS_PLUS_PLUS_INIT,
use_mini_batch=True)

(all_scores, _, clustering_scores, _, kmeans_init,
(all_scores, _, clustering_scores, _, _, kmeans_init,
kmeans_training_op) = kmeans.training_graph()
# Some heuristics to approximately whiten this output.
all_scores = (all_scores[0] - 0.5) * 5
Expand Down
Expand Up @@ -337,6 +337,7 @@ def training_graph(self):
assigned cluster instead.
cluster_centers_initialized: scalar indicating whether clusters have been
initialized.
cluster_centers_var: a Variable holding the cluster centers.
init_op: an op to initialize the clusters.
training_op: an op that runs an iteration of training.
"""
Expand Down Expand Up @@ -380,7 +381,7 @@ def training_graph(self):
inputs, num_clusters, cluster_idx, cluster_centers_var)

return (all_scores, cluster_idx, scores, cluster_centers_initialized,
init_op, training_op)
cluster_centers_var, init_op, training_op)

def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
cluster_centers_updated, total_counts):
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/contrib/learn/python/learn/estimators/kmeans.py
Expand Up @@ -106,7 +106,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config):
"""Model function for KMeansClustering estimator."""
assert labels is None, labels
(all_scores, model_predictions, losses,
is_initialized, init_op, training_op) = clustering_ops.KMeans(
is_initialized, _, init_op, training_op) = clustering_ops.KMeans(
_parse_tensor_or_dict(features),
params.get('num_clusters'),
initial_clusters=params.get('training_initial_clusters'),
Expand Down

0 comments on commit 26928c6

Please sign in to comment.