Skip to content

Commit

Permalink
Use the new Estimator.get_variable_value() method to get the kmeans c…
Browse files Browse the repository at this point in the history
…luster centers.

PiperOrigin-RevId: 171320755
  • Loading branch information
tensorflower-gardener committed Oct 6, 2017
1 parent 7fceb8d commit 3110185
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 29 deletions.
2 changes: 1 addition & 1 deletion tensorflow/contrib/factorization/examples/mnist.py
Expand Up @@ -142,7 +142,7 @@ def inference(inp, num_clusters, hidden1_units, hidden2_units):
# initial_clusters=tf.contrib.factorization.KMEANS_PLUS_PLUS_INIT,
use_mini_batch=True)

(all_scores, _, clustering_scores, _, _, kmeans_init,
(all_scores, _, clustering_scores, _, kmeans_init,
kmeans_training_op) = kmeans.training_graph()
# Some heuristics to approximately whiten this output.
all_scores = (all_scores[0] - 0.5) * 5
Expand Down
8 changes: 5 additions & 3 deletions tensorflow/contrib/factorization/python/ops/clustering_ops.py
Expand Up @@ -51,6 +51,9 @@
RANDOM_INIT = 'random'
KMEANS_PLUS_PLUS_INIT = 'kmeans_plus_plus'

# The name of the variable holding the cluster centers. Used by the Estimator.
CLUSTERS_VAR_NAME = 'clusters'


class KMeans(object):
"""Creates the graph for k-means clustering."""
Expand Down Expand Up @@ -279,7 +282,7 @@ def _create_variables(self, num_clusters):
"""
init_value = array_ops.constant([], dtype=dtypes.float32)
cluster_centers = variable_scope.variable(
init_value, name='clusters', validate_shape=False)
init_value, name=CLUSTERS_VAR_NAME, validate_shape=False)
cluster_centers_initialized = variable_scope.variable(
False, dtype=dtypes.bool, name='initialized')

Expand Down Expand Up @@ -337,7 +340,6 @@ def training_graph(self):
assigned cluster instead.
cluster_centers_initialized: scalar indicating whether clusters have been
initialized.
cluster_centers_var: a Variable holding the cluster centers.
init_op: an op to initialize the clusters.
training_op: an op that runs an iteration of training.
"""
Expand Down Expand Up @@ -381,7 +383,7 @@ def training_graph(self):
inputs, num_clusters, cluster_idx, cluster_centers_var)

return (all_scores, cluster_idx, scores, cluster_centers_initialized,
cluster_centers_var, init_op, training_op)
init_op, training_op)

def _mini_batch_sync_updates_op(self, update_in_steps, cluster_centers_var,
cluster_centers_updated, total_counts):
Expand Down
28 changes: 4 additions & 24 deletions tensorflow/contrib/factorization/python/ops/kmeans.py
Expand Up @@ -21,12 +21,10 @@
from __future__ import print_function

import time
import numpy as np

from tensorflow.contrib.factorization.python.ops import clustering_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
Expand Down Expand Up @@ -161,8 +159,7 @@ def model_fn(self, features, mode, config):
* `eval_metric_ops`: Maps `SCORE` to `loss`.
* `predictions`: Maps `ALL_DISTANCES` to the distance from each input
point to each cluster center; maps `CLUSTER_INDEX` to the index of
the closest cluster center for each input point; maps `CLUSTERS` to
the cluster centers (which ignores the input points).
the closest cluster center for each input point.
"""
# input_points is a single Tensor. Therefore, the sharding functionality
# in clustering_ops is unused, and some of the values below are lists of a
Expand All @@ -184,8 +181,8 @@ def model_fn(self, features, mode, config):
# training_op: an op that runs an iteration of training, either an entire
# Lloyd iteration or a mini-batch of a Lloyd iteration. Multiple workers
# may execute this op, but only after is_initialized becomes True.
(all_distances, model_predictions, losses, is_initialized,
cluster_centers_var, init_op, training_op) = clustering_ops.KMeans(
(all_distances, model_predictions, losses, is_initialized, init_op,
training_op) = clustering_ops.KMeans(
inputs=input_points,
num_clusters=self._num_clusters,
initial_clusters=self._initial_clusters,
Expand Down Expand Up @@ -215,7 +212,6 @@ def model_fn(self, features, mode, config):
predictions={
KMeansClustering.ALL_DISTANCES: all_distances[0],
KMeansClustering.CLUSTER_INDEX: model_predictions[0],
KMeansClustering.CLUSTERS: cluster_centers_var.value(),
},
loss=loss,
train_op=training_op,
Expand All @@ -242,9 +238,7 @@ class KMeansClustering(estimator.Estimator):
# Keys returned by predict().
# ALL_DISTANCES: The distance from each input point to each cluster center.
# CLUSTER_INDEX: The index of the closest cluster center for each input point.
# CLUSTERS: The cluster centers (which ignores the input points).
CLUSTER_INDEX = 'cluster_index'
CLUSTERS = 'clusters'
ALL_DISTANCES = 'all_distances'

def __init__(self,
Expand Down Expand Up @@ -400,18 +394,4 @@ def transform(self, input_fn):

def cluster_centers(self):
"""Returns the cluster centers."""

# TODO(ccolby): Fix this clunky code once cl/168262087 is submitted.
# Discussion: go/estimator-get-variable-value
class RunOnceHook(session_run_hook.SessionRunHook):
"""Stops after a single run."""

def after_run(self, run_context, run_values):
del run_values # unused
run_context.request_stop()

result = self.predict(
input_fn=lambda: (constant_op.constant([], shape=[0, 1]), None),
predict_keys=[KMeansClustering.CLUSTERS],
hooks=[RunOnceHook()])
return np.array([r[KMeansClustering.CLUSTERS] for r in result])
return self.get_variable_value(clustering_ops.CLUSTERS_VAR_NAME)
2 changes: 1 addition & 1 deletion tensorflow/contrib/learn/python/learn/estimators/kmeans.py
Expand Up @@ -106,7 +106,7 @@ def _kmeans_clustering_model_fn(features, labels, mode, params, config):
"""Model function for KMeansClustering estimator."""
assert labels is None, labels
(all_scores, model_predictions, losses,
is_initialized, _, init_op, training_op) = clustering_ops.KMeans(
is_initialized, init_op, training_op) = clustering_ops.KMeans(
_parse_tensor_or_dict(features),
params.get('num_clusters'),
initial_clusters=params.get('training_initial_clusters'),
Expand Down

1 comment on commit 3110185

@IliasIB
Copy link

@IliasIB IliasIB commented on 3110185 May 3, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if it's appropriate to comment on an older commit, but I don't know where else to put this. This change is not reflected in the documentation, which means it is almost impossible (except for this commit), to find how you can get the cluster centers.

Please sign in to comment.