Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding sparse addmv and triangular_solve support on CPU - Mac OS - Apple Silicon M2 #96972

Open
tvercaut opened this issue Mar 16, 2023 · 21 comments
Assignees
Labels
module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: macos Mac OS related issues module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@tvercaut
Copy link

tvercaut commented Mar 16, 2023

馃殌 The feature, motivation and pitch

As discussed in #77764 (comment) it would be helpful to have a CPU fallback for torch.addmv on Mac OS when using sparse matrices.

Alternatives

Run on google colab or similar

Additional context

Steps to reproduce the issue on Mac OS:

import torch
print(f"Running PyTorch version: {torch.__version__}")

dtype = torch.float32
device = torch.device("cpu")
#device = torch.device("mps")
print(f"Using device: {device}")

#mat = torch.randn((4,4), dtype=dtype, device=device)
mat = torch.randn((4,4), dtype=dtype, device=device).relu().to_sparse_csr()
mvec = torch.randn((4,), dtype=dtype, device=device)
avec = torch.randn((4,), dtype=dtype, device=device)

ovec = torch.addmv(avec, mat, mvec)
print(ovec)

leading to

UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/SparseCsrTensorImpl.cpp:56.)
  mat = torch.randn((4,4), dtype=dtype, device=device).relu().to_sparse_csr()
Traceback (most recent call last):
  File "[...]/test.py", line 14, in <module>
    ovec = torch.addmv(avec, mat, mvec)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Calling addmv on a sparse CPU tensor requires compiling PyTorch with MKL. Please use PyTorch built MKL support.

It may be that MKL can be compiled for Mac OS (and thus shipped in the default pytorch distribution for mac) or maybe an less optimised alternative needs to be found (e.g. Eigen).

cc @alexsamardzic @nikitaved @pearu @cpuhrsch @amjames @bhosmer @malfet @albanD

@leovinus2001
Copy link

leovinus2001 commented Mar 16, 2023

Have you looked whether you can do it via the Apple Accelerate framework? There were extensions beyond dense BLAS for sparse matrices but long ago since I used it it. IMHO thinking about Intel MKL on Apple Silicon is a waste of time as Apple has a team for such CPU libraries. Just like with Metal/MPS/GPU. Might be easier to just ask them. https://developer.apple.com/documentation/accelerate/creating_sparse_matrices
https://developer.apple.com/documentation/accelerate/sparse_solvers/sparse_matrix_and_dense_matrix_multiplication

@tvercaut
Copy link
Author

Thanks @leovinus2001. As discussed by @cpuhrsch in #77764 (comment), my understanding is that the PyTorch team currently doesn't have the bandwidth to support the integration of additional platform-specific libraries to support sparse tensors.

So yes, tapping into Apple Accelerate would be neat but given these constraints and given that MKL is already used throughout PyTorch, even if not optimal on ARM / Silicon Macs, it could make sense to try it. If MKL can be used on Mac, it would certainly help "quickly" bridge the current gap in functionality.

On a related note, I just encountered another such missing functionality:

RuntimeError: Calling triangular_solve on a sparse CPU tensor requires compiling PyTorch with MKL. Please use PyTorch built MKL support.

Also interestingly, I realised that a workaround for the missing addmv can be to use torch.mm (with unsqueeze/squeeze as needed).

@malfet malfet added module: macos Mac OS related issues module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: sparse Related to torch.sparse labels Mar 17, 2023
@malfet
Copy link
Contributor

malfet commented Mar 17, 2023

Is there a generic performant sparse library? For example, can ARM Compute Library be used here? See https://developer.arm.com/documentation/101004/2030/Sparse-Linear-Algebra/Example-of-SpMV-usage?lang=en

@cpuhrsch
Copy link
Contributor

@malfet - Can we ship mkldnn on Mac? It sounds like this more of a packaging problem assuming mkldnn is supported on mac.

@tvercaut
Copy link
Author

Maybe a point of helpful clarification: What does MKL exactly mean in the context of PyTorch?

If it's what is now refered to as OneMKL then this doesn't seem to have an ARM64 backend unfortunately (yet?). This is also confirmed here.

If it's what is now being referred to as OneDNN (and previously MKL-DNN or DNNL), then this seems to support ARM64.

@leovinus2001
Copy link

Thanks for the clarifications!

@tvercaut

Also interestingly, I realised that a workaround for the missing addmv() can be to use torch.mm (with unsqueeze/squeeze as needed).

Firstly, in that context, to go for a quick fix, it seems that IF torch.addmm () works fine on MPS then that is a potential solution for now. Just reshape the vector and voila. Ps: I have mapped sgemv() to sgemm () like that in the past ;)

Secondly, regarding MKL and lack of resources to support a new backend, it is a challenge. Obviously a topic for another thread. The recent MPS support is a great example though of what is possible.

However, I would say that using Accelerate/CPU would be a quality improvement for PyTorch on macOS in future. Also for iOS. When I built PyTorch from scratch and see 3 possible BLAS options on macOS but not the native Accelerate then it just makes me sad. Using one alternative BLAS on macOS for debugging makes sense but to get the most speed and performance, the Accelerate solutions are the best and most forward compatible.

Finally, am happy to help where I can. If you like a simple macOS code prototypes or proof-of-concept of addmv() or gemv() for dense and sparse then I am happy to contribute along the lines what we discussed earlier for torch.mm() and GEMM. Here I made some prototypes for earlier issues
#81185 (comment)
and
#81185 (comment)

@leovinus2001
Copy link

@malfet - Can we ship mkldnn on Mac? It sounds like this more of a packaging problem assuming mkldnn is supported on mac.

A word of caution please. Whether right or wrong, it seems to me that the sole purpose of using MKLDNN on macOS is to get "something going". That might seem appealing short term but long term it is counter productive. MKLDNN was built with Intel in mind, and Apple has moved on to Apple Silicon and more.

In principle, the goal of PyTorch macOS support is to please the PyTorch users with best performance on macOS right? That is always the Apple user perspective anyway "the best for the user" and the recent MPS support and code based on Apple and community work seems a good example. IMHO it seems to me that MKLDNN is not part of that road. Similar like my earlier observation on the Accelerate Sparse Matrix support and native support.

Happy to help where I can. A nice question for another thread would be something like "what MKLDNN calls would be necessary anyway on macOS?""

@malfet
Copy link
Contributor

malfet commented Mar 17, 2023

@malfet - Can we ship mkldnn on Mac? It sounds like this more of a packaging problem assuming mkldnn is supported on mac.

@cpuhrsch yeah, we can, but I though you need MKL(which is exclusively Intel), rather than MKLDNN (which indeed has some ARM support)

@cpuhrsch
Copy link
Contributor

@malfet - Yes, sorry, you're right. I meant MKL.

@malfet
Copy link
Contributor

malfet commented Mar 17, 2023

@cpuhrsch - MKL is developed by Intel and distributed in binary only form, so there isn't much one can do to make it work on M1. But integrating with Accelerate doesn't sound like such a bad idea, as we already doing it to BLAS/LAPACK on M1.

@cpuhrsch
Copy link
Contributor

@malfet - Sure, but maybe it can still run M1 even if very inefficiently. That seems like potentially an easier fix to unblock baseline coverage than integrating with a new backend library. We'll need to spend some time internally discussing how we best go about M1 coverage for sparsity.

@tvercaut
Copy link
Author

Thanks @malfet, @cpuhrsch. I didn't realise Apple Accelerate was already integrated in PyTorch. I quick search shows that it actually even is used in critical functions such as ReLU:

#ifdef CAFFE2_USE_ACCELERATE
template <>
template <>
bool ReluFunctor<CPUContext>::operator()<float>(
const int N,
const float* X,
float* Y,
CPUContext* /* context */) const {
const float zero = 0.0f;
vDSP_vthres(X, 1, &zero, Y, 1, N);
return true;
}
#endif // CAFFE2_USE_ACCELERATE

So given that OneMKL doesn't (yet?) provide a generic or Apple-specific backend, relying on Apple Accelerate for addmv, triangular_solve and the like on MAC cpus seems like a reasonable option to me.

Alternatively, one could also rely on Eigen for a more generic but probably less optimised approach.

Anyway, I will not be the person implementing and maintaing this, so not my call :)

@tvercaut tvercaut changed the title Adding sparse addmv support on CPU - Mac OS - Apple Silicon M2 Adding sparse addmv and triangular_solve support on CPU - Mac OS - Apple Silicon M2 Mar 21, 2023
@malfet
Copy link
Contributor

malfet commented Mar 21, 2023

@malfet - Sure, but maybe it can still run M1 even if very inefficiently. That seems like potentially an easier fix to unblock baseline coverage than integrating with a new backend library. We'll need to spend some time internally discussing how we best go about M1 coverage for sparsity.

One can not mix-n-match architectures. I.e. one can always download x86 build for PyTorch on their M1, and it will run, but will be quote slow. As @tvercaut correctly pointed out, Accelerate is already used for BLAS/LAPACK implementation, and can be extended for some sparse ops as well.

On the other hand, it would also seems beneficial to have a reference implementation for say Linux aarch64 platform.

@malfet malfet self-assigned this Mar 21, 2023
@tvercaut
Copy link
Author

If I understand the suggestion of using the x86 build, I must say that running a complete pytorch-based project through Rosetta 2 doesn't sound very appealing to me.

I however do agree that having a reference portable CPU-based implementation of the sparse operations currently supported through OneMKL would make a lot of sense. I guess this would mean prioritising porting /mkl/SparseBlasImpl.cpp to Eigen before porting it to say Apple Accelerate or Arm Performance Libraries.

On the GPU side, unfortunately, there doesn't seem to be any sparse matrix support in MPS but maybe this will eventually be possible through a more generic vulkan based backend?

malfet added a commit that referenced this issue Mar 22, 2023
Partially addresses the problem raised in #96972

Add `test_addmv` and enable `test_block_addmv` on all platforms

TODO: Make sure that test_block_addmv non-contiguous mode actually
generate non-contiguous as rigth now it probably does not, as test
passes assuming values are contiguous.
@malfet
Copy link
Contributor

malfet commented Mar 22, 2023

@tvercaut Eigen used to have a very complicated licensing rules, but I see that later one is just MPL license, so indeed this is a good idea. Also, Accelerate supports only float and double tensors, while Eigen would probably work with complex ones as well. In the meantime, I've made a very simple addmv implementation in #97353

As for MPS support, you are more then welcome to write a metal kernel and I will review the PR

@leovinus2001
Copy link

@tvercaut Eigen used to have a very complicated licensing rules, but I see that later one is just MPL license, so indeed this is a good idea.

Agreed on Eigen. I used to avoid it like the plague for exactly this license reason as well.

@tvercaut Also, Accelerate supports only float and double tensors, while Eigen would probably work with complex ones as well.

That is not correct. Accelerate has BLAS implementations for single and double complex as well. As you know the naming scheme is SGEMM/DGEMM for single and double precision float and then CGEMM/ZGEMM for single and double precision complex calculations.

For example for ZGEMM is here
https://developer.apple.com/documentation/accelerate/1513094-cblas_zgemm?language=objc

Granted, the docs can be byzantine but I have used the Accelerate SGEMM/DGEMM for more than 5 years on macOS and iOS which makes me think that CGEMM/ZGEMM are just as old.

@tvercaut In the meantime, I've made a very simple addmv implementation in #97353

Sounds like fun, I'll have a look :)

Some other relevant thoughts

  1. It might be obvious to you all but there is another advantge to using the Apple Accelerate framework instead of a 3rd party Eigen, OpenBlas, etc library. The advantage is simply "speed" or as Apple like to say "performance". An article like this one on the AMX matrix processing, specifally for GEMM-like matrix calculations and an achievable 2x speed up on matrix calculations, is a fun example.

In other words, when you compile general ARM plus vector instructions from Eigen and OpenBlas, you will not always be able to achieve the same speed as the Accelerate with its internal AMX instruction use. And like BNNS, possible even ANE use for some datatypes datatypes.

The Secret Apple M1 Coprocessor AMX
https://medium.com/swlh/apples-m1-secret-coprocessor-6599492fc1e1

  1. For GEMM operations on int8 and fp16, I used to use BNNS with a linear layer for same reasons as (1)

  2. Finally, I wonder whether you are familiar with the Apple SIMD datatypes like vector_int16 or vector_float16 as vector of sixteen 32-bit signed integer or float elements. The advantage of using them with an Apple clang for macOS and iOS is that it can be compiled into the "best" instruction sequence while the using of more generic data types and 3rd party libraries would not always be able to achieve that speed. This datatype is probably not so relevant for the sparse add_mv discussion here but then again, another tool to achieve speed and simple and maintainable high-performance code.

https://developer.apple.com/documentation/accelerate/simd?language=objc

@tvercaut
Copy link
Author

Accelerate has BLAS implementations for single and double complex as well.

The issue discussed here relates to sparse matrix operations. These are not covered by BLAS / LAPACK. As far as the Accelerate documentation goes, only float and double are supported for sparse matrices:
https://developer.apple.com/documentation/accelerate/creating_sparse_matrices?language=objc

It might be obvious to you all but there is another advantge to using the Apple Accelerate framework instead of a 3rd party Eigen, OpenBlas, etc library.

Benchmarking these is complex but sure, I would expect Apple libs to perform better on Apple hardware than third-party ones. My take is that for a project like PyTorch, it makes sense to get a generic implementation that works mostly everywhere before focusing on getting the most performance out of a single platform. So to me (but I have no say in PyTorch governance or strategy), it's not "one or the other" but maybe just "one before the other"...

That said, I am sure the devs would love to get more PRs to improve performance on platform-specific paths.

Also my understanding is that for dense ops, PyTorch can already benefit from the most efficient BLAS / LAPACK implementation on any given platform by linking with the corresponding library. There shouldn't be any need to tap into an Apple-specific API there. I would even guess that PyTorch is already using Apple Accelerate's BLAS / LAPACK in the default binary distribution:

otool -L .pyenv/versions/3.11.2/envs/my-virtual-env-3.11.2/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib | grep Accelerate        
	/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate (compatibility version 1.0.0, current version 4.0.0)

@leovinus2001
Copy link

Accelerate has BLAS implementations for single and double complex as well.

The issue discussed here relates to sparse matrix operations. These are not covered by BLAS / LAPACK. As far as the Accelerate documentation goes, only float and double are supported for sparse matrices: https://developer.apple.com/documentation/accelerate/creating_sparse_matrices?language=objc

Good point :) Indeed I forgot for a moment that we are discussing sparse implementations here and there the complex type is missing in Accelerate it seems. My apologies. I will see whether I can file an improvement request with Apple on that.

It might be obvious to you all but there is another advantge to using the Apple Accelerate framework instead of a 3rd party Eigen, OpenBlas, etc library.

Benchmarking these is complex but sure, I would expect Apple libs to perform better on Apple hardware than third-party ones. My take is that for a project like PyTorch, it makes sense to get a generic implementation that works mostly everywhere before focusing on getting the most performance out of a single platform. So to me (but I have no say in PyTorch governance or strategy), it's not "one or the other" but maybe just "one before the other"...

We are very much in agreement here. For something like PyTorch, two setups and CPU/BLAS libraries makes sense. One cross-platform and known library is great for bootstrapping, reference cases and unit tests. It would even work with Linux-on-M1-Mac. And one that uses the most high performance stuff on the target platform like what has been done now with GPU/MPS and PyTorch.

Also my understanding is that for dense ops, PyTorch can already benefit from the most efficient BLAS / LAPACK implementation on any given platform by linking with the corresponding library. There shouldn't be any need to tap into an Apple-specific API there. I would even guess that PyTorch is already using Apple Accelerate's BLAS / LAPACK in the default binary distribution:

otool -L .pyenv/versions/3.11.2/envs/my-virtual-env-3.11.2/lib/python3.11/site-packages/torch/lib/libtorch_cpu.dylib | grep Accelerate        
	/System/Library/Frameworks/Accelerate.framework/Versions/A/Accelerate (compatibility version 1.0.0, current version 4.0.0)

To be honest, I am not sure about that. For example, I am not sure that the cblas_dgemm() from Accelerate is used by PyTorch latest on any Apple macOS or iOS setup.

The "otool" command you mention is a great check but it only says that Accelerate is used for "something". As we saw earlier in our discussion, the current PyTorch uses a reference to vDSP_vthres() in Accelerate for ReLU which could entirely explain your otool output.

I did investigate this very question on cblas_dgemm()-via-Accelerate a few years ago, probably with PyTorch 1.07 or similar. I also used "otool" and built PyTorch from scratch on macOS and grep'd the source for the relevant calls. At that time, I was unable to find any reference to Accelerate GEMM operations in use on macOS. This might have been changed later of course. Also, when building PyTorch from scratch, the build scripts allowed me to build on macOS with MKL and OpenBlas GEMM but the fleeting hint at the Accelerate GEMMs, I was unable to make it work. Seemed like that was incomplete. Hence I concluded that cblas_sgemm() et al via Accelerate was unsupported at the time in PyTorch.

Personally, I would be most happy if even the dense-only cblas_sgemm() and cblas_dgemm() from Accelerate/CPU are in use on Apple Silicon by PyTorch latest. Will have another look at that. Or if anyone know for sure, feel free to confirm, thanks.

Back to sparse add_mv() examples :)

pytorchmergebot pushed a commit that referenced this issue Mar 24, 2023
Partially addresses the problem raised in #96972

Add `test_addmv` and enable `test_block_addmv` on all platforms (so the test could be run on M1)

TODO: Make sure that test_block_addmv non-contiguous mode actually
generate non-contiguous as rigth now it probably does not, as test
passes assuming values are contiguous.

Pull Request resolved: #97353
Approved by: https://github.com/cpuhrsch
@leovinus2001
Copy link

leovinus2001 commented Mar 24, 2023

Accelerate has BLAS implementations for single and double complex as well.

The issue discussed here relates to sparse matrix operations. These are not covered by BLAS / LAPACK. As far as the Accelerate documentation goes, only float and double are supported for sparse matrices: https://developer.apple.com/documentation/accelerate/creating_sparse_matrices?language=objc

Good point :) Indeed I forgot for a moment that we are discussing sparse implementations here and there the complex type is missing in Accelerate it seems. My apologies. I will see whether I can file an improvement request with Apple on that.

PS: To request missing complex and more sparse CPU functionality in Accelerate, I have opened a bug report aka FeedbackRequest (FB) with Apple as number FB12078666 at https://feedbackassistant.apple.com/feedback/12078666

@malfet
Copy link
Contributor

malfet commented Mar 30, 2023

@tvercaut so addmv should be available on Apple silicon now, will see what can be done about triangular_solve later.

@tvercaut
Copy link
Author

tvercaut commented Apr 15, 2023

EDIT: Removed teh need for PYTORCH_ENABLE_MPS_FALLBACK=1 since scatter_add is supported on MPS (but not teh more generic scatter_reduce).

Somehow out of topic for this issue focused on CPU support but anyway. For what it's worth, I tried implementing a workaround for the lack of sparse-dense multiplication on MPS and got a proof of concept running. If anyone is interested, see below.

Code snippet
import torch

print(f"Running PyTorch version: {torch.__version__}")

torchdevice = torch.device("cpu")
if torch.cuda.is_available():
    torchdevice = torch.device("cuda")
    print("Default GPU is " + torch.cuda.get_device_name(torch.device("cuda")))
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
    torchdevice = torch.device("mps")
    print("mps backend available")
print("Running on " + str(torchdevice))

import numpy as np
import typing

# Hacky workaround for lack of CSR matrices on mps
class SparseCsrHack:
    _crow_indices: torch.Tensor
    _col_indices: torch.Tensor
    _values: torch.Tensor
    _size: typing.Union[list, tuple, torch.Size]
    is_sparse = True
    layout = torch.sparse_csr

    def __init__(
        self,
        crow_indices: torch.Tensor,
        col_indices: torch.Tensor,
        values: torch.Tensor,
        size: typing.Union[list, tuple, torch.Size],
    ):
        self._crow_indices = crow_indices
        self._col_indices = col_indices
        self._values = values
        self._size = torch.Size(size)

    def crow_indices(self) -> torch.Tensor:
        return self._crow_indices

    def col_indices(self) -> torch.Tensor:
        return self._col_indices

    def values(self) -> torch.Tensor:
        return self._values

    def size(self) -> torch.Size:
        return self._size

# Hacky workaround for lack of COO matrices on mps
class SparseCooHack:
    _indices: torch.Tensor
    _values: torch.Tensor
    _size: typing.Union[list, tuple, torch.Size]
    is_sparse = True
    layout = torch.sparse_coo

    def __init__(self, indices: torch.Tensor, values: torch.Tensor, size: typing.Union[list, tuple, torch.Size]):
        self._indices = indices
        self._values = values
        self._size = torch.Size(size)

    def indices(self) -> torch.Tensor:
        return self._indices

    def values(self) -> torch.Tensor:
        return self._values

    def size(self) -> torch.Size:
        return self._size


# Hacky and incomplete workaround for lack of spmv / spmm on mps
def spmv_hack(A, x):
    if not A.is_sparse:
        return A @ x
    if A.layout == torch.sparse_coo:
        A_row_idx, A_col_idx = A.indices()
    elif A.layout == torch.sparse_csr:
        A_col_idx = A.col_indices()
        A_crow_idx = A.crow_indices()
        # Uncompress row indices:
        A_row_idx = torch.repeat_interleave(A_crow_idx[1:] - A_crow_idx[:-1])
        # workaround for "UserWarning: MPS: no support for int64 repeats mask, casting it to int32"
        if A_row_idx.dtype != A_crow_idx.dtype:
            A_row_idx = A_row_idx.to(dtype=A_crow_idx.dtype)
    rep_x = torch.index_select(x, dim=0, index=A_col_idx)
    Arepx = A.values()[..., None] * rep_x
    etmp = torch.empty((A.size()[0], x.size()[1]), device=x.device, dtype=x.dtype)
    b = etmp.scatter_add(0, A_row_idx[..., None].repeat(1, x.size()[1]), Arepx)
    return b


# Super simple test
s1 = 10
s2 = 3
s3 = 2
A_dense = torch.randn((s1, s2), dtype=torch.float32, device=torchdevice).relu()
if A_dense.device.type != "mps":
    A_csr = A_dense.to_sparse_csr()
    A_coo = A_dense.to_sparse_coo()
else:
    A_csr_cpu = A_dense.cpu().to_sparse_csr()
    A_csr = SparseCsrHack(
        crow_indices=A_csr_cpu.crow_indices().to(A_dense.device),
        col_indices=A_csr_cpu.col_indices().to(A_dense.device),
        values=A_csr_cpu.values().to(A_dense.device),
        size=A_csr_cpu.size(),
    )
    A_coo_cpu = A_dense.cpu().to_sparse_coo()
    A_coo = SparseCooHack(
        indices=A_coo_cpu.indices().to(A_dense.device),
        values=A_coo_cpu.values().to(A_dense.device),
        size=A_coo_cpu.size(),
    )
x = torch.randn((s2, s3), dtype=torch.float32, device=torchdevice)

np.set_printoptions(precision=3)

b_ref = A_dense @ x
print(f"b_ref=\n{b_ref.T.cpu().numpy()}\n")

if x.device.type != "mps":
    b_fromcsr = A_csr @ x
    print(f"b_fromcsr=\n{b_fromcsr.T.cpu().numpy()}")
    print(f"allclose:{torch.allclose(b_fromcsr,b_ref)}\n")

b_fromcsrhack = spmv_hack(A_csr, x)
print(f"b_fromcsrhack=\n{b_fromcsrhack.T.cpu().numpy()}")
print(f"allclose:{torch.allclose(b_fromcsrhack,b_ref)}\n")

if x.device.type != "mps":
    b_fromcoo = A_coo @ x
    print(f"b_fromcoo=\n{b_fromcoo.T.cpu().numpy()}")
    print(f"allclose:{torch.allclose(b_fromcoo,b_ref)}\n")

b_fromcoohack = spmv_hack(A_coo, x)
print(f"b_fromcoohack=\n{b_fromcoohack.T.cpu().numpy()}")
print(f"allclose:{torch.allclose(b_fromcoohack,b_ref)}\n")

b_fromdensehack = spmv_hack(A_dense, x)
print(f"b_fromdensehack=\n{b_fromdensehack.T.cpu().numpy()}")
print(f"allclose:{torch.allclose(b_fromdensehack,b_ref)}\n")

Let me know if you think anything in there warrants creating a new issue. And apologies for the duplicate but I also mentioned this in #77764 (comment)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: arm Related to ARM architectures builds of PyTorch. Includes Apple M1 module: macos Mac OS related issues module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants