-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.mm(dense, sparse_csr) #73686
torch.mm(dense, sparse_csr) #73686
Changes from 8 commits
5de502a
f714a13
0f23c94
dddccaf
3cd3346
1f8537a
81dd4f4
8118928
f5e1e61
c71d386
345ce5a
ca596ca
6ec1929
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
#pragma once | ||
|
||
#include <ATen/Tensor.h> | ||
#include <ATen/core/Scalar.h> | ||
|
||
namespace at { | ||
namespace native { | ||
namespace sparse { | ||
namespace impl { | ||
|
||
// Returns true if all entries of self are zero | ||
// TODO: This has potential to be a generic helper | ||
inline bool _is_all_zero(const Tensor& self) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The strided path of this helper function introduces an unnecessary synchronization. I think it's possible to restructure the code and remove this check. Dense |
||
if (self.is_sparse_csr() || self.is_sparse()) { | ||
if (self._nnz() == 0) { | ||
return true; | ||
} | ||
return (self.values().count_nonzero().item<int64_t>() == 0); | ||
} | ||
return (self.count_nonzero().item<int64_t>() == 0); | ||
} | ||
|
||
inline void _check_is_cpu(const Tensor& self, c10::string_view name) { | ||
TORCH_CHECK( | ||
self.is_cpu(), | ||
"Expected all tensors to be on the same device. addmm expected '", | ||
name, | ||
"' to be CPU tensor, but got ", | ||
self.device(), | ||
" tensor"); | ||
} | ||
|
||
inline void _check_is_cuda(const Tensor& self, c10::string_view name) { | ||
TORCH_CHECK( | ||
self.is_cuda(), | ||
"Expected all tensors to be on the same device. addmm expected '", | ||
name, | ||
"' to be CUDA tensor, but got ", | ||
self.device(), | ||
" tensor"); | ||
} | ||
|
||
inline void _check_dim(const Tensor& self, int64_t target_dim, c10::string_view name) { | ||
if (target_dim == 2) { | ||
TORCH_CHECK( | ||
self.dim() == target_dim, | ||
name, " must be a matrix, ", | ||
"got ", self.dim(), "-D tensor"); | ||
} | ||
TORCH_CHECK( | ||
self.dim() == target_dim, | ||
"Expected ", | ||
name, | ||
" to be of dimension ", | ||
target_dim, | ||
" but got ", | ||
self.dim(), | ||
" instead."); | ||
} | ||
|
||
} | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might be already in the generated code. The easy way to check this is to pass tensors on different devices and see whether you get an error from these lines or from somewhere higher in the call chain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, but then we have to update all the error message tests etc. Sounds like good follow up work.