Skip to content

Commit

Permalink
[MRG+1] Adding support for sample weights to K-Means (#10933)
Browse files Browse the repository at this point in the history
  • Loading branch information
jnhansen authored and TomDLT committed May 22, 2018
1 parent 399f1b2 commit 4b24fbe
Show file tree
Hide file tree
Showing 6 changed files with 354 additions and 133 deletions.
6 changes: 6 additions & 0 deletions doc/modules/clustering.rst
Expand Up @@ -195,6 +195,12 @@ k-means++ initialization scheme, which has been implemented in scikit-learn
(generally) distant from each other, leading to provably better results than
random initialization, as shown in the reference.

The algorithm supports sample weights, which can be given by a parameter
``sample_weight``. This allows to assign more weight to some samples when
computing cluster centers and values of inertia. For example, assigning a
weight of 2 to a sample is equivalent to adding a duplicate of that sample
to the dataset :math:`X`.

A parameter can be given to allow K-means to be run in parallel, called
``n_jobs``. Giving this parameter a positive value uses that many processors
(default: 1). A value of -1 uses all available processors, with -2 using one
Expand Down
98 changes: 59 additions & 39 deletions sklearn/cluster/_k_means.pyx
@@ -1,6 +1,6 @@
# cython: profile=True
# Profiling is enabled by default as the overhead does not seem to be measurable
# on this specific use case.
# Profiling is enabled by default as the overhead does not seem to be
# measurable on this specific use case.

# Author: Peter Prettenhofer <peter.prettenhofer@gmail.com>
# Olivier Grisel <olivier.grisel@ensta.org>
Expand Down Expand Up @@ -34,6 +34,7 @@ np.import_array()
@cython.wraparound(False)
@cython.cdivision(True)
cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
np.ndarray[floating, ndim=1] sample_weight,
np.ndarray[floating, ndim=1] x_squared_norms,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] labels,
Expand Down Expand Up @@ -89,6 +90,7 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
dist *= -2
dist += center_squared_norms[center_idx]
dist += x_squared_norms[sample_idx]
dist *= sample_weight[sample_idx]
if min_dist == -1 or dist < min_dist:
min_dist = dist
labels[sample_idx] = center_idx
Expand All @@ -103,7 +105,8 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] sample_weight,
np.ndarray[DOUBLE, ndim=1] x_squared_norms,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] labels,
np.ndarray[floating, ndim=1] distances):
Expand Down Expand Up @@ -141,7 +144,8 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,

for center_idx in range(n_clusters):
center_squared_norms[center_idx] = dot(
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
n_features, &centers[center_idx, 0], 1,
&centers[center_idx, 0], 1)

for sample_idx in range(n_samples):
min_dist = -1
Expand All @@ -154,6 +158,7 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
dist *= -2
dist += center_squared_norms[center_idx]
dist += x_squared_norms[sample_idx]
dist *= sample_weight[sample_idx]
if min_dist == -1 or dist < min_dist:
min_dist = dist
labels[sample_idx] = center_idx
Expand All @@ -167,9 +172,10 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] sample_weight,
np.ndarray[DOUBLE, ndim=1] x_squared_norms,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] counts,
np.ndarray[floating, ndim=1] weight_sums,
np.ndarray[INT, ndim=1] nearest_center,
np.ndarray[floating, ndim=1] old_center,
int compute_squared_diff):
Expand All @@ -192,7 +198,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
-------
inertia : float
The inertia of the batch prior to centers update, i.e. the sum
of squared distances to the closest center for each sample. This
of squared distances to the closest center for each sample. This
is the objective function being minimized by the k-means algorithm.
squared_diff : float
Expand All @@ -213,29 +219,29 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,

unsigned int sample_idx, center_idx, feature_idx
unsigned int k
int old_count, new_count
DOUBLE old_weight_sum, new_weight_sum
DOUBLE center_diff
DOUBLE squared_diff = 0.0

# move centers to the mean of both old and newly assigned samples
for center_idx in range(n_clusters):
old_count = counts[center_idx]
new_count = old_count
old_weight_sum = weight_sums[center_idx]
new_weight_sum = old_weight_sum

# count the number of samples assigned to this center
for sample_idx in range(n_samples):
if nearest_center[sample_idx] == center_idx:
new_count += 1
new_weight_sum += sample_weight[sample_idx]

if new_count == old_count:
if new_weight_sum == old_weight_sum:
# no new sample: leave this center as it stands
continue

# rescale the old center to reflect it previous accumulated weight
# with regards to the new data that will be incrementally contributed
if compute_squared_diff:
old_center[:] = centers[center_idx]
centers[center_idx] *= old_count
centers[center_idx] *= old_weight_sum

# iterate of over samples assigned to this cluster to move the center
# location by inplace summation
Expand All @@ -250,12 +256,12 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
centers[center_idx, X_indices[k]] += X_data[k]

# inplace rescale center with updated count
if new_count > old_count:
if new_weight_sum > old_weight_sum:
# update the count statistics for this center
counts[center_idx] = new_count
weight_sums[center_idx] = new_weight_sum

# re-scale the updated center with the total new counts
centers[center_idx] /= new_count
centers[center_idx] /= new_weight_sum

# update the incremental computation of the squared total
# centers position change
Expand All @@ -271,6 +277,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
@cython.wraparound(False)
@cython.cdivision(True)
def _centers_dense(np.ndarray[floating, ndim=2] X,
np.ndarray[floating, ndim=1] sample_weight,
np.ndarray[INT, ndim=1] labels, int n_clusters,
np.ndarray[floating, ndim=1] distances):
"""M step of the K-means EM algorithm
Expand All @@ -281,6 +288,9 @@ def _centers_dense(np.ndarray[floating, ndim=2] X,
----------
X : array-like, shape (n_samples, n_features)
sample_weight : array-like, shape (n_samples,)
The weights for each observation in X.
labels : array of integers, shape (n_samples)
Current label assignment
Expand All @@ -301,13 +311,16 @@ def _centers_dense(np.ndarray[floating, ndim=2] X,
n_features = X.shape[1]
cdef int i, j, c
cdef np.ndarray[floating, ndim=2] centers
if floating is float:
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
else:
centers = np.zeros((n_clusters, n_features), dtype=np.float64)
cdef np.ndarray[floating, ndim=1] weight_in_cluster

dtype = np.float32 if floating is float else np.float64
centers = np.zeros((n_clusters, n_features), dtype=dtype)
weight_in_cluster = np.zeros((n_clusters,), dtype=dtype)

n_samples_in_cluster = np.bincount(labels, minlength=n_clusters)
empty_clusters = np.where(n_samples_in_cluster == 0)[0]
for i in range(n_samples):
c = labels[i]
weight_in_cluster[c] += sample_weight[i]
empty_clusters = np.where(weight_in_cluster == 0)[0]
# maybe also relocate small clusters?

if len(empty_clusters):
Expand All @@ -316,23 +329,25 @@ def _centers_dense(np.ndarray[floating, ndim=2] X,

for i, cluster_id in enumerate(empty_clusters):
# XXX two relocated clusters could be close to each other
new_center = X[far_from_centers[i]]
far_index = far_from_centers[i]
new_center = X[far_index]
centers[cluster_id] = new_center
n_samples_in_cluster[cluster_id] = 1
weight_in_cluster[cluster_id] = sample_weight[far_index]

for i in range(n_samples):
for j in range(n_features):
centers[labels[i], j] += X[i, j]
centers[labels[i], j] += X[i, j] * sample_weight[i]

centers /= n_samples_in_cluster[:, np.newaxis]
centers /= weight_in_cluster[:, np.newaxis]

return centers


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
def _centers_sparse(X, np.ndarray[floating, ndim=1] sample_weight,
np.ndarray[INT, ndim=1] labels, n_clusters,
np.ndarray[floating, ndim=1] distances):
"""M step of the K-means EM algorithm
Expand All @@ -342,6 +357,9 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
----------
X : scipy.sparse.csr_matrix, shape (n_samples, n_features)
sample_weight : array-like, shape (n_samples,)
The weights for each observation in X.
labels : array of integers, shape (n_samples)
Current label assignment
Expand All @@ -356,7 +374,9 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
centers : array, shape (n_clusters, n_features)
The resulting centers
"""
cdef int n_features = X.shape[1]
cdef int n_samples, n_features
n_samples = X.shape[0]
n_features = X.shape[1]
cdef int curr_label

cdef np.ndarray[floating, ndim=1] data = X.data
Expand All @@ -365,17 +385,17 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,

cdef np.ndarray[floating, ndim=2, mode="c"] centers
cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \
np.bincount(labels, minlength=n_clusters)
cdef np.ndarray[floating, ndim=1] weight_in_cluster
dtype = np.float32 if floating is float else np.float64
centers = np.zeros((n_clusters, n_features), dtype=dtype)
weight_in_cluster = np.zeros((n_clusters,), dtype=dtype)
for i in range(n_samples):
c = labels[i]
weight_in_cluster[c] += sample_weight[i]
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \
np.where(n_samples_in_cluster == 0)[0]
np.where(weight_in_cluster == 0)[0]
cdef int n_empty_clusters = empty_clusters.shape[0]

if floating is float:
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
else:
centers = np.zeros((n_clusters, n_features), dtype=np.float64)

# maybe also relocate small clusters?

if n_empty_clusters > 0:
Expand All @@ -386,14 +406,14 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
assign_rows_csr(X, far_from_centers, empty_clusters, centers)

for i in range(n_empty_clusters):
n_samples_in_cluster[empty_clusters[i]] = 1
weight_in_cluster[empty_clusters[i]] = 1

for i in range(labels.shape[0]):
curr_label = labels[i]
for ind in range(indptr[i], indptr[i + 1]):
j = indices[ind]
centers[curr_label, j] += data[ind]
centers[curr_label, j] += data[ind] * sample_weight[i]

centers /= n_samples_in_cluster[:, np.newaxis]
centers /= weight_in_cluster[:, np.newaxis]

return centers
15 changes: 11 additions & 4 deletions sklearn/cluster/_k_means_elkan.pyx
Expand Up @@ -103,7 +103,9 @@ cdef update_labels_distances_inplace(
upper_bounds[sample] = d_c


def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_,
np.ndarray[floating, ndim=1, mode='c'] sample_weight,
int n_clusters,
np.ndarray[floating, ndim=2, mode='c'] init,
float tol=1e-4, int max_iter=30, verbose=False):
"""Run Elkan's k-means.
Expand All @@ -112,6 +114,9 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
----------
X_ : nd-array, shape (n_samples, n_features)
sample_weight : nd-array, shape (n_samples,)
The weights for each observation in X.
n_clusters : int
Number of clusters to find.
Expand All @@ -133,7 +138,7 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
else:
dtype = np.float64

#initialize
# initialize
cdef np.ndarray[floating, ndim=2, mode='c'] centers_ = init
cdef floating* centers_p = <floating*>centers_.data
cdef floating* X_p = <floating*>X_.data
Expand Down Expand Up @@ -219,7 +224,8 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
print("end inner loop")

# compute new centers
new_centers = _centers_dense(X_, labels_, n_clusters, upper_bounds_)
new_centers = _centers_dense(X_, sample_weight, labels_,
n_clusters, upper_bounds_)
bounds_tight[:] = 0

# compute distance each center moved
Expand All @@ -237,7 +243,8 @@ def k_means_elkan(np.ndarray[floating, ndim=2, mode='c'] X_, int n_clusters,
center_half_distances = euclidean_distances(centers_) / 2.
if verbose:
print('Iteration %i, inertia %s'
% (iteration, np.sum((X_ - centers_[labels]) ** 2)))
% (iteration, np.sum((X_ - centers_[labels]) ** 2 *
sample_weight[:,np.newaxis])))
center_shift_total = np.sum(center_shift)
if center_shift_total ** 2 < tol:
if verbose:
Expand Down

0 comments on commit 4b24fbe

Please sign in to comment.