diff --git a/doc/whats_new/v0.21.rst b/doc/whats_new/v0.21.rst index 0525a40467e74..579a7ff492b94 100644 --- a/doc/whats_new/v0.21.rst +++ b/doc/whats_new/v0.21.rst @@ -315,8 +315,15 @@ Support for Python 3.4 and below has been officially dropped. - |API| The parameter ``labels`` in :func:`metrics.hamming_loss` is deprecated in version 0.21 and will be removed in version 0.23. - :issue:`10580` by :user:`Reshama Shaikh ` and `Sandra - Mitrovic `. + :issue:`10580` by :user:`Reshama Shaikh ` and + `Sandra Mitrovic `. + +- |Fix| The function :func:`euclidean_distances`, and therefore the functions + :func:`pairwise_distances` and :func:`pairwise_distances_chunked` with + ``metric=euclidean``, suffered from numerical precision issues. Precision has + been increased at the cost of a small drop in performance in some cases. + :issue:`13410` by :user:`Celelibi ` and + :user:`Jérémie du Boisberranger ` :mod:`sklearn.mixture` ...................... diff --git a/sklearn/metrics/_safe_euclidean_sparse.pyx b/sklearn/metrics/_safe_euclidean_sparse.pyx new file mode 100644 index 0000000000000..2085508f66611 --- /dev/null +++ b/sklearn/metrics/_safe_euclidean_sparse.pyx @@ -0,0 +1,70 @@ +#cython: language_level=3 +#cython: boundscheck=False, cdivision=True, wraparound=False + +import numpy as np +cimport numpy as np +from cython cimport floating +from libc.math cimport fmax + +np.import_array() + +ctypedef fused INT: + np.int32_t + np.int64_t + + +def _euclidean_sparse_dense_exact(floating[::1] X_data, + INT[::1] X_indices, + INT[::1] X_indptr, + np.ndarray[floating, ndim=2] Y, + floating[::1] y_squared_norms): + """Euclidean distances between X (CSR matrix) and Y (dense).""" + cdef: + int n_samples_X = X_indptr.shape[0] - 1 + int n_samples_Y = Y.shape[0] + int n_features = Y.shape[1] + int incy = Y.strides[1] / Y.itemsize + + int i, j + + floating[:, ::1] D = np.empty((n_samples_X, n_samples_Y), Y.dtype) + + for i in range(n_samples_X): + for j in range(n_samples_Y): + D[i, j] = _euclidean_sparse_dense_exact_1d( + &X_data[X_indptr[i]], + &X_indices[X_indptr[i]], + X_indptr[i + 1] - X_indptr[i], + &Y[j, 0], + incy, + y_squared_norms[j]) + + return np.asarray(D) + + +cdef floating _euclidean_sparse_dense_exact_1d(floating *x_data, + INT *x_indices, + int x_nnz, + floating *y, + int incy, + floating y_squared_norm) nogil: + """Euclidean distance between vectors x sparse and y dense""" + cdef: + int i + floating yi + floating tmp = 0.0 + floating result = 0.0 + floating partial_y_squared_norm = 0.0 + + # Split the loop to avoid unsafe compiler auto optimizations + for i in range(x_nnz): + yi = y[x_indices[i] * incy] + partial_y_squared_norm += yi * yi + + for i in range(x_nnz): + tmp = x_data[i] - y[x_indices[i] * incy] + result += tmp * tmp + + result += y_squared_norm - partial_y_squared_norm + + return fmax(result, 0) diff --git a/sklearn/metrics/pairwise.py b/sklearn/metrics/pairwise.py index 315e3c8460b06..502127912f02a 100644 --- a/sklearn/metrics/pairwise.py +++ b/sklearn/metrics/pairwise.py @@ -30,6 +30,8 @@ from ..utils._joblib import effective_n_jobs from .pairwise_fast import _chi2_kernel_fast, _sparse_manhattan +from .pairwise_fast import _euclidean_dense_dense_exact +from ._safe_euclidean_sparse import _euclidean_sparse_dense_exact # Utility Functions @@ -168,20 +170,6 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, Considering the rows of X (and Y=X) as vectors, compute the distance matrix between each pair of vectors. - For efficiency reasons, the euclidean distance between a pair of row - vector x and y is computed as:: - - dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y)) - - This formulation has two advantages over other ways of computing distances. - First, it is computationally efficient when dealing with sparse data. - Second, if one argument varies but the other remains unchanged, then - `dot(x, x)` and/or `dot(y, y)` can be pre-computed. - - However, this is not the most precise way of doing this computation, and - the distance matrix returned by this function may not be exactly - symmetric as required by, e.g., ``scipy.spatial.distance`` functions. - Read more in the :ref:`User Guide `. Parameters @@ -193,6 +181,7 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, Y_norm_squared : array-like, shape (n_samples_2, ), optional Pre-computed dot-products of vectors in Y (e.g., ``(Y**2).sum(axis=1)``) + May be ignored in some cases, see the note below. squared : boolean, optional Return squared Euclidean distances. @@ -200,10 +189,35 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, X_norm_squared : array-like, shape = [n_samples_1], optional Pre-computed dot-products of vectors in X (e.g., ``(X**2).sum(axis=1)``) + May be ignored in some cases, see the note below. Returns ------- - distances : {array, sparse matrix}, shape (n_samples_1, n_samples_2) + distances : array, shape (n_samples_1, n_samples_2) + + Notes + ----- + When ``n_features > 16``, the euclidean distance between a pair of row + vector x and y is computed as:: + + dist(x, y) = sqrt(dot(x, x) - 2 * dot(x, y) + dot(y, y)) + + This formulation is computationaly more efficient than the usual one and + can benefit from pre-computed ``dot(x, x)`` and/or ``dot(y, y)``. When the + input is stored in float32, computations are done by first upcasting ``X`` + and ``Y`` to float64 (by chunks to limit memory usage). In that case, + ``X_norm_squared`` and ``Y_norm_squared`` are ignored and computed based on + upcast ``X`` and ``Y`` to keep good precision. + + However, this is not the most precise way of doing this computation, and + the distance matrix returned by this function may not be exactly + symmetric as required by, e.g., ``scipy.spatial.distance`` functions. + + When ``n_features <= 16``, the previous method is not as efficient and is + more likely to suffer from numerical instabilities, so the euclidean + distance between a pair of row vector x and y is computed as:: + + dist(x, y) = sqrt(dot(x - y)) Examples -------- @@ -224,39 +238,169 @@ def euclidean_distances(X, Y=None, Y_norm_squared=None, squared=False, """ X, Y = check_pairwise_arrays(X, Y) + XX, YY = _check_norms(X, Y, X_norm_squared, Y_norm_squared) + + n_features = X.shape[1] + + # For n_features > 16 we use the 'fast 'method to compute the euclidean + # distance, i.e. d(x,y)² = ||x||² + ||y||² - 2 * x.y + # It's faster but less precise. + if n_features > 16: + + # To minimize precision issues with float32, we compute the distance + # matrix on chunks of X and Y upcast to float64 + if X.dtype == np.float32: + distances = _euclidean_distances_upcast_fast(X, XX, Y, YY) + + # if dtype is already float64, no need to chunk and upcast + else: + distances = - 2 * safe_sparse_dot(X, Y.T, dense_output=True) + distances += XX[:, np.newaxis] + distances += YY[np.newaxis, :] + + # ensure no negative squared distance. + np.maximum(distances, 0, out=distances) + + # For n_features <= 16, we use the 'exact' method, i.e. the usual method, + # d(x,y)² = ||x - y||². + else: + + # Euclidean distance between 2 sparse vectors is very slow. It's much + # faster to densify one. We densify the smaller one for lower memory + # usage. + if issparse(X) and issparse(Y): + if Y.shape[0] > X.shape[0]: + X = X.toarray() + else: + Y = Y.toarray() + + if issparse(X): + distances = _euclidean_sparse_dense_exact( + X.data, X.indices, X.indptr, Y, YY) + elif issparse(Y): + distances = _euclidean_sparse_dense_exact( + Y.data, Y.indices, Y.indptr, X, XX).T + else: + distances = _euclidean_dense_dense_exact(X, Y) + + # Ensure that distances between vectors and themselves are set to 0.0. + # This may not be the case due to floating point rounding errors. + if X is Y: + np.fill_diagonal(distances, 0) + + return distances if squared else np.sqrt(distances, out=distances) + + +def _check_norms(X, Y=None, X_norm_squared=None, Y_norm_squared=None): + n_features = X.shape[1] + + # In this case, we compute euclidean distances by upcasting to float64. + # Computing norms on float32 and then upcast to float64 loses precision. + # We either accept float64 precomputed norms or delay their computation to + # the moment X (resp. Y) is upcast to float64. + special_case = n_features > 16 and X.dtype == np.float32 + if X_norm_squared is not None: - XX = check_array(X_norm_squared) - if XX.shape == (1, X.shape[0]): - XX = XX.T - elif XX.shape != (X.shape[0], 1): + XX = np.atleast_1d(X_norm_squared).reshape(-1) + if XX.shape != (X.shape[0],): raise ValueError( "Incompatible dimensions for X and X_norm_squared") + if special_case and XX.dtype == np.float32: + XX = None + elif special_case: + XX = None else: - XX = row_norms(X, squared=True)[:, np.newaxis] + XX = row_norms(X, squared=True) if X is Y: # shortcut in the common case euclidean_distances(X, X) - YY = XX.T + YY = XX elif Y_norm_squared is not None: - YY = np.atleast_2d(Y_norm_squared) - - if YY.shape != (1, Y.shape[0]): + YY = np.atleast_1d(Y_norm_squared).reshape(-1) + if YY.shape != (Y.shape[0],): raise ValueError( "Incompatible dimensions for Y and Y_norm_squared") + if special_case and YY.dtype == np.float32: + YY = None + elif special_case: + YY = None else: - YY = row_norms(Y, squared=True)[np.newaxis, :] + YY = row_norms(Y, squared=True) - distances = safe_sparse_dot(X, Y.T, dense_output=True) - distances *= -2 - distances += XX - distances += YY - np.maximum(distances, 0, out=distances) + if not special_case: + XX = XX.astype(X.dtype, copy=False) + YY = YY.astype(Y.dtype, copy=False) - if X is Y: - # Ensure that distances between vectors and themselves are set to 0.0. - # This may not be the case due to floating point rounding errors. - distances.flat[::distances.shape[0] + 1] = 0.0 + return XX, YY - return distances if squared else np.sqrt(distances, out=distances) + +def _euclidean_distances_upcast_fast(X, XX=None, Y=None, YY=None): + """Euclidean distances between X and Y + + Assumes X and Y have float32 dtype. + Assumes XX and YY have float64 dtype or are None. + + X and Y are upcast to float64 by chunks, which size is chosen to limit + memory increase by approximately 10MiB. + """ + n_samples_X = X.shape[0] + n_samples_Y = Y.shape[0] + n_features = X.shape[1] + + distances = np.empty((n_samples_X, n_samples_Y), dtype=np.float32) + + maxmem = 10 * 2**17 # this number of float64 take 10MiB memory. + + x_density = X.getnnz() / np.prod(X.shape) if issparse(X) else 1 + y_density = Y.getnnz() / np.prod(Y.shape) if issparse(Y) else 1 + + # The increase amount of memory is: + # - x_density * chunk_size * n_features (copy of chunk of X) + # - y_density * chunk_size * n_features (copy of chunk of Y) + # - chunk_size * chunk_size (chunk of distance matrix) + # Hence x² + (xd+yd)kx = M, where x=chunk_size, k=n_features, M=maxmem + # xd=x_density and yd=y_density + tmp = (x_density + y_density) * n_features + chunk_size = (-tmp + np.sqrt(tmp**2 + 4 * maxmem)) / 2 + chunk_size = max(int(chunk_size), 1) + + n_samples_X_rem = n_samples_X % chunk_size + n_chunks_X = n_samples_X // chunk_size + (n_samples_X_rem != 0) + n_samples_Y_rem = n_samples_Y % chunk_size + n_chunks_Y = n_samples_Y // chunk_size + (n_samples_Y_rem != 0) + + for i in range(n_chunks_X): + xs = i * chunk_size + xe = xs + (chunk_size if i < n_chunks_X - 1 else n_samples_X_rem) + + X_chunk = X[xs:xe].astype(np.float64) + if XX is None: + XX_chunk = row_norms(X_chunk, squared=True) + else: + XX_chunk = XX[xs:xe] + + for j in range(n_chunks_Y): + ys = j * chunk_size + ye = ys + (chunk_size if j < n_chunks_Y - 1 else n_samples_Y_rem) + + if X is Y and j < i: + # when X is Y the distance matrix is symmetric so we only need + # to compute half of it. + d = distances[ys:ye, xs:xe].T + + else: + Y_chunk = Y[ys:ye].astype(np.float64) + if YY is None: + YY_chunk = row_norms(Y_chunk, squared=True) + else: + YY_chunk = YY[ys:ye] + + d = -2 * safe_sparse_dot(X_chunk, Y_chunk.T, dense_output=True) + d += XX_chunk[:, np.newaxis] + d += YY_chunk[np.newaxis, :] + + distances[xs:xe, ys:ye] = d.astype(np.float32) + + return distances def _argmin_min_reduce(dist, start): diff --git a/sklearn/metrics/pairwise_fast.pyx b/sklearn/metrics/pairwise_fast.pyx index 1a77aad2e6aa1..b09f7d7ab618e 100644 --- a/sklearn/metrics/pairwise_fast.pyx +++ b/sklearn/metrics/pairwise_fast.pyx @@ -8,10 +8,10 @@ # # License: BSD 3 clause -from libc.string cimport memset import numpy as np cimport numpy as np from cython cimport floating +from libc.string cimport memset from ..utils._cython_blas cimport _asum @@ -67,3 +67,48 @@ def _sparse_manhattan(floating[::1] X_data, int[:] X_indices, int[:] X_indptr, row[Y_indices[j]] -= Y_data[j] D[ix, iy] = _asum(n_features, &row[0], 1) + + +def _euclidean_dense_dense_exact(np.ndarray[floating, ndim=2] X, + np.ndarray[floating, ndim=2] Y): + cdef: + int n_samples_X = X.shape[0] + int n_samples_Y = Y.shape[0] + int n_features = X.shape[1] + int incx = X.strides[1] / X.itemsize + int incy = Y.strides[1] / Y.itemsize + + int i, j + + floating[:, ::1] D = np.empty((n_samples_X, n_samples_Y), X.dtype) + + for i in range(n_samples_X): + for j in range(n_samples_Y): + D[i, j] = _euclidean_dense_dense_exact_1d( + &X[i, 0], incx, &Y[j, 0], incy, n_features) + + return np.asarray(D) + + +cdef floating _euclidean_dense_dense_exact_1d(floating *x, + int incx, + floating *y, + int incy, + int n_features) nogil: + """Euclidean distance between x dense and y dense""" + cdef: + int i + floating tmp = 0.0 + floating result = 0.0 + + if incx == incy == 1: + # special case for c contiguous arrays for better vectorization. + for i in range(n_features): + tmp = x[i] - y[i] + result += tmp * tmp + else: + for i in range(n_features): + tmp = x[i * incx] - y[i * incy] + result += tmp * tmp + + return result diff --git a/sklearn/metrics/setup.py b/sklearn/metrics/setup.py index 97175456220cd..c4c9f78ddcd7d 100644 --- a/sklearn/metrics/setup.py +++ b/sklearn/metrics/setup.py @@ -14,6 +14,11 @@ def configuration(parent_package="", top_path=None): config.add_extension("pairwise_fast", sources=["pairwise_fast.pyx"], + libraries=libraries, + extra_compile_args=['-Ofast']) + + config.add_extension("_safe_euclidean_sparse", + sources=["_safe_euclidean_sparse.pyx"], libraries=libraries) config.add_subpackage('tests') diff --git a/sklearn/metrics/tests/test_pairwise.py b/sklearn/metrics/tests/test_pairwise.py index fe52ce0244dd0..b9dccd382df0d 100644 --- a/sklearn/metrics/tests/test_pairwise.py +++ b/sklearn/metrics/tests/test_pairwise.py @@ -540,41 +540,159 @@ def test_pairwise_distances_chunked(): assert_raises(StopIteration, next, gen) -def test_euclidean_distances(): - # Check the pairwise Euclidean distances computation - X = [[0]] - Y = [[1], [2]] +@pytest.mark.parametrize("dim", [16, 64], ids=["dim<32", "dim>32"]) +@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +def test_euclidean_distances_known_result(dim, x_array_constr, y_array_constr): + # Check the pairwise Euclidean distances computation on known result + X = x_array_constr([[0]*dim]) + Y = y_array_constr([[1] + [0]*(dim - 1), [0]*(dim - 1) + [2]]) D = euclidean_distances(X, Y) - assert_array_almost_equal(D, [[1., 2.]]) + assert_allclose(D, [[1., 2.]]) - X = csr_matrix(X) - Y = csr_matrix(Y) - D = euclidean_distances(X, Y) - assert_array_almost_equal(D, [[1., 2.]]) +@pytest.mark.parametrize("dim", [16, 64], ids=["dim<32", "dim>32"]) +@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +def test_euclidean_distances_with_norms(dim, y_array_constr): + # check that we still get the right answers with {X,Y}_norm_squared + # and that we get a wrong answer with wrong {X,Y}_norm_squared rng = np.random.RandomState(0) - X = rng.random_sample((10, 4)) - Y = rng.random_sample((20, 4)) + X = rng.random_sample((10, dim)) + Y = rng.random_sample((20, dim)) X_norm_sq = (X ** 2).sum(axis=1).reshape(1, -1) Y_norm_sq = (Y ** 2).sum(axis=1).reshape(1, -1) - # check that we still get the right answers with {X,Y}_norm_squared + Y = y_array_constr(Y) + D1 = euclidean_distances(X, Y) D2 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq) D3 = euclidean_distances(X, Y, Y_norm_squared=Y_norm_sq) D4 = euclidean_distances(X, Y, X_norm_squared=X_norm_sq, Y_norm_squared=Y_norm_sq) - assert_array_almost_equal(D2, D1) - assert_array_almost_equal(D3, D1) - assert_array_almost_equal(D4, D1) + assert_allclose(D2, D1) + assert_allclose(D3, D1) + assert_allclose(D4, D1) # check we get the wrong answer with wrong {X,Y}_norm_squared - X_norm_sq *= 0.5 - Y_norm_sq *= 0.5 + # (note that if n_features <= 32 and both X and Y are dense, squared norms + # are not used) wrong_D = euclidean_distances(X, Y, X_norm_squared=np.zeros_like(X_norm_sq), Y_norm_squared=np.zeros_like(Y_norm_sq)) - assert_greater(np.max(np.abs(wrong_D - D1)), .01) + if dim > 32 or issparse(Y): + with pytest.raises(AssertionError): + assert_allclose(wrong_D, D1) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("dim", [2, 20], ids=["dim<32", "dim>32"]) +@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +def test_euclidean_distances(dtype, dim, x_array_constr, y_array_constr): + # check that euclidean distances gives same result as scipy cdist + # when X and Y != X are provided + rng = np.random.RandomState(0) + X = rng.random_sample((100, dim)).astype(dtype, copy=False) + X[X < 0.8] = 0 + Y = rng.random_sample((10, dim)).astype(dtype, copy=False) + Y[Y < 0.8] = 0 + + expected = cdist(X, Y) + + X = x_array_constr(X) + Y = y_array_constr(Y) + distances = euclidean_distances(X, Y) + + assert_allclose(distances, expected, rtol=1e-6) + assert distances.dtype == dtype + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("dim", [2, 20], ids=["dim<32", "dim>32"]) +@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +def test_euclidean_distances_sym(dtype, dim, x_array_constr): + # check that euclidean distances gives same result as scipy pdist + # when only X is provided + rng = np.random.RandomState(0) + X = rng.random_sample((100, dim)).astype(dtype, copy=False) + X[X < 0.8] = 0 + + expected = squareform(pdist(X)) + + X = x_array_constr(X) + distances = euclidean_distances(X) + + assert_allclose(distances, expected, rtol=1e-6) + assert distances.dtype == dtype + + +@pytest.mark.parametrize("x_array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +@pytest.mark.parametrize("y_array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +def test_euclidean_distances_upcast(x_array_constr, y_array_constr): + # check euclidean distances upcasted when dtype=float32 and dim > 32. + # dimensions chosen to have at least 2 chunks for X and Y. + rng = np.random.RandomState(0) + X = rng.random_sample((800, 2000)).astype(np.float32) + X[X < 0.2] = 0 + Y = rng.random_sample((700, 2000)).astype(np.float32) + Y[Y < 0.2] = 0 + + expected = cdist(X, Y) + + X = x_array_constr(X) + Y = y_array_constr(Y) + distances = euclidean_distances(X, Y) + + assert_allclose(distances, expected, rtol=1e-6) + assert distances.dtype == np.float32 + + +@pytest.mark.parametrize("array_constr", [np.array, csr_matrix], + ids=["dense", "sparse"]) +def test_euclidean_distances_upcast_symmetric(array_constr): + # check euclidean distances upcasted when dtype=float32 and dim > 32, when + # only X is provided. Dimensions chosen to have at least 2 chunks for X. + rng = np.random.RandomState(0) + X = rng.random_sample((800, 2000)).astype(np.float32) + X[X < 0.2] = 0 + + expected = squareform(pdist(X)) + + X = array_constr(X) + distances = euclidean_distances(X) + + assert_allclose(distances, expected, rtol=1e-6) + assert distances.dtype == np.float32 + + +@pytest.mark.parametrize("dtype, x", + [(np.float32, 10000), (np.float64, 100000000)], + ids=["float32", "float64"]) +@pytest.mark.parametrize("dim", [1, 100]) +def test_euclidean_distances_extreme_values(dtype, x, dim): + # check that euclidean distances is correct where 'fast' method wouldn't be + X = np.array([[x + 1] + [0] * (dim - 1)], dtype=dtype) + Y = np.array([[x] + [0] * (dim - 1)], dtype=dtype) + + expected = [[1]] + + distances = euclidean_distances(X, Y) + + if dtype == np.float64 and dim == 100: + # This is expected to fail for float64 and dimension > 32 due to lack + # of precision of the fast method of euclidean distances. + with pytest.raises(AssertionError, match='Not equal to tolerance'): + assert_allclose(distances, expected) + else: + assert_allclose(distances, expected) def test_cosine_distances():