Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse pca patch #5493

Merged
merged 32 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4943285
Fixes for sparse PCA
Intron7 Jul 6, 2023
7c76ede
added support for 64bit indexing
Intron7 Jul 6, 2023
20ce7ec
Update python/cuml/prims/stats/covariance.py
Intron7 Jul 8, 2023
ba8ed16
Update python/cuml/prims/stats/covariance.py
Intron7 Jul 8, 2023
91eed9a
Update python/cuml/prims/stats/covariance.py
Intron7 Jul 8, 2023
67834aa
pre-commit changes
Intron7 Jul 8, 2023
c7a0324
switched to only import cov_sparse
Intron7 Jul 9, 2023
75885fa
updated `cov_sparse`
Intron7 Jul 9, 2023
00ef242
added `test_cov_sparse`
Intron7 Jul 9, 2023
572942c
pre-commit update
Intron7 Jul 9, 2023
39a95b9
Update python/cuml/prims/stats/covariance.py
Intron7 Jul 10, 2023
5650ce1
made `cov_sparse` private
Intron7 Jul 10, 2023
d08dff6
call `_cov_sparse` from `cov`
Intron7 Jul 10, 2023
25b8563
docstring fix
Intron7 Jul 10, 2023
39ded03
updated test
Intron7 Jul 10, 2023
0b293dc
fixed typo with `self.mean_`
Intron7 Jul 10, 2023
e8df915
added hint to issue #5475
Intron7 Jul 10, 2023
6a421e7
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 11, 2023
0024078
fixes issue number
Intron7 Jul 11, 2023
308e947
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 12, 2023
f5018fa
Merge branch 'rapidsai:branch-23.08' into sparse_pca_patch
Intron7 Jul 13, 2023
106b4c8
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 18, 2023
9a3fb01
Updated Tests
Intron7 Jul 24, 2023
a9155c1
improved csr kernel and added coo kernel
Intron7 Jul 24, 2023
7a58773
Merge branch 'rapidsai:branch-23.08' into sparse_pca_patch
Intron7 Jul 27, 2023
b9d73b3
fixed bug with `_cov_sparse`
Intron7 Jul 27, 2023
f744a6c
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 31, 2023
e24f187
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 31, 2023
a9c76b3
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Aug 1, 2023
b321a51
Revert change to python/cuml/prims/stats/__init__.py .
csadorf Aug 1, 2023
cca3f81
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Aug 1, 2023
1b9a2f7
added reference to cuml issue
Intron7 Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 203 additions & 0 deletions python/cuml/prims/stats/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,77 @@
}
"""

gramm_kernel_csr = r"""
(const int *indptr, const int *index, {0} *data, int nrows, int ncols, {0} *out) {
int row = blockIdx.x;
int col = threadIdx.x;

if(row >= nrows) return;

int start = indptr[row];
int end = indptr[row + 1];

for (int idx1 = start; idx1 < end; idx1++){
int index1 = index[idx1];
{0} data1 = data[idx1];
for(int idx2 = idx1 + col; idx2 < end; idx2 += blockDim.x){
int index2 = index[idx2];
{0} data2 = data[idx2];
atomicAdd(&out[index1 * ncols + index2], data1 * data2);
}
}
}
"""

gramm_kernel_coo = r"""
(const int *rows, const int *cols, {0} *data, int nnz, int ncols, int nrows, {0} * out) {
int i = blockIdx.x;
if (i >= nnz) return;
int row1 = rows[i];
int col1 = cols[i];
{0} data1 = data[i];
int limit = min(i + nrows, nnz);

for(int j = i + threadIdx.x; j < limit; j += blockDim.x){
if(row1 < rows[j]) return;

if(col1 <= cols[j]){
atomicAdd(&out[col1 * ncols + cols[j]], data1 * data[j]);
}
csadorf marked this conversation as resolved.
Show resolved Hide resolved
}
}
"""

copy_kernel = r"""
({0} *out, int ncols) {
int row = blockIdx.y * blockDim.y + threadIdx.y;
int col = blockIdx.x * blockDim.x + threadIdx.x;

if (row >= ncols || col >= ncols) return;

if (row > col) {
out[row * ncols + col] = out[col * ncols + row];
}
}
"""


def _cov_kernel(dtype):
return cuda_kernel_factory(cov_kernel_str, (dtype,), "cov_kernel")


def _gramm_kernel_csr(dtype):
return cuda_kernel_factory(gramm_kernel_csr, (dtype,), "gramm_kernel_csr")


def _gramm_kernel_coo(dtype):
return cuda_kernel_factory(gramm_kernel_coo, (dtype,), "gramm_kernel_coo")


def _copy_kernel(dtype):
return cuda_kernel_factory(copy_kernel, (dtype,), "copy_kernel")


@cuml.internals.api_return_any()
def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
"""
Expand Down Expand Up @@ -102,6 +168,16 @@ def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
"X and Y must have same shape %s != %s" % (x.shape, y.shape)
)

# Fix for cuml issue #5475 & cupy issue #7699
# addressing problems with sparse matrix multiplication (spGEMM)
if (
x is y
and cupyx.scipy.sparse.issparse(x)
and mean_x is None
and mean_y is None
):
return _cov_sparse(x, return_gram=return_gram, return_mean=return_mean)
csadorf marked this conversation as resolved.
Show resolved Hide resolved

if mean_x is not None and mean_y is not None:
if mean_x.dtype != mean_y.dtype:
raise ValueError(
Expand Down Expand Up @@ -156,3 +232,130 @@ def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
return cov_result, mean_x, mean_y
elif return_gram and return_mean:
return cov_result, gram_matrix, mean_x, mean_y


@cuml.internals.api_return_any()
def _cov_sparse(x, return_gram=False, return_mean=False):
"""
Computes the mean and the covariance of matrix X of
the form Cov(X, X) = E(XX) - E(X)E(X)
csadorf marked this conversation as resolved.
Show resolved Hide resolved

This is a temporary fix for
cuml issue #5475 and cupy issue #7699,
where the operation `x.T.dot(x)` did not work for
larger sparse matrices.

Parameters
----------

x : cupyx.scipy.sparse of size (m, n)
return_gram : boolean (default = False)
If True, gram matrix of the form (1 / n) * X.T.dot(X)
will be returned.
When True, a copy will be created
to store the results of the covariance.
When False, the local gram matrix result
will be overwritten
return_mean: boolean (default = False)
If True, the Maximum Likelihood Estimate used to
calculate the mean of X and X will be returned,
of the form (1 / n) * mean(X) and (1 / n) * mean(X)

Returns
-------

result : cov(X, X) when return_gram and return_mean are False
cov(X, X), gram(X, X) when return_gram is True,
return_mean is False
cov(X, X), mean(X), mean(X) when return_gram is False,
return_mean is True
cov(X, X), gram(X, X), mean(X), mean(X)
when return_gram is True and return_mean is True
"""

gram_matrix = cp.zeros((x.shape[1], x.shape[1]), dtype=x.data.dtype)
if cupyx.scipy.sparse.isspmatrix_csr(x):
block = (128,)
grid = (x.shape[0],)
compute_mean_cov = _gramm_kernel_csr(x.data.dtype)
compute_mean_cov(
grid,
block,
(
x.indptr,
x.indices,
x.data,
x.shape[0],
x.shape[1],
gram_matrix,
),
)

elif cupyx.scipy.sparse.isspmatrix_coo(x):
x.sum_duplicates()
nnz = len(x.row)
block = (128,)
grid = (nnz,)
compute_gram_coo = _gramm_kernel_coo(x.data.dtype)
compute_gram_coo(
grid,
block,
(x.row, x.col, x.data, nnz, x.shape[1], x.shape[0], gram_matrix),
)

else:
x = x.tocsr()
block = (128,)
grid = (math.ceil(x.shape[0] / block[0]),)
compute_mean_cov = _gramm_kernel_csr(x.data.dtype)
compute_mean_cov(
grid,
block,
(
x.indptr,
x.indices,
x.data,
x.shape[0],
x.shape[1],
gram_matrix,
),
)

copy_gram = _copy_kernel(x.data.dtype)
block = (32, 32)
grid = (math.ceil(x.shape[1] / block[0]), math.ceil(x.shape[1] / block[1]))
copy_gram(
grid,
block,
(gram_matrix, x.shape[1]),
)

mean_x = x.sum(axis=0) * (1 / x.shape[0])
gram_matrix *= 1 / x.shape[0]

if return_gram:
cov_result = cp.zeros(
(gram_matrix.shape[0], gram_matrix.shape[0]),
dtype=gram_matrix.dtype,
)
else:
cov_result = gram_matrix

compute_cov = _cov_kernel(x.dtype)

block_size = (32, 32)
grid_size = (math.ceil(gram_matrix.shape[0] / 8),) * 2
compute_cov(
grid_size,
block_size,
(cov_result, gram_matrix, mean_x, mean_x, gram_matrix.shape[0]),
)

if not return_gram and not return_mean:
return cov_result
elif return_gram and not return_mean:
return cov_result, gram_matrix
elif not return_gram and return_mean:
return cov_result, mean_x, mean_x
elif return_gram and return_mean:
return cov_result, gram_matrix, mean_x, mean_x
25 changes: 25 additions & 0 deletions python/cuml/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from cuml.testing.utils import array_equal
from cuml.prims.stats import cov
from cuml.prims.stats.covariance import _cov_sparse
import pytest
from cuml.internals.safe_imports import gpu_only_import

Expand Down Expand Up @@ -43,3 +44,27 @@ def test_cov(nrows, ncols, sparse, dtype):
local_cov = cp.cov(x, rowvar=False, ddof=0)

assert array_equal(cov_result, local_cov, 1e-6, with_sign=True)


@pytest.mark.parametrize("nrows", [1000])
@pytest.mark.parametrize("ncols", [500, 1500])
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
@pytest.mark.parametrize("mtype", ["csr", "coo"])
def test_cov_sparse(nrows, ncols, dtype, mtype):

x = cupyx.scipy.sparse.random(
nrows, ncols, density=0.07, format=mtype, dtype=dtype
)
cov_result = _cov_sparse(x, return_mean=True)

# check cov
assert cov_result[0].shape == (ncols, ncols)

x = x.todense()
local_cov = cp.cov(x, rowvar=False, ddof=0)

assert array_equal(cov_result[0], local_cov, 1e-6, with_sign=True)

# check mean
local_mean = x.mean(axis=0)
assert array_equal(cov_result[1], local_mean, 1e-6, with_sign=True)