Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse pca patch #5493

Merged
merged 32 commits into from
Aug 2, 2023
Merged

Sparse pca patch #5493

merged 32 commits into from
Aug 2, 2023

Conversation

Intron7
Copy link
Contributor

@Intron7 Intron7 commented Jul 6, 2023

Fixed bug with Sparse PCA function. #5475

I added a new cov_sparse function that replaces the cov function for _sparse_fit.
It used a custom kernel to calculate the gram_matrix within cov_sparse .

@Intron7 Intron7 requested a review from a team as a code owner July 6, 2023 22:18
@rapids-bot
Copy link

rapids-bot bot commented Jul 6, 2023

Pull requests from external contributors require approval from a rapidsai organization member with write permissions or greater before CI can begin.

@github-actions github-actions bot added the Cython / Python Cython or Python issue label Jul 6, 2023
@csadorf csadorf linked an issue Jul 7, 2023 that may be closed by this pull request
@csadorf csadorf added improvement Improvement / enhancement to an existing function non-breaking Non-breaking change labels Jul 7, 2023
@csadorf
Copy link
Contributor

csadorf commented Jul 7, 2023

/ok to test

@csadorf
Copy link
Contributor

csadorf commented Jul 7, 2023

@Intron7 Thank you very much for the contribution! We'll review shortly, however, it would be great if you could address the linter issues in the mean time. Thanks a lot!

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, Severin, for these changes! Unfortunately, I'm out of the office today and next week, but I have some ideas on ways this could be improved. I'll provide feedback when I'm back.

Copy link
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the contribution! I would appreciate if you could address the linter complaints and my comments and suggestions.

python/cuml/decomposition/pca.pyx Outdated Show resolved Hide resolved
python/cuml/decomposition/pca.pyx Outdated Show resolved Hide resolved
python/cuml/prims/stats/__init__.py Outdated Show resolved Hide resolved
python/cuml/prims/stats/covariance.py Outdated Show resolved Hide resolved
python/cuml/prims/stats/covariance.py Outdated Show resolved Hide resolved
python/cuml/prims/stats/covariance.py Outdated Show resolved Hide resolved
python/cuml/prims/stats/covariance.py Outdated Show resolved Hide resolved
python/cuml/prims/stats/covariance.py Show resolved Hide resolved
python/cuml/prims/stats/covariance.py Outdated Show resolved Hide resolved
python/cuml/prims/stats/covariance.py Show resolved Hide resolved
python/cuml/prims/stats/__init__.py Outdated Show resolved Hide resolved
@csadorf
Copy link
Contributor

csadorf commented Jul 10, 2023

/ok to test

Co-authored-by: Simon Adorf <sadorf@nvidia.com>
@Intron7 Intron7 requested a review from csadorf July 10, 2023 14:50
@dantegd dantegd requested a review from cjnolet July 10, 2023 15:59
@Intron7
Copy link
Contributor Author

Intron7 commented Jul 10, 2023

In the doc string for cov for return_mean you write (1 / n) * mean(X) etc., however you return the mean and not this. Shall we also fix this?

@csadorf
Copy link
Contributor

csadorf commented Jul 13, 2023

/ok to test

@Intron7 Intron7 requested a review from cjnolet July 18, 2023 21:34
Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your patience, Severin. I finally got an opportunity to provide feedback about the kernel design.

int start_idx = indptr[row];
int stop_idx = indptr[row+1];

for(int idx = start_idx; idx < stop_idx; idx++){
Copy link
Member

@cjnolet cjnolet Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some concerns about this implementation, but the good news is that I think we have some options that are fairly trivial to fix. If you think about the way SIMT architectures like GPUs work at the hardware level, each warp (grouping of 32 threads) is only able to execute a single instruction at a time. If two threads within a warp need to execute different instructions, the rest of the threads need to stall to wait for those instructions, even if they aren't executing anything. Because of this, we try to design our kernels so that the threads within each warp are 1) able to do a uniform amount of work, and 2) able to execute the same instructions as much as possible. Things like atomics and conditional branching can have an impact on this, which is called warp divergence.

The degree distributions (number of columns for each row) are almost never uniform and are most often highly skewed, sometimes even by power laws. Because of this, you cannot expect good performance by simply having each thread loop through the columns within each row. Sometimes folks perform a permutation of the matrix in order to sort the rows by their degree distributions. This can help a litle, but it's not a feasible solution here because we can't afford to copy the data.

The other piece here is the atomics- they are expensive and they also cause the warps to diverge because the amount of time for each atomic to execute is non-deterministic and based on the number of competing concurrent writes. These collisions are going to compound w/ the means and I would highly suggest removing the fused mean computation by using cupy to compute that.

Memory reads are also impacted by this model because with each thread reading sequential memory locations from the sparse arrays, you aren't able to benefit from coalescing in each warp since each threads won't be reading from sequential locations on each instruction cycle.

For CSR matrices- an efficient way to do this would be schedule some number of warps per block (let's start with somewhere between 1-8), and have each warp work on their own row at a time. A block that contains multiple warps will need to wait for any straggler warps but 1 warp per block could end up causing issues for load balancing. To get a little more intricate, we could perform a differencing of the indptr array, which would give us the degrees for each row, and then perform a couple kernel launches with different numbers of warps per block to make sure we're keeping the warps uniform (enough) for good performance. For a first-pass, though, we can skip launching multiple kernels and just find a good block size that yields reasonable performance on power-law graphs.

For COO: This is the easiest case, since we can essentially just compute the output gramm embarassingly parallel- map each thread to an edge in the array and perform your atomicAdd.

I would also make sure to test the performance of this implementation with a power-law graph. You can use the RMAT generator in pylibraft to generate such a graph. The problem w/ cupy's sparse generator is that the resulting arrays will have a uniform degree distribution, and thus will not match real-world sparse datasets. Power-law is a worst-case so if we work well on those then uniform degree distributions will just yield better perf.

Further, since we are replacing a highly optimized primitive from cusparse here, we should do our due diligence and gather some benchmarks to make sure we aren't introducing any significant regressions in the meantime.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey Corey thank you so much for your insights and tips. I worked on new kernels for both COO and CSR. These are in my testing (so far) much faster than the cupy x.T.dot(x) versions at least for matrix sizes where those still worked. I kept the reduced atomicsadds so that only the upper half of the matrix gets filled in. I'm currently getting the Powergraphs done.
So far for cpx.scipy.sparse.random(100000, 2000,density=0.05,format="csr/coo",dtype=cp.float64,random_state=42) I'm going from around 800 ms for x.T.dot(x) to 8 to 10 ms for the Rawkernels. For my realworld singlecell data they also perform really well. I think that 2-3 warps might be the most performant versions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet can we assume that the coo matrix is sorted?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Intron7 I don't know that we can assume that w/ cupy/scipy. But if we use an element-wise kernel, we shouldn't necessarily have to assume that, should we?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dear @cjnolet,
here are the performance numbers you asked for. I tested my Raw kernels with multiple batch sizes vs the standard SPGEMM algorithms from CuPy. I ran each function 10 times and averaged the runtimes. The coo_kernel results include x.sum_duplicates() runtimes. For 50000000 and more edges the standard libraries stop working. I updated kernels in the branch.

Testing for 50000 edges
(51073, 16382)
csr 10.99555492401123
coo 17.003202438354492
csr_kernel 32 11.6835355758667
csr_kernel 64 9.113883972167969
csr_kernel 128 8.887648582458496
csr_kernel 256 36.73577308654785
csr_kernel 512 54.798269271850586
csr_kernel 1024 55.03363609313965
coo_kernel 32 61.06009483337402
coo_kernel 64 51.66192054748535
coo_kernel 128 61.94412708282471
coo_kernel 256 62.290191650390625
coo_kernel 512 62.026119232177734
coo_kernel 1024 61.32020950317383

Testing for 500000 edges
(29509, 16374)
csr 65.50588607788086
coo 71.40600681304932
csr_kernel 32 54.044485092163086
csr_kernel 64 60.14046669006348
csr_kernel 128 56.35478496551514
csr_kernel 256 46.6019868850708
csr_kernel 512 54.707956314086914
csr_kernel 1024 55.11133670806885
coo_kernel 32 63.839530944824226
coo_kernel 64 64.72692489624023
coo_kernel 128 65.07103443145752
coo_kernel 256 69.69590187072754
coo_kernel 512 64.99478816986084
coo_kernel 1024 64.7336483001709

Testing for 5000000 edges
(43418, 16377)
csr 360.93554496765137
coo 219.34540271759033
csr_kernel 32 34.18407440185547
csr_kernel 64 53.055429458618164
csr_kernel 128 58.417463302612305
csr_kernel 256 52.16631889343262
csr_kernel 512 49.57113265991211
csr_kernel 1024 54.834651947021484
coo_kernel 32 78.15330028533936
coo_kernel 64 76.80673599243164
coo_kernel 128 77.03337669372559
coo_kernel 256 73.25513362884521
coo_kernel 512 76.92813873291016
coo_kernel 1024 77.04839706420898

Testing for 50000000 edges
(64359, 16384)
csr_kernel 32 188.805890083313
csr_kernel 64 158.88402462005615
csr_kernel 128 87.86423206329346
csr_kernel 256 78.75776290893555
csr_kernel 512 76.42595767974854
csr_kernel 1024 70.73354721069336
coo_kernel 32 283.26284885406494
coo_kernel 64 282.31539726257324
coo_kernel 128 282.2613477706909
coo_kernel 256 282.2352886199951
coo_kernel 512 282.2505474090576
coo_kernel 1024 282.18557834625244

Testing for 500000000 edges
(11200, 16384)
csr_kernel 32 585.2307081222534
csr_kernel 64 414.9267911911011
csr_kernel 128 305.6136131286621
csr_kernel 256 256.9816827774048
csr_kernel 512 221.5327024459839
csr_kernel 1024 204.65679168701172
coo_kernel 32 3360.795545578003
coo_kernel 64 3355.547070503235
coo_kernel 128 3356.409192085266
coo_kernel 256 3355.93740940094
coo_kernel 512 3356.1617851257324
coo_kernel 1024 3356.2454223632812

@csadorf
Copy link
Contributor

csadorf commented Jul 27, 2023

/ok to test

@Intron7 Intron7 requested a review from cjnolet July 27, 2023 17:57
@Intron7 Intron7 requested a review from csadorf August 1, 2023 10:54
@csadorf
Copy link
Contributor

csadorf commented Aug 1, 2023

/ok to test

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's a lot more that could be done w/ the kernels but I think they do look a lot better (and this is really only temporary so we don't need to spend too much time / complexity on them).

@@ -102,6 +168,15 @@ def cov(x, y, mean_x=None, mean_y=None, return_gram=False, return_mean=False):
"X and Y must have same shape %s != %s" % (x.shape, y.shape)
)

# Fix for cupy issue #7699: addressing problems with sparse matrix multiplication (spGEMM)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please reference a cuml github issue here in a TODO (and create one if not already created) so that we can track it and know where to apply the fix in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#5475 is the cuml issue and I can also reference it in the comment. Would that be sufficent @cjnolet?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet I've asked that we reference the underlying issue here instead of the cuML issue that will be closed with this PR. Do you want a separate cuML issue that references cupy#7699 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@csadorf yes, I would prefer to reference the cuml issue (and also reference the cupy issue) as the cuml issue local to the repository where the code is hosted, and thus has a stronger link to tracking the work's progress.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added in a reference to the cuml issue #5475

@csadorf
Copy link
Contributor

csadorf commented Aug 1, 2023

/ok to test

@csadorf
Copy link
Contributor

csadorf commented Aug 1, 2023

@Intron7 Please do not merge the upstream branch unless the "Recently Updated" check fails. It generates unnecessary load for our CI system.

@csadorf
Copy link
Contributor

csadorf commented Aug 1, 2023

/ok to test

@cjnolet
Copy link
Member

cjnolet commented Aug 2, 2023

/ok to test

Copy link
Member

@cjnolet cjnolet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks Severin!

Copy link
Contributor

@csadorf csadorf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

@csadorf
Copy link
Contributor

csadorf commented Aug 2, 2023

/merge

@rapids-bot rapids-bot bot merged commit 14d931a into rapidsai:branch-23.08 Aug 2, 2023
52 checks passed
@Intron7
Copy link
Contributor Author

Intron7 commented Aug 2, 2023

Thanks for all your patience and support @csadorf @cjnolet

@csadorf
Copy link
Contributor

csadorf commented Aug 2, 2023

Thanks for all your patience and support @csadorf @cjnolet

Thank's a lot for your contribution!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cython / Python Cython or Python issue improvement Improvement / enhancement to an existing function non-breaking Non-breaking change
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] PCA doesn't work with sparse matrices
3 participants