Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

ENH: KMeans tolerance parameter renamed tol (as in coordinate descent…

…) and made public
  • Loading branch information...
commit 544e5317a0d858e68946e4777fd0112f06884823 1 parent 09c53ae
Olivier Grisel ogrisel authored
Showing with 14 additions and 10 deletions.
  1. +14 −10 scikits/learn/cluster/k_means_.py
24 scikits/learn/cluster/k_means_.py
View
@@ -82,7 +82,7 @@ def k_init(X, k, n_samples_max=500, rng=None):
# K-means estimation by EM (expectation maximisation)
def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
- delta=1e-4, rng=None, copy_x=True):
+ tol=1e-4, rng=None, copy_x=True):
""" K-means clustering algorithm.
Parameters
@@ -118,7 +118,7 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
If an ndarray is passed, it should be of shape (k, p) and gives
the initial centers.
- delta: float, optional
+ tol: float, optional
The relative increment in the results before declaring convergence.
verbose: boolean, optional
@@ -189,8 +189,8 @@ def k_means(X, k, init='k-means++', n_init=10, max_iter=300, verbose=0,
labels, inertia = _e_step(X, centers)
centers = _m_step(X, labels, k)
if verbose:
- print 'Iteration %i, intertia %s' % (i, inertia)
- if np.sum((centers_old - centers) ** 2) < delta * vdata:
+ print 'Iteration %i, inertia %s' % (i, inertia)
+ if np.sum((centers_old - centers) ** 2) < tol * vdata:
if verbose:
print 'Converged to similar centers at iteration', i
break
@@ -319,6 +319,9 @@ class KMeans(BaseEstimator):
'matrix': interpret the k parameter as a k by M (or length k
array for one-dimensional data) array of initial centroids.
+ tol: float, optional default: 1e-4
+ Relative tolerance w.r.t. inertia to declare convergence
+
Methods
-------
@@ -355,23 +358,24 @@ class KMeans(BaseEstimator):
it can be useful to restart it several times.
"""
- def __init__(self, k=8, init='random', n_init=10, max_iter=300,
+ def __init__(self, k=8, init='random', n_init=10, max_iter=300, tol=1e-4,
verbose=0, rng=None, copy_x=True):
self.k = k
self.init = init
self.max_iter = max_iter
+ self.tol = tol
self.n_init = n_init
self.verbose = verbose
self.rng = rng
self.copy_x = copy_x
def fit(self, X, **params):
- """ Compute k-means"""
+ """Compute k-means"""
X = np.asanyarray(X)
self._set_params(**params)
- self.cluster_centers_, self.labels_, self.inertia_ = k_means(X,
- k=self.k, init=self.init, n_init=self.n_init,
- max_iter=self.max_iter, verbose=self.verbose,
- rng=self.rng, copy_x=self.copy_x)
+ self.cluster_centers_, self.labels_, self.inertia_ = k_means(
+ X, k=self.k, init=self.init, n_init=self.n_init,
+ max_iter=self.max_iter, verbose=self.verbose,
+ tol=self.tol, rng=self.rng, copy_x=self.copy_x)
return self
Please sign in to comment.
Something went wrong with that request. Please try again.