-
Notifications
You must be signed in to change notification settings - Fork 3.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
High-level API for torch.sparse.mm
with optimized spmm_reduce
kernel using CSR format
#6699
Conversation
for more information, see https://pre-commit.ci
…_geometric into spmm_reduce_api
for more information, see https://pre-commit.ci
…_geometric into spmm_reduce_api
for more information, see https://pre-commit.ci
…_geometric into spmm_reduce_api
for more information, see https://pre-commit.ci
…_geometric into spmm_reduce_api
for more information, see https://pre-commit.ci
Removes duplicate edged to the given homogeneous or heterogeneous graph. It will change the original order of dataset by concatenating one of duplicated edges at the end of the dataset. It can be used to clean-up a known repeated self-connecting edges issue in ogbn-products. Reference to ogbn-products Leaderboard: [here](https://ogb.stanford.edu/docs/nodeprop/#:~:text=Note%3A%20A%20very%20small%20number%20of%20self%2Dconnecting%20edges%20are%20repeated%20(see%20here)%3B%20you%20may%20remove%20them%20if%20necessary) Moved this to separate PR from #6699 --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Matthias Fey <matthias.fey@tu-dortmund.de>
…into spmm_reduce_api
reduce = 'sum' if reduce == 'add' else reduce | ||
|
||
if reduce not in ['sum', 'mean', 'min', 'max']: | ||
raise ValueError(f"`reduce` argument '{reduce}' not supported") | ||
|
||
if isinstance(src, SparseTensor): | ||
return torch_sparse.matmul(src, other, reduce) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @JakubPietrakIntel @rusty1s , May I know the logic here? If src
is SparseTensor
which is the default data type, it will never go into the spmm optimized impl in PT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I changed this back for now since torch.sparse.mm
is missing CUDA support. I think we can either patch this in SparseTensor
or patch this here such that we only call torch.sparse.mm
in case PyTorch >= 2.0 and CPU. Otherwise, I think this solution is fine, as we support torch.sparse.Tensor
now anyway (and are in the process of removing torch-sparse
dependency).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understand. So for now, we need add a patch to call torch.sparse.mm
when PT >= 2.0 and CPU. Do you know the time of adding CUDA support?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don‘t have any insights into that. Let me add the optimized routine in a separate PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed this in #6759
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. The fix LGTM.
Related to important optimization in Pytorch:
✅ port sparse_mm.reduce to pytorch and optimize it on CPU #83727
Updating simplified high-level API for
spmm_reduce()
kernel and tests.The current kernel implementation has limitation to process
src
of typetorch.Tensor
intorch.sparse_csr
format, therefore I've added an option to auto-convertsrc
to CSR format usingsrc.to_sparse_csr()
, which isFalse
by default and will result inValueError
if the input is not provided in the correct format.The conversion from
SparseTensor
totorch.Tensor
is enabled by default for Pytorch > 1.13.Added transfrom to remove duplicated in ogbn-products dataset, because the new kernel can't handle duplicate entries (useful for benchmarks).
Re-opened this PR because the draft (#6689) needed to be scrapped after a rebase.