Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse CSR: Add torch.sin #68123

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 12 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Expand Up @@ -274,6 +274,18 @@ Tensor empty_like(

auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Preserve);

if (options.layout() == kSparseCsr && self.is_sparse_csr()) {
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
auto result = at::native::_sparse_csr_tensor_unsafe(
self.crow_indices().clone(),
self.col_indices().clone(),
at::empty(self.values().sizes(), options.layout(kStrided)),
self.sizes(),
c10::typeMetaToScalarType(options.dtype()),
options.layout(),
options.device()
);
}

if (self.is_quantized()) {

// TODO: To support all features of MemoryFormat::Preserve we need to add
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -3819,18 +3819,23 @@
device_check: NoCheck # TensorIterator
structured_delegate: sin.out
variants: function, method
dispatch:
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr

- func: sin_(Tensor(a!) self) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured_delegate: sin.out
variants: function, method
dispatch:
SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr_

- func: sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
structured: True
structured_inherits: TensorIteratorBase
dispatch:
CPU, CUDA: sin_out
SparseCsrCPU, SparseCsrCUDA: sin_sparse_csr_out

- func: sinc(Tensor self) -> Tensor
structured_delegate: sinc.out
Expand Down
6 changes: 4 additions & 2 deletions aten/src/ATen/native/sparse/SparseCsrTensor.cpp
Expand Up @@ -311,9 +311,11 @@ const SparseCsrTensor& resize_as_sparse_csr_(
const SparseCsrTensor& src) {
TORCH_CHECK(
src.is_sparse_csr() && self.is_sparse_csr(),
"resize_as_sparse_csr_: layout for self and src must be sparse_csr but got self, src: ",
"resize_as_sparse_csr_: layout for self and src must be sparse_csr but got ",
self.layout(),
src.layout());
" for self, and ",
src.layout(),
" for src");
Comment on lines +316 to +318
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Earlier, the message ended with something like:

but got self, src: kStridedkStrided

if (!_is_same_size_as_sparse_csr(self, src)) {
get_sparse_csr_impl(self)->resize_as_sparse_csr_tensor_(src);
}
Expand Down
60 changes: 60 additions & 0 deletions aten/src/ATen/native/sparse/SparseCsrTensorMath.cpp
Expand Up @@ -60,6 +60,29 @@ void convert_indices_from_coo_to_csr_cpu(const Tensor& result, const Tensor& inp
data_out[i] = static_cast<output_t>(numel);
}

template <typename F, typename ...Args>
Tensor& unary_op_out(F op_out, const Tensor& self, Tensor& result, Args&&... args) {
TORCH_INTERNAL_ASSERT(self.is_sparse_csr());
TORCH_INTERNAL_ASSERT(result.is_sparse_csr());

if (!result.is_same(self)) {
// For the case of (0x0) result tensor, manually resize `result` tensor
// to the size of `self` tensor
if (result.numel() == 0) {
at::native::resize_as_sparse_csr_(result, self);
}
Comment on lines +81 to +83
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at::native::resize_output(result.col_indices(), self.col_indices().sizes());
// OR
result.col_indices().resize_(self.col_indices().sizes());

don't resize the result.col_indices() tensor. Is this a bug, or is it expected? cc: @IvanYashchuk @cpuhrsch

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we write

auto col_indices = result.col_indices();
col_indices.resize_(self.col_indices().sizes());

We would see that col_indices is resized, but result.col_indices() would have the original size. To replace result.col_indices() we would need to use set_member_tensors method (result.unsafeGetTensorImpl()->set_member_tensors(...)`).
It's expected, this issue is related #63549.

// copy_sparse_csr_ internally checks the sizes of result and self tensors
// Hence no external size check required
at::native::copy_sparse_csr_(result, self);
}

auto self_values = self.values();
auto result_values = result.values();

op_out(self_values, std::forward<Args>(args)..., result_values);
return result;
}

} // end anonymous namespace

namespace native {
Expand All @@ -68,6 +91,30 @@ using namespace at::sparse_csr;
// certain utiliy functions are usable from sparse COO.
using namespace at::sparse;

namespace {

template <typename F>
inline Tensor get_result_tensor_for_unary_op(F op, const Tensor& input) {
auto values = input.values();

// To handle type promotion for inputs to unary ops,
// we first get the result from the underlined op, and use the result
// to create a sparse CSR tensor, which is used as the input to the out= variant
auto result_values = op(values);

auto result = at::native::_sparse_csr_tensor_unsafe(
input.crow_indices().clone(),
input.col_indices().clone(),
result_values,
input.sizes(),
result_values.scalar_type(),
input.layout(),
result_values.device());

return result;
}
}

static constexpr bool is_mkl_supported() {
#ifdef _MSC_VER
return false;
Expand All @@ -85,6 +132,19 @@ bool is_square_or_vec(int64_t dim_i, int64_t dim_j, int64_t dim_k) {
return (dim_i == dim_k && dim_k == dim_j) || (dim_i == dim_j && dim_k == 1);
}

Tensor& sin_sparse_csr_out(const Tensor& self, Tensor& result) {
return unary_op_out(&at::sin_outf, self, result);
}

Tensor sin_sparse_csr(const Tensor& self) {
auto result = get_result_tensor_for_unary_op(&at::sin, self);
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
return sin_sparse_csr_out(self, result);
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
}

Tensor& sin_sparse_csr_(Tensor& self) {
return sin_sparse_csr_out(self, self);
}
krshrimali marked this conversation as resolved.
Show resolved Hide resolved

template <typename scalar_t>
void addmm_out_sparse_csr_native_cpu(const Tensor& sparse, const Tensor& dense, const Tensor& r, Scalar alpha, Scalar beta) {

Expand Down
24 changes: 23 additions & 1 deletion test/test_sparse_csr.py
Expand Up @@ -8,8 +8,9 @@
from torch.testing._internal.common_utils import \
(TEST_WITH_ROCM, TestCase, run_tests, load_tests, coalescedonoff)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoCusparseGeneric,
(ops, instantiate_device_type_tests, dtypes, dtypesIfCUDA, onlyCPU, onlyCUDA, skipCUDAIfNoCusparseGeneric,
precisionOverride, skipMeta, skipCUDAIf, skipCPUIfNoMklSparse)
from torch.testing._internal.common_methods_invocations import (unary_ufuncs, )
from torch.testing._internal.common_cuda import _get_torch_cuda_version
from torch.testing._internal.common_dtype import floating_types, get_all_dtypes
from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED
Expand All @@ -30,6 +31,7 @@ def _check_cusparse_spgemm_available():
min_supported_version = (11, 0)
return version >= min_supported_version

_sparse_csr_unary_ops = list(filter(lambda op: op.supports_sparse_csr, unary_ufuncs))

# This should be just an import from test_linalg instead of code duplication
# but https://github.com/pytorch/pytorch/pull/63511#discussion_r733989701
Expand Down Expand Up @@ -940,6 +942,26 @@ def test_coo_csr_conversion(self, device, dtype):

self.assertEqual(csr_sparse.to_dense(), dense)

@ops(_sparse_csr_unary_ops)
def test_sparse_csr_unary(self, device, dtype, op):
samples = op.sample_inputs(device, dtype)

if len(samples) == 0:
self.skipTest("Skipped! No sample inputs!")

for sample in samples:
assert torch.is_tensor(sample.input)
# Sparse CSR only supports 2D tensors as inputs
# Fail early to prevent silent success with this test
if sample.input.ndim != 2:
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Expected 2D tensor but got tensor with dimension: {sample.input.ndim}.")

expected = op(sample.input)
assert torch.is_tensor(expected)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it might be better to use self.assertTrue to get a better error message (we can do that in a follow-up PR or together with other changes for more unary ops)

output = op(sample.input.to_sparse_csr())
assert torch.is_tensor(output)
krshrimali marked this conversation as resolved.
Show resolved Hide resolved
self.assertEqual(output.to_dense(), expected)


# e.g., TestSparseCSRCPU and TestSparseCSRCUDA
instantiate_device_type_tests(TestSparseCSR, globals())
Expand Down
22 changes: 16 additions & 6 deletions torch/testing/_internal/common_methods_invocations.py
Expand Up @@ -581,6 +581,8 @@ def __init__(self,
supports_sparse=False, # whether the op supports sparse inputs

supports_scripting=True, # only run tracing tests
# the following metadata relates to sparse csr support and is used in test_sparse_csr.py
supports_sparse_csr=False, # whether the op supports sparse csr inputs
# the following metadata relates to complex support and is checked in test_ops.py
test_conjugated_samples=True,
test_neg_view=True,
Expand Down Expand Up @@ -707,6 +709,7 @@ def __init__(self,
self.supports_inplace_autograd = supports_inplace_autograd

self.supports_sparse = supports_sparse
self.supports_sparse_csr = supports_sparse_csr

self.aliases = ()
if aliases is not None:
Expand Down Expand Up @@ -1085,12 +1088,18 @@ def sample_inputs_unary(op_info, device, dtype, requires_grad, **kwargs):
low = low if low is None else low + op_info._domain_eps
high = high if high is None else high - op_info._domain_eps

return (SampleInput(make_tensor((L,), device=device, dtype=dtype,
low=low, high=high,
requires_grad=requires_grad)),
SampleInput(make_tensor((), device=device, dtype=dtype,
low=low, high=high,
requires_grad=requires_grad)))
if op_info.supports_sparse_csr:
# Tensors with dim=2 for sparse CSR testing
return (SampleInput(make_tensor((L, L), device=device, dtype=dtype,
low=low, high=high,
requires_grad=requires_grad)),)
else:
return (SampleInput(make_tensor((L,), device=device, dtype=dtype,
low=low, high=high,
requires_grad=requires_grad)),
SampleInput(make_tensor((), device=device, dtype=dtype,
low=low, high=high,
requires_grad=requires_grad)))

# Metadata class for unary "universal functions (ufuncs)" that accept a single
# tensor and have common properties like:
Expand Down Expand Up @@ -10417,6 +10426,7 @@ def ref_pairwise_distance(input1, input2):
handles_large_floats=False,
handles_complex_extremals=False,
safe_casts_outputs=True,
supports_sparse_csr=True,
supports_forward_ad=True,
decorators=(precisionOverride({torch.bfloat16: 1e-2}),)),
UnaryUfuncInfo('sinc',
Expand Down