Skip to content

Commit

Permalink
Adds support and tests for KMeans/MiniBatchKMeans to work with float3…
Browse files Browse the repository at this point in the history
…2 to save memory
  • Loading branch information
Sebastian Saeger authored and yenchenlin committed Jun 3, 2016
1 parent d161bfa commit 9567ef5
Show file tree
Hide file tree
Showing 5 changed files with 369 additions and 53 deletions.
10 changes: 9 additions & 1 deletion doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,12 @@ Enhancements
- The :func: `ignore_warnings` now accept a category argument to ignore only
the warnings of a specified type. By `Thierry Guillemot`_.

- :class:`cluster.KMeans` and :class:`cluster.MiniBatchKMeans` now works
with ``np.float32`` and ``np.float64`` input data without converting it.
This allows to reduce the memory consumption by using ``np.float32``.
(`#6430 <https://github.com/scikit-learn/scikit-learn/pull/6430>`_)
By `Sebastian Säger`_.

Bug fixes
.........

Expand Down Expand Up @@ -1693,7 +1699,7 @@ List of contributors for release 0.15 by number of commits.
* 4 Alexis Metaireau
* 4 Ignacio Rossi
* 4 Virgile Fritsch
* 4 Sebastian Saeger
* 4 Sebastian Säger
* 4 Ilambharathi Kanniah
* 4 sdenton4
* 4 Robert Layton
Expand Down Expand Up @@ -4189,3 +4195,5 @@ David Huard, Dave Morrill, Ed Schofield, Travis Oliphant, Pearu Peterson.
.. _Sears Merritt: https://github.com/merritts

.. _Wenhua Yang: https://github.com/geekoala

.. _Sebastian Säger: https://github.com/ssaeger
106 changes: 73 additions & 33 deletions sklearn/cluster/_k_means.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import numpy as np
import scipy.sparse as sp
cimport numpy as np
cimport cython
from cython cimport floating

from ..utils.extmath import norm
from sklearn.utils.sparsefuncs_fast import assign_rows_csr
Expand All @@ -23,18 +24,19 @@ ctypedef np.int32_t INT

cdef extern from "cblas.h":
double ddot "cblas_ddot"(int N, double *X, int incX, double *Y, int incY)
float sdot "cblas_sdot"(int N, float *X, int incX, float *Y, int incY)

np.import_array()


@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
np.ndarray[DOUBLE, ndim=1] x_squared_norms,
np.ndarray[DOUBLE, ndim=2] centers,
cpdef DOUBLE _assign_labels_array(np.ndarray[floating, ndim=2] X,
np.ndarray[floating, ndim=1] x_squared_norms,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] labels,
np.ndarray[DOUBLE, ndim=1] distances):
np.ndarray[floating, ndim=1] distances):
"""Compute label assignment and inertia for a dense array
Return the inertia (sum of squared distances to the centers).
Expand All @@ -43,33 +45,52 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
unsigned int n_clusters = centers.shape[0]
unsigned int n_features = centers.shape[1]
unsigned int n_samples = X.shape[0]
unsigned int x_stride = X.strides[1] / sizeof(DOUBLE)
unsigned int center_stride = centers.strides[1] / sizeof(DOUBLE)
unsigned int x_stride
unsigned int center_stride
unsigned int sample_idx, center_idx, feature_idx
unsigned int store_distances = 0
unsigned int k
np.ndarray[floating, ndim=1] center_squared_norms
# the following variables are always double cause make them floating
# does not save any memory, but makes the code much bigger
DOUBLE inertia = 0.0
DOUBLE min_dist
DOUBLE dist
np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros(
n_clusters, dtype=np.float64)

if floating is float:
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
x_stride = X.strides[1] / sizeof(float)
center_stride = centers.strides[1] / sizeof(float)
else:
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)
x_stride = X.strides[1] / sizeof(DOUBLE)
center_stride = centers.strides[1] / sizeof(DOUBLE)

if n_samples == distances.shape[0]:
store_distances = 1

for center_idx in range(n_clusters):
center_squared_norms[center_idx] = ddot(
n_features, &centers[center_idx, 0], center_stride,
&centers[center_idx, 0], center_stride)
if floating is float:
center_squared_norms[center_idx] = sdot(
n_features, &centers[center_idx, 0], center_stride,
&centers[center_idx, 0], center_stride)
else:
center_squared_norms[center_idx] = ddot(
n_features, &centers[center_idx, 0], center_stride,
&centers[center_idx, 0], center_stride)

for sample_idx in range(n_samples):
min_dist = -1
for center_idx in range(n_clusters):
dist = 0.0
# hardcoded: minimize euclidean distance to cluster center:
# ||a - b||^2 = ||a||^2 + ||b||^2 -2 <a, b>
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
&centers[center_idx, 0], center_stride)
if floating is float:
dist += sdot(n_features, &X[sample_idx, 0], x_stride,
&centers[center_idx, 0], center_stride)
else:
dist += ddot(n_features, &X[sample_idx, 0], x_stride,
&centers[center_idx, 0], center_stride)
dist *= -2
dist += center_squared_norms[center_idx]
dist += x_squared_norms[sample_idx]
Expand All @@ -87,16 +108,16 @@ cpdef DOUBLE _assign_labels_array(np.ndarray[DOUBLE, ndim=2] X,
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
np.ndarray[DOUBLE, ndim=2] centers,
cpdef DOUBLE _assign_labels_csr(X, np.ndarray[floating, ndim=1] x_squared_norms,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] labels,
np.ndarray[DOUBLE, ndim=1] distances):
np.ndarray[floating, ndim=1] distances):
"""Compute label assignment and inertia for a CSR input
Return the inertia (sum of squared distances to the centers).
"""
cdef:
np.ndarray[DOUBLE, ndim=1] X_data = X.data
np.ndarray[floating, ndim=1] X_data = X.data
np.ndarray[INT, ndim=1] X_indices = X.indices
np.ndarray[INT, ndim=1] X_indptr = X.indptr
unsigned int n_clusters = centers.shape[0]
Expand All @@ -105,18 +126,28 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
unsigned int store_distances = 0
unsigned int sample_idx, center_idx, feature_idx
unsigned int k
np.ndarray[floating, ndim=1] center_squared_norms
# the following variables are always double cause make them floating
# does not save any memory, but makes the code much bigger
DOUBLE inertia = 0.0
DOUBLE min_dist
DOUBLE dist
np.ndarray[DOUBLE, ndim=1] center_squared_norms = np.zeros(
n_clusters, dtype=np.float64)

if floating is float:
center_squared_norms = np.zeros(n_clusters, dtype=np.float32)
else:
center_squared_norms = np.zeros(n_clusters, dtype=np.float64)

if n_samples == distances.shape[0]:
store_distances = 1

for center_idx in range(n_clusters):
center_squared_norms[center_idx] = ddot(
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
if floating is float:
center_squared_norms[center_idx] = sdot(
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)
else:
center_squared_norms[center_idx] = ddot(
n_features, &centers[center_idx, 0], 1, &centers[center_idx, 0], 1)

for sample_idx in range(n_samples):
min_dist = -1
Expand All @@ -142,18 +173,18 @@ cpdef DOUBLE _assign_labels_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
np.ndarray[DOUBLE, ndim=2] centers,
def _mini_batch_update_csr(X, np.ndarray[floating, ndim=1] x_squared_norms,
np.ndarray[floating, ndim=2] centers,
np.ndarray[INT, ndim=1] counts,
np.ndarray[INT, ndim=1] nearest_center,
np.ndarray[DOUBLE, ndim=1] old_center,
np.ndarray[floating, ndim=1] old_center,
int compute_squared_diff):
"""Incremental update of the centers for sparse MiniBatchKMeans.
Parameters
----------
X: CSR matrix, dtype float64
X: CSR matrix, dtype float
The complete (pre allocated) training set as a CSR matrix.
centers: array, shape (n_clusters, n_features)
Expand All @@ -179,7 +210,7 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
of the algorithm.
"""
cdef:
np.ndarray[DOUBLE, ndim=1] X_data = X.data
np.ndarray[floating, ndim=1] X_data = X.data
np.ndarray[int, ndim=1] X_indices = X.indices
np.ndarray[int, ndim=1] X_indptr = X.indptr
unsigned int n_samples = X.shape[0]
Expand Down Expand Up @@ -245,9 +276,9 @@ def _mini_batch_update_csr(X, np.ndarray[DOUBLE, ndim=1] x_squared_norms,
@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
def _centers_dense(np.ndarray[floating, ndim=2] X,
np.ndarray[INT, ndim=1] labels, int n_clusters,
np.ndarray[DOUBLE, ndim=1] distances):
np.ndarray[floating, ndim=1] distances):
"""M step of the K-means EM algorithm
Computation of cluster centers / means.
Expand Down Expand Up @@ -275,7 +306,12 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
n_samples = X.shape[0]
n_features = X.shape[1]
cdef int i, j, c
cdef np.ndarray[DOUBLE, ndim=2] centers = np.zeros((n_clusters, n_features))
cdef np.ndarray[floating, ndim=2] centers
if floating is float:
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
else:
centers = np.zeros((n_clusters, n_features), dtype=np.float64)

n_samples_in_cluster = bincount(labels, minlength=n_clusters)
empty_clusters = np.where(n_samples_in_cluster == 0)[0]
# maybe also relocate small clusters?
Expand Down Expand Up @@ -303,7 +339,7 @@ def _centers_dense(np.ndarray[DOUBLE, ndim=2] X,
@cython.wraparound(False)
@cython.cdivision(True)
def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
np.ndarray[DOUBLE, ndim=1] distances):
np.ndarray[floating, ndim=1] distances):
"""M step of the K-means EM algorithm
Computation of cluster centers / means.
Expand All @@ -329,19 +365,23 @@ def _centers_sparse(X, np.ndarray[INT, ndim=1] labels, n_clusters,
cdef int n_features = X.shape[1]
cdef int curr_label

cdef np.ndarray[DOUBLE, ndim=1] data = X.data
cdef np.ndarray[floating, ndim=1] data = X.data
cdef np.ndarray[int, ndim=1] indices = X.indices
cdef np.ndarray[int, ndim=1] indptr = X.indptr

cdef np.ndarray[DOUBLE, ndim=2, mode="c"] centers = \
np.zeros((n_clusters, n_features))
cdef np.ndarray[floating, ndim=2, mode="c"] centers
cdef np.ndarray[np.npy_intp, ndim=1] far_from_centers
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] n_samples_in_cluster = \
bincount(labels, minlength=n_clusters)
cdef np.ndarray[np.npy_intp, ndim=1, mode="c"] empty_clusters = \
np.where(n_samples_in_cluster == 0)[0]
cdef int n_empty_clusters = empty_clusters.shape[0]

if floating is float:
centers = np.zeros((n_clusters, n_features), dtype=np.float32)
else:
centers = np.zeros((n_clusters, n_features), dtype=np.float64)

# maybe also relocate small clusters?

if n_empty_clusters > 0:
Expand Down

0 comments on commit 9567ef5

Please sign in to comment.