Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.mm(dense, sparse_csr) #73686

Closed
wants to merge 13 commits into from
5 changes: 5 additions & 0 deletions aten/src/ATen/native/BinaryOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,11 @@ Tensor& mul_(Tensor& self, const Scalar& other) {
return at::mul_out(self, wrapped_scalar_tensor(other), self); // redispatch!
}

Tensor& mul__scalar_sparse_csr(Tensor& self, const Scalar& other) {
self.values().mul_(other);
return self;
}

Device correct_out_device(const Tensor& self, const Tensor& other) {
if (self.device() == at::kCPU){
return other.device();
Expand Down
19 changes: 11 additions & 8 deletions aten/src/ATen/native/mkl/SparseBlasImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,18 +340,21 @@ void addmm_out_sparse_csr(
const Scalar& alpha,
const Tensor& result) {
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mat1.dim() == 2 && mat2.dim() == 2 && result.dim() == 2);
if (mat2.layout() == kStrided && result.layout() == kStrided) {
if (mat1.is_sparse_csr() && mat2.layout() == kStrided && result.layout() == kStrided) {
return addmm_dense_result(mat1, mat2, beta, alpha, result);
} else if (
mat1.is_sparse_csr() && mat2.is_sparse_csr() &&
result.layout() == kStrided) {
}
if (mat1.layout() == kStrided && mat2.is_sparse_csr() && result.layout() == kStrided) {
// TODO: We can use MKL's transposition flags once we have CSC support.
return addmm_dense_result(mat2.transpose(0, 1), mat1.transpose(0, 1), beta, alpha, result.transpose(0, 1));
}
if (mat1.is_sparse_csr() && mat2.is_sparse_csr() && result.layout() == kStrided) {
return addmm_sparse_input_dense_result(mat1, mat2, beta, alpha, result);
} else if (mat2.is_sparse_csr() && result.is_sparse_csr()) {
}
if (mat1.is_sparse_csr() && mat2.is_sparse_csr() && result.is_sparse_csr()) {
return addmm_sparse_result(mat1, mat2, beta, alpha, result);
} else {
TORCH_CHECK(false, "addmm: computation on CPU is not implemented for ",
result.layout(), " + ", mat1.layout(), " @ ", mat2.layout());
}
TORCH_CHECK(false, "addmm: computation on CPU is not implemented for ",
result.layout(), " + ", mat1.layout(), " @ ", mat2.layout());
}

/*
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3209,6 +3209,7 @@
variants: method
dispatch:
CompositeExplicitAutograd: mul_
SparseCsrCPU, SparseCsrCUDA: mul__scalar_sparse_csr

# multiply, alias for mul
- func: multiply.Tensor(Tensor self, Tensor other) -> Tensor
Expand Down
146 changes: 66 additions & 80 deletions aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <ATen/native/Resize.h>
#include <ATen/native/mkl/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseBlasImpl.h>
#include <ATen/native/sparse/SparseCsrTensorMath.h>
#include <c10/util/irange.h>

#ifndef AT_PER_OPERATOR_HEADERS
Expand Down Expand Up @@ -411,91 +412,64 @@ void addmm_out_sparse_csr_native_cpu(
}

// Functions for matrix multiplication.
// result = beta * self + alpha (mat1 @ mat2)
Tensor& addmm_out_sparse_csr_cpu(
const Tensor& self,
const Tensor& mat1,
const Tensor& mat2,
const Scalar& beta,
const Scalar& alpha,
Tensor& result) {
TORCH_INTERNAL_ASSERT(mat1.is_sparse_csr());

// TODO: remove this, there are no codegenerated checks for devices yet
TORCH_CHECK(
!self.is_cuda(),
"Expected all tensors to be on the same device. addmm expected 't' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(
!result.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'out' to be CPU tensor, but got CUDA tensor");
TORCH_CHECK(
!mat1.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'mat1' to be a CPU tensor, but got a CUDA tensor");
TORCH_CHECK(
!mat2.is_cuda(),
"Expected all tensors to be on the same device. addmm: expected 'mat2' to be a CPU tensor, but got a CUDA tensor");
sparse::impl::_check_is_cpu(self, "self");
sparse::impl::_check_is_cpu(mat1, "mat1");
sparse::impl::_check_is_cpu(mat2, "mat2");
sparse::impl::_check_is_cpu(result, "result");
Comment on lines +424 to +427
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be already in the generated code. The easy way to check this is to pass tensors on different devices and see whether you get an error from these lines or from somewhere higher in the call chain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but then we have to update all the error message tests etc. Sounds like good follow up work.


// All the checks are from addmm_out_cuda_impl (ATen/native/cuda/Blas.cpp) and
// TORCH_META_FUNC(addmm) (ATen/native/LinearAlgebra.cpp)
// TODO: remove code duplication and unify code
TORCH_CHECK(
mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor");
TORCH_CHECK(
mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor");
TORCH_CHECK(
mat1.sizes()[1] == mat2.sizes()[0],
"mat1 and mat2 shapes cannot be multiplied (",
mat1.sizes()[0],
"x",
mat1.sizes()[1],
" and ",
mat2.sizes()[0],
"x",
mat2.sizes()[1],
")");

IntArrayRef mat1_sizes = mat1.sizes();
IntArrayRef mat2_sizes = mat2.sizes();
IntArrayRef self__sizes;
c10::MaybeOwned<Tensor> self_;
if (&result != &self && self.layout() == kStrided) {
self_ = expand_size(self, {mat1_sizes[0], mat2_sizes[1]}, "addmm");
self__sizes = self_->sizes();
} else {
self_ = c10::MaybeOwned<Tensor>::borrowed(self);
self__sizes = self_->sizes();
}
sparse::impl::_check_dim(mat1, 2, "mat1");
sparse::impl::_check_dim(mat2, 2, "mat2");

TORCH_CHECK(
((self_->dim() == 2) && (self_->sizes()[0] == mat1.sizes()[0]) &&
(self_->sizes()[1] == mat2.sizes()[1])),
"The input tensor must be a matrix with size ",
mat1.sizes()[0],
"x",
mat2.sizes()[1],
", but got a ",
self_->dim(),
"-D tensor with size ",
self__sizes[0],
"x",
self__sizes[1]);
mat1.size(1) == mat2.size(0), "mat1 and mat2 shapes cannot be multiplied (",
mat1.size(0), "x", mat1.size(1), " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")");

c10::MaybeOwned<at::Tensor> self_ =
expand_size(self, {mat1.size(0), mat2.size(1)}, "addmm");


TORCH_CHECK(((self_->dim() == 2) &&
(self_->size(0) == mat1.size(0)) &&
(self_->size(1) == mat2.size(1))),
"The input tensor must be a matrix with size ",
mat1.size(0),
"x",
mat2.size(1),
", but got a ",
self_->dim(),
"-D tensor with size ",
self_->size(0),
"x",
self_->size(1));

if (&result != &self) {
if (result.layout() == kStrided) {
at::native::resize_output(result, self__sizes);
result.resize_as_(*self_);
} else {
at::native::resize_as_sparse_csr_(result, *self_);
result.resize_as_sparse_(*self_);
}
result.copy_(*self_);
}

IntArrayRef result_sizes = result.sizes();
if ((result_sizes[0] == 0) || (result_sizes[1] == 0)) {
if (result.numel() == 0) {
return result;
}

if (mat1._nnz() == 0 && mat2.layout() == kStrided) {
// According to docs, when beta==0 values in self should be ignored. nans
// and infs should not propagate
if (sparse::impl::_is_all_zero(mat1) || sparse::impl::_is_all_zero(mat2)) {
// According to docs, when beta==0 values in self should be ignored.
// nans and infs should not propagate
if (beta.toComplexDouble() == 0.) {
result.zero_();
} else {
Expand All @@ -504,22 +478,13 @@ Tensor& addmm_out_sparse_csr_cpu(
return result;
}

if (mat2.is_sparse_csr() && (mat1._nnz() == 0 || mat2._nnz() == 0)) {
if (beta.toComplexDouble() == 0.) {
result.values().zero_();
} else {
result.values().mul_(beta);
}
return result;
}

#if !AT_USE_MKL_SPARSE()
if (mat2.is_sparse_csr() && result.is_sparse_csr()) {
TORCH_CHECK(
false,
"Calling addmm on sparse CPU tensors requires Linux platform. ",
"Please use PyTorch built with MKL on Linux.");
}
TORCH_CHECK(
(mat1.is_sparse_csr() ||
(mat2.is_sparse_csr() && result.is_sparse_csr())),
false,
"Calling addmm on sparse CPU tensors requires Linux platform. ",
"Please use PyTorch built with MKL on Linux.");
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(result.layout() == kStrided);
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
result.scalar_type(), "addmm_sparse_dense", [&] {
Expand Down Expand Up @@ -558,14 +523,35 @@ Tensor& _sparse_csr_mm_out(
}

Tensor _sparse_csr_mm(const Tensor& mat1, const Tensor& mat2) {
Tensor zero;
if (mat1.is_sparse_csr() && mat2.is_sparse_csr()) {
// Return sparse
// TODO: replace with at::zeros when it's implemented for sparse csr
zero = at::empty({mat1.size(0), mat2.size(1)}, mat2.options());
} else {
zero = at::zeros({mat1.size(0), mat2.size(1)}, mat2.options());
return at::addmm(
at::empty({mat1.size(0), mat2.size(1)}, mat2.options()),
mat1,
mat2,
0.0,
1.0);
}
if (mat1.is_sparse_csr() && mat2.layout() == c10::kStrided) {
// Return dense
return at::addmm(
at::zeros({mat1.size(0), mat2.size(1)}, mat2.options()),
mat1,
mat2,
0.0,
1.0);
}
if (mat1.layout() == c10::kStrided && mat2.is_sparse_csr()) {
// Return dense
return at::addmm(
at::zeros({mat1.size(0), mat2.size(1)}, mat1.options()),
mat1,
mat2,
0.0,
1.0);
}
return at::addmm(zero, mat1, mat2, 0.0, 1.0);
TORCH_INTERNAL_ASSERT(false, "Shouldn't get here. Please open an issue.");
}

Tensor _sparse_csr_addmm(
Expand Down
64 changes: 64 additions & 0 deletions aten/src/ATen/native/sparse/SparseCsrTensorMath.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#pragma once

#include <ATen/Tensor.h>
#include <ATen/core/Scalar.h>

namespace at {
namespace native {
namespace sparse {
namespace impl {

// Returns true if all entries of self are zero
// TODO: This has potential to be a generic helper
inline bool _is_all_zero(const Tensor& self) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The strided path of this helper function introduces an unnecessary synchronization. I think it's possible to restructure the code and remove this check. Dense addmm doesn't have checks for zero. nnz == 0 is checked because underlying backend libraries might not handle this case correctly, raising an error, so we need to manually fix the results.

if (self.is_sparse_csr() || self.is_sparse()) {
if (self._nnz() == 0) {
return true;
}
return (self.values().count_nonzero().item<int64_t>() == 0);
}
return (self.count_nonzero().item<int64_t>() == 0);
}

inline void _check_is_cpu(const Tensor& self, c10::string_view name) {
TORCH_CHECK(
self.is_cpu(),
"Expected all tensors to be on the same device. addmm expected '",
name,
"' to be CPU tensor, but got ",
self.device(),
" tensor");
}

inline void _check_is_cuda(const Tensor& self, c10::string_view name) {
TORCH_CHECK(
self.is_cuda(),
"Expected all tensors to be on the same device. addmm expected '",
name,
"' to be CUDA tensor, but got ",
self.device(),
" tensor");
}

inline void _check_dim(const Tensor& self, int64_t target_dim, c10::string_view name) {
if (target_dim == 2) {
TORCH_CHECK(
self.dim() == target_dim,
name, " must be a matrix, ",
"got ", self.dim(), "-D tensor");
}
TORCH_CHECK(
self.dim() == target_dim,
"Expected ",
name,
" to be of dimension ",
target_dim,
" but got ",
self.dim(),
" instead.");
}

}
}
}
}
4 changes: 2 additions & 2 deletions aten/src/ATen/native/sparse/SparseTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,8 +709,8 @@ Tensor& mul_sparse_(Tensor& self, const Tensor& other) {

Tensor& mul_out_sparse_csr(const Tensor& t_, const Tensor& src_, Tensor& r) {
// // TODO: Use a specialized CSR kernel for performance if needed
TORCH_CHECK(t_.is_sparse_csr(), "mul(dense, sparse_csr) is not supported");
TORCH_CHECK(src_.is_sparse_csr(), "mul(sparse_csr, dense) is not supported");
TORCH_CHECK(t_.is_sparse_csr() || (t_.layout() == c10::kStrided && t_.dim() == 0), "mul(dense, sparse_csr) is not supported");
TORCH_CHECK(src_.is_sparse_csr() || (src_.layout() == c10::kStrided && src_.dim() == 0), "mul(sparse_csr, dense) is not supported");
TORCH_CHECK(r.is_sparse_csr(), "Expected result Tensor to be of format CSR");
Tensor t = t_.to_sparse();
Tensor src = src_.to_sparse();
Expand Down