Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse pca patch #5493

Merged
merged 32 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
4943285
Fixes for sparse PCA
Intron7 Jul 6, 2023
7c76ede
added support for 64bit indexing
Intron7 Jul 6, 2023
20ce7ec
Update python/cuml/prims/stats/covariance.py
Intron7 Jul 8, 2023
ba8ed16
Update python/cuml/prims/stats/covariance.py
Intron7 Jul 8, 2023
91eed9a
Update python/cuml/prims/stats/covariance.py
Intron7 Jul 8, 2023
67834aa
pre-commit changes
Intron7 Jul 8, 2023
c7a0324
switched to only import cov_sparse
Intron7 Jul 9, 2023
75885fa
updated `cov_sparse`
Intron7 Jul 9, 2023
00ef242
added `test_cov_sparse`
Intron7 Jul 9, 2023
572942c
pre-commit update
Intron7 Jul 9, 2023
39a95b9
Update python/cuml/prims/stats/covariance.py
Intron7 Jul 10, 2023
5650ce1
made `cov_sparse` private
Intron7 Jul 10, 2023
d08dff6
call `_cov_sparse` from `cov`
Intron7 Jul 10, 2023
25b8563
docstring fix
Intron7 Jul 10, 2023
39ded03
updated test
Intron7 Jul 10, 2023
0b293dc
fixed typo with `self.mean_`
Intron7 Jul 10, 2023
e8df915
added hint to issue #5475
Intron7 Jul 10, 2023
6a421e7
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 11, 2023
0024078
fixes issue number
Intron7 Jul 11, 2023
308e947
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 12, 2023
f5018fa
Merge branch 'rapidsai:branch-23.08' into sparse_pca_patch
Intron7 Jul 13, 2023
106b4c8
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 18, 2023
9a3fb01
Updated Tests
Intron7 Jul 24, 2023
a9155c1
improved csr kernel and added coo kernel
Intron7 Jul 24, 2023
7a58773
Merge branch 'rapidsai:branch-23.08' into sparse_pca_patch
Intron7 Jul 27, 2023
b9d73b3
fixed bug with `_cov_sparse`
Intron7 Jul 27, 2023
f744a6c
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 31, 2023
e24f187
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Jul 31, 2023
a9c76b3
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Aug 1, 2023
b321a51
Revert change to python/cuml/prims/stats/__init__.py .
csadorf Aug 1, 2023
cca3f81
Merge branch 'branch-23.08' into sparse_pca_patch
Intron7 Aug 1, 2023
1b9a2f7
added reference to cuml issue
Intron7 Aug 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion python/cuml/prims/stats/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
csadorf marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
131 changes: 131 additions & 0 deletions python/cuml/prims/stats/covariance.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,42 @@
}
"""

mean_cov_kernel_str = r"""
(const int *indptr, const int *index, {0} *data, int nrows, int ncols, {0} *out, {0} *mean) {
int row = blockDim.x * blockIdx.x + threadIdx.x;
if(row >= nrows) return;
int start_idx = indptr[row];
int stop_idx = indptr[row+1];

for(int idx = start_idx; idx < stop_idx; idx++){
Copy link
Member

@cjnolet cjnolet Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some concerns about this implementation, but the good news is that I think we have some options that are fairly trivial to fix. If you think about the way SIMT architectures like GPUs work at the hardware level, each warp (grouping of 32 threads) is only able to execute a single instruction at a time. If two threads within a warp need to execute different instructions, the rest of the threads need to stall to wait for those instructions, even if they aren't executing anything. Because of this, we try to design our kernels so that the threads within each warp are 1) able to do a uniform amount of work, and 2) able to execute the same instructions as much as possible. Things like atomics and conditional branching can have an impact on this, which is called warp divergence.

The degree distributions (number of columns for each row) are almost never uniform and are most often highly skewed, sometimes even by power laws. Because of this, you cannot expect good performance by simply having each thread loop through the columns within each row. Sometimes folks perform a permutation of the matrix in order to sort the rows by their degree distributions. This can help a litle, but it's not a feasible solution here because we can't afford to copy the data.

The other piece here is the atomics- they are expensive and they also cause the warps to diverge because the amount of time for each atomic to execute is non-deterministic and based on the number of competing concurrent writes. These collisions are going to compound w/ the means and I would highly suggest removing the fused mean computation by using cupy to compute that.

Memory reads are also impacted by this model because with each thread reading sequential memory locations from the sparse arrays, you aren't able to benefit from coalescing in each warp since each threads won't be reading from sequential locations on each instruction cycle.

For CSR matrices- an efficient way to do this would be schedule some number of warps per block (let's start with somewhere between 1-8), and have each warp work on their own row at a time. A block that contains multiple warps will need to wait for any straggler warps but 1 warp per block could end up causing issues for load balancing. To get a little more intricate, we could perform a differencing of the indptr array, which would give us the degrees for each row, and then perform a couple kernel launches with different numbers of warps per block to make sure we're keeping the warps uniform (enough) for good performance. For a first-pass, though, we can skip launching multiple kernels and just find a good block size that yields reasonable performance on power-law graphs.

For COO: This is the easiest case, since we can essentially just compute the output gramm embarassingly parallel- map each thread to an edge in the array and perform your atomicAdd.

I would also make sure to test the performance of this implementation with a power-law graph. You can use the RMAT generator in pylibraft to generate such a graph. The problem w/ cupy's sparse generator is that the resulting arrays will have a uniform degree distribution, and thus will not match real-world sparse datasets. Power-law is a worst-case so if we work well on those then uniform degree distributions will just yield better perf.

Further, since we are replacing a highly optimized primitive from cusparse here, we should do our due diligence and gather some benchmarks to make sure we aren't introducing any significant regressions in the meantime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Corey thank you so much for your insights and tips. I worked on new kernels for both COO and CSR. These are in my testing (so far) much faster than the cupy x.T.dot(x) versions at least for matrix sizes where those still worked. I kept the reduced atomicsadds so that only the upper half of the matrix gets filled in. I'm currently getting the Powergraphs done.
So far for cpx.scipy.sparse.random(100000, 2000,density=0.05,format="csr/coo",dtype=cp.float64,random_state=42) I'm going from around 800 ms for x.T.dot(x) to 8 to 10 ms for the Rawkernels. For my realworld singlecell data they also perform really well. I think that 2-3 warps might be the most performant versions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet can we assume that the coo matrix is sorted?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Intron7 I don't know that we can assume that w/ cupy/scipy. But if we use an element-wise kernel, we shouldn't necessarily have to assume that, should we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @cjnolet,
here are the performance numbers you asked for. I tested my Raw kernels with multiple batch sizes vs the standard SPGEMM algorithms from CuPy. I ran each function 10 times and averaged the runtimes. The coo_kernel results include x.sum_duplicates() runtimes. For 50000000 and more edges the standard libraries stop working. I updated kernels in the branch.

Testing for 50000 edges
(51073, 16382)
csr 10.99555492401123
coo 17.003202438354492
csr_kernel 32 11.6835355758667
csr_kernel 64 9.113883972167969
csr_kernel 128 8.887648582458496
csr_kernel 256 36.73577308654785
csr_kernel 512 54.798269271850586
csr_kernel 1024 55.03363609313965
coo_kernel 32 61.06009483337402
coo_kernel 64 51.66192054748535
coo_kernel 128 61.94412708282471
coo_kernel 256 62.290191650390625
coo_kernel 512 62.026119232177734
coo_kernel 1024 61.32020950317383

Testing for 500000 edges
(29509, 16374)
csr 65.50588607788086
coo 71.40600681304932
csr_kernel 32 54.044485092163086
csr_kernel 64 60.14046669006348
csr_kernel 128 56.35478496551514
csr_kernel 256 46.6019868850708
csr_kernel 512 54.707956314086914
csr_kernel 1024 55.11133670806885
coo_kernel 32 63.839530944824226
coo_kernel 64 64.72692489624023
coo_kernel 128 65.07103443145752
coo_kernel 256 69.69590187072754
coo_kernel 512 64.99478816986084
coo_kernel 1024 64.7336483001709

Testing for 5000000 edges
(43418, 16377)
csr 360.93554496765137
coo 219.34540271759033
csr_kernel 32 34.18407440185547
csr_kernel 64 53.055429458618164
csr_kernel 128 58.417463302612305
csr_kernel 256 52.16631889343262
csr_kernel 512 49.57113265991211
csr_kernel 1024 54.834651947021484
coo_kernel 32 78.15330028533936
coo_kernel 64 76.80673599243164
coo_kernel 128 77.03337669372559
coo_kernel 256 73.25513362884521
coo_kernel 512 76.92813873291016
coo_kernel 1024 77.04839706420898

Testing for 50000000 edges
(64359, 16384)
csr_kernel 32 188.805890083313
csr_kernel 64 158.88402462005615
csr_kernel 128 87.86423206329346
csr_kernel 256 78.75776290893555
csr_kernel 512 76.42595767974854
csr_kernel 1024 70.73354721069336
coo_kernel 32 283.26284885406494
coo_kernel 64 282.31539726257324
coo_kernel 128 282.2613477706909
coo_kernel 256 282.2352886199951
coo_kernel 512 282.2505474090576
coo_kernel 1024 282.18557834625244

Testing for 500000000 edges
(11200, 16384)
csr_kernel 32 585.2307081222534
csr_kernel 64 414.9267911911011
csr_kernel 128 305.6136131286621
csr_kernel 256 256.9816827774048
csr_kernel 512 221.5327024459839
csr_kernel 1024 204.65679168701172
coo_kernel 32 3360.795545578003
coo_kernel 64 3355.547070503235
coo_kernel 128 3356.409192085266
coo_kernel 256 3355.93740940094
coo_kernel 512 3356.1617851257324
coo_kernel 1024 3356.2454223632812

int index1 = index[idx];
{0} data1 = data[idx];
long long int outidx = \
static_cast<long long int>(index1) * ncols + index1;
atomicAdd(&out[outidx], data1 * data1);
atomicAdd(&mean[index1], data1);
for(int idx2 = idx+1; idx2 < stop_idx; idx2++){
int index2 = index[idx2];
{0} data2 = data[idx2];
long long int outidx2 = \
static_cast<long long int>(index1) * ncols + index2;
atomicAdd(&out[outidx2], data1 * data2);
}
csadorf marked this conversation as resolved.
Show resolved Hide resolved
}
}
"""


def _cov_kernel(dtype):
return cuda_kernel_factory(cov_kernel_str, (dtype,), "cov_kernel")


def _mean_cov_kernel(dtype):
return cuda_kernel_factory(
mean_cov_kernel_str, (dtype,), "mean_cov_kernel"
)


@cuml.internals.api_return_any()
def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
"""
Expand Down Expand Up @@ -102,6 +133,15 @@ def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
"X and Y must have same shape %s != %s" % (x.shape, y.shape)
)

# Fix for cupy issue #7699: addressing problems with sparse matrix multiplication (spGEMM)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reference a cuml github issue here in a TODO (and create one if not already created) so that we can track it and know where to apply the fix in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#5475 is the cuml issue and I can also reference it in the comment. Would that be sufficent @cjnolet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet I've asked that we reference the underlying issue here instead of the cuML issue that will be closed with this PR. Do you want a separate cuML issue that references cupy#7699 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csadorf yes, I would prefer to reference the cuml issue (and also reference the cupy issue) as the cuml issue local to the repository where the code is hosted, and thus has a stronger link to tracking the work's progress.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added in a reference to the cuml issue #5475

if (
x is y
and cupyx.scipy.sparse.issparse(x)
and mean_x is None
and mean_y is None
):
return _cov_sparse(x, return_gram=return_gram, return_mean=return_mean)
csadorf marked this conversation as resolved.
Show resolved Hide resolved

if mean_x is not None and mean_y is not None:
if mean_x.dtype != mean_y.dtype:
raise ValueError(
Expand Down Expand Up @@ -156,3 +196,94 @@ def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
return cov_result, mean_x, mean_y
elif return_gram and return_mean:
return cov_result, gram_matrix, mean_x, mean_y


@cuml.internals.api_return_any()
def _cov_sparse(x, return_gram=False, return_mean=False):
"""
Computes the mean and the covariance of matrix X of
the form Cov(X, X) = E(XX) - E(X)E(X)
csadorf marked this conversation as resolved.
Show resolved Hide resolved

This is a temporary fix for cupy issue #7699, where the
operation `x.T.dot(x)` did not work for larger
sparse matrices.

Parameters
----------

x : cupyx.scipy.sparse of size (m, n)
return_gram : boolean (default = False)
If True, gram matrix of the form (1 / n) * X.T.dot(X)
will be returned.
When True, a copy will be created
to store the results of the covariance.
When False, the local gram matrix result
will be overwritten
return_mean: boolean (default = False)
If True, the Maximum Likelihood Estimate used to
calculate the mean of X and X will be returned,
of the form (1 / n) * mean(X) and (1 / n) * mean(X)

Returns
-------

result : cov(X, X) when return_gram and return_mean are False
cov(X, X), gram(X, X) when return_gram is True,
return_mean is False
cov(X, X), mean(X), mean(X) when return_gram is False,
return_mean is True
cov(X, X), gram(X, X), mean(X), mean(X)
when return_gram is True and return_mean is True
"""
if not cupyx.scipy.sparse.isspmatrix_csr(x):
x = x.tocsr()
gram_matrix = cp.zeros((x.shape[1], x.shape[1]), dtype=x.data.dtype)
mean_x = cp.zeros((x.shape[1],), dtype=x.data.dtype)

block = (8,)
grid = (math.ceil(x.shape[0] / block[0]),)
compute_mean_cov = _mean_cov_kernel(x.data.dtype)
compute_mean_cov(
grid,
block,
(
x.indptr,
x.indices,
x.data,
x.shape[0],
x.shape[1],
gram_matrix,
mean_x,
),
)
gram_matrix = gram_matrix + gram_matrix.T
gram_matrix -= cp.diag(cp.diag(gram_matrix) / 2)
gram_matrix *= 1 / x.shape[0]
mean_x *= 1 / x.shape[0]

if return_gram:
cov_result = cp.zeros(
(gram_matrix.shape[0], gram_matrix.shape[0]),
dtype=gram_matrix.dtype,
)
else:
cov_result = gram_matrix

compute_cov = _cov_kernel(x.dtype)

block_size = (8, 8)
grid_size = (math.ceil(gram_matrix.shape[0] / 8),) * 2
compute_cov(
grid_size,
block_size,
(cov_result, gram_matrix, mean_x, mean_x, gram_matrix.shape[0]),
)

if not return_gram and not return_mean:
return cov_result
elif return_gram and not return_mean:
return cov_result, gram_matrix
elif not return_gram and return_mean:
return cov_result, mean_x, mean_x
elif return_gram and return_mean:
return cov_result, gram_matrix, mean_x, mean_x
24 changes: 24 additions & 0 deletions python/cuml/tests/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from cuml.testing.utils import array_equal
from cuml.prims.stats import cov
from cuml.prims.stats.covariance import _cov_sparse
import pytest
from cuml.internals.safe_imports import gpu_only_import

Expand Down Expand Up @@ -43,3 +44,26 @@ def test_cov(nrows, ncols, sparse, dtype):
local_cov = cp.cov(x, rowvar=False, ddof=0)

assert array_equal(cov_result, local_cov, 1e-6, with_sign=True)


@pytest.mark.parametrize("nrows", [1000])
@pytest.mark.parametrize("ncols", [500, 1500])
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
def test_cov_sparse(nrows, ncols, dtype):

x = cupyx.scipy.sparse.random(
nrows, ncols, density=0.07, format="csr", dtype=dtype
)
cov_result = _cov_sparse(x, return_mean=True)

# check cov
assert cov_result[0].shape == (ncols, ncols)

x = x.todense()
local_cov = cp.cov(x, rowvar=False, ddof=0)

assert array_equal(cov_result[0], local_cov, 1e-6, with_sign=True)

# check mean
local_mean = x.mean(axis=0)
assert array_equal(cov_result[1], local_mean, 1e-6, with_sign=True)