diff --git a/aten/src/ATen/native/TensorCompare.cpp b/aten/src/ATen/native/TensorCompare.cpp index ca8f039fe964600..b9c4739bfafa00b 100644 --- a/aten/src/ATen/native/TensorCompare.cpp +++ b/aten/src/ATen/native/TensorCompare.cpp @@ -60,27 +60,42 @@ TORCH_META_FUNC(isneginf) (const Tensor& self) { build_unary_force_boolean_op(maybe_get_output(), self); } -TORCH_META_FUNC2(max, dim)(const Tensor& self, int64_t dim, bool keepdim) { +void check_minmax_for_meta( + impl::MetaBase& meta, + const char* name, + const Tensor& self, + int64_t dim, + bool keepdim) { TORCH_CHECK( self.layout() == Layout::Strided, - "max(): only supports strided layout, got: ", self.layout()); + name, ": only supports strided layout, got: ", self.layout()); TORCH_CHECK( !self.is_complex(), - "max(): does not support complex input"); + name, ": does not support complex input"); dim = maybe_wrap_dim(dim, self.dim()); DimVector sizes(self.sizes()); if (self.numel() == 0) { - sizes = at::native::get_zero_numel_tensor_size(self, dim, keepdim, "max()"); + sizes = at::native::get_zero_numel_tensor_size(self, dim, keepdim, name); } else { sizes = get_reduction_shape(self, dim, keepdim); } - set_output(0, sizes, self.options()); - set_output(1, sizes, self.options().dtype(kLong)); - namedinference::propagate_names_for_reduction(maybe_get_output(0), self, dim, keepdim); - namedinference::propagate_names_for_reduction(maybe_get_output(1), self, dim, keepdim); + meta.set_output(0, sizes, self.options()); + meta.set_output(1, sizes, self.options().dtype(kLong)); + namedinference::propagate_names_for_reduction( + meta.maybe_get_output(0), self, dim, keepdim); + namedinference::propagate_names_for_reduction( + meta.maybe_get_output(1), self, dim, keepdim); +} + +TORCH_META_FUNC2(max, dim)(const Tensor& self, int64_t dim, bool keepdim) { + check_minmax_for_meta(*this, "max()", self, dim, keepdim); +} + +TORCH_META_FUNC2(min, dim)(const Tensor& self, int64_t dim, bool keepdim) { + check_minmax_for_meta(*this, "min()", self, dim, keepdim); } } // namespace meta @@ -420,23 +435,43 @@ std::tuple mode_out(const Tensor& self, int64_t dim, bool kee } } -TORCH_IMPL_FUNC(max_out) -(const Tensor& self, - int64_t dim, - bool keepdim, - const Tensor& values, - const Tensor& indices) { +template +void minmax_out_impl( + const Tensor& self, + int64_t dim, + bool keepdim, + const Tensor& values, + const Tensor& indices, + Stub& stub) { NoNamesGuard guard; if (self.numel() > 0) { if (self.numel() == 1 && self.dim() == 0) { values.fill_(self); indices.fill_(0); } else { - max_stub(self.device().type(), values, indices, self, dim, keepdim); + stub(self.device().type(), values, indices, self, dim, keepdim); } } } +TORCH_IMPL_FUNC(max_out) +(const Tensor& self, + int64_t dim, + bool keepdim, + const Tensor& values, + const Tensor& indices) { + minmax_out_impl(self, dim, keepdim, values, indices, max_stub); +} + +TORCH_IMPL_FUNC(min_out) +(const Tensor& self, + int64_t dim, + bool keepdim, + const Tensor& values, + const Tensor& indices) { + minmax_out_impl(self, dim, keepdim, values, indices, min_stub); +} + std::tuple qmax(const Tensor& self, int64_t dim, bool keepdim) { Tensor max_indices = at::empty({0}, self.options().dtype(kLong)); Tensor max = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type()))); @@ -446,16 +481,12 @@ std::tuple qmax(const Tensor& self, int64_t dim, bool keepdim) { at::_make_per_tensor_quantized_tensor(max, self.q_scale(), self.q_zero_point()), max_indices); } -std::tuple min(const Tensor& self, int64_t dim, bool keepdim) { +std::tuple qmin(const Tensor& self, int64_t dim, bool keepdim) { Tensor min_indices = at::empty({0}, self.options().dtype(kLong)); - if (self.is_quantized()) { - Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type()))); - at::native::min_out(self.int_repr(), dim, keepdim, min, min_indices); - return std::tuple(at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices); - } else { - Tensor min = at::empty({0}, self.options()); - return at::native::min_out(self, dim, keepdim, min, min_indices); - } + Tensor min = at::empty({0}, self.options().dtype(toUnderlying(self.scalar_type()))); + at::min_outf(self.int_repr(), dim, keepdim, min, min_indices); + return std::tuple( + at::_make_per_tensor_quantized_tensor(min, self.q_scale(), self.q_zero_point()), min_indices); } static std::tuple _aminmax_out_impl(Tensor& min, Tensor& max, @@ -491,49 +522,6 @@ std::tuple _aminmax(const Tensor& self, int64_t dim, bool keepdi return result; } -static std::tuple min_out_impl(Tensor& min, Tensor& min_indices, - const Tensor& self, int64_t dim, bool keepdim) { - TORCH_CHECK(self.device().is_cpu() || self.is_cuda(), - "min only supports CPU AND CUDA device type, got: ", self.device().type()); - TORCH_CHECK(self.layout() == Layout::Strided, - "min only supports strided layout, got: ", self.layout()); - TORCH_CHECK(self.device() == min.device(), - "expected device ", self.device(), " but got ", - min.device(), " for min values output"); - TORCH_CHECK(self.device() == min_indices.device(), - "expected device ", self.device(), " but got ", - min_indices.device(), " for indices output"); - dim = maybe_wrap_dim(dim, self.dim()); - if (self.numel() == 0) { - zero_numel_tensor_resize(min, min_indices, self, dim, keepdim, "min()"); - return std::tie(min, min_indices); - } - else if (_dimreduce_return_trivial_no_ident(min, self, dim, keepdim, "min")) { - TORCH_CHECK(!self.is_complex(), "min does not support complex inputs."); - AT_ASSERT(min.dim() == 0); - min_indices.resize_({}).fill_(0); - return std::forward_as_tuple(min, min_indices); - } else { - min_stub(self.device().type(), min, min_indices, self, dim, keepdim); - return std::tuple{min, min_indices}; - } -} - -std::tuple min_out( - const Tensor& self, - int64_t dim, - bool keepdim, - Tensor& min, - Tensor& min_indices) { - auto result = [&]() { - NoNamesGuard guard; - return min_out_impl(min, min_indices, self, dim, keepdim); - }(); - namedinference::propagate_names_for_reduction(min, self, dim, keepdim); - namedinference::propagate_names_for_reduction(min_indices, self, dim, keepdim); - return result; -} - Tensor& clamp_out(const Tensor& self, const c10::optional& min, const c10::optional& max, Tensor& result) { if (min && max) { auto iter = TensorIterator::unary_op(result, self); diff --git a/aten/src/ATen/native/TensorCompare.h b/aten/src/ATen/native/TensorCompare.h index 24b6e9cebbb35e5..3014cade5f8e944 100644 --- a/aten/src/ATen/native/TensorCompare.h +++ b/aten/src/ATen/native/TensorCompare.h @@ -9,11 +9,11 @@ namespace at { namespace native { using reduce_minmax_fn = void (*)(Tensor&, Tensor&, const Tensor&, int64_t, bool); -using reduce_max_fn = +using structured_reduce_minmax_fn = void (*)(const Tensor&, const Tensor&, const Tensor&, int64_t, bool); -DECLARE_DISPATCH(reduce_max_fn, max_stub); -DECLARE_DISPATCH(reduce_minmax_fn, min_stub); +DECLARE_DISPATCH(structured_reduce_minmax_fn, max_stub); +DECLARE_DISPATCH(structured_reduce_minmax_fn, min_stub); DECLARE_DISPATCH(reduce_minmax_fn, _aminmax_stub); using where_fn = void (*)(TensorIterator &, ScalarType); diff --git a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp index 1c67d7c8f9d72ef..64409d233ac404b 100644 --- a/aten/src/ATen/native/cpu/TensorCompareKernel.cpp +++ b/aten/src/ATen/native/cpu/TensorCompareKernel.cpp @@ -88,17 +88,14 @@ static inline void compare_base_kernel(const Tensor& result1, const Tensor& resu } static void min_kernel_impl( - Tensor& result, - Tensor& indice, + const Tensor& result, + const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) { auto wrap_dim = maybe_wrap_dim(dim, self.dim()); int64_t self_dim_size = ensure_nonempty_size(self, wrap_dim); - TORCH_CHECK(result.scalar_type() == self.scalar_type() && indice.scalar_type() == kLong, - "Expect dtype ", self.scalar_type(), "and torch.long, but got ", result.scalar_type(), "and", indice.scalar_type()); - AT_DISPATCH_ALL_TYPES_AND3(ScalarType::Half, ScalarType::BFloat16, ScalarType::Bool, self.scalar_type(), "min_cpu", [&] { compare_base_kernel(result, indice, self, wrap_dim, keepdim, [&] ( scalar_t* result_data, int64_t* indice_data, diff --git a/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu b/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu index 1f40755ef79845e..7ba7f7272e21404 100644 --- a/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu +++ b/aten/src/ATen/native/cuda/ReduceMinMaxKernel.cu @@ -97,8 +97,8 @@ void argmin_kernel_cuda(TensorIterator& iter) { } } -static void min_kernel_impl(Tensor& result, Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) { - at::TensorIterator iter = make_reduction("min", result, indice, self, dim, keepdim, self.scalar_type(), kLong); +static void min_kernel_impl(const Tensor& result, const Tensor& indice, const Tensor& self, int64_t dim, bool keepdim) { + auto iter = meta::make_reduction(self, result, indice, dim, keepdim, self.scalar_type(), kLong); AT_DISPATCH_ALL_TYPES_AND3(kBFloat16, kHalf, kBool, iter.dtype(2), "min_cuda", [&]() { gpu_reduce_kernel( iter, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 0992153cf3ca7cf..8d33454e9e519b9 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -2855,12 +2855,14 @@ - func: nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) - func: min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + structured_delegate: min.dim_min device_check: NoCheck # TensorIterator variants: function, method dispatch: - CPU, CUDA, QuantizedCPU, QuantizedCUDA: min + QuantizedCPU, QuantizedCUDA: qmin - func: min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + structured: True device_check: NoCheck # TensorIterator dispatch: CPU, CUDA: min_out diff --git a/test/test_autograd.py b/test/test_autograd.py index cb24a136e064f89..b3d17d96c1763c0 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -7595,6 +7595,7 @@ def fn(v): nnz = 0 if empty_nnz else 5 _test(sparse_size + dense_size, len(sparse_size), nnz, device) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_sparse_backward(self, device, dtype): class FixedGradientFunction(Function): diff --git a/test/test_sparse.py b/test/test_sparse.py index 291b3a0eeede5e2..99fee469d95caef 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -11,7 +11,7 @@ from numbers import Number from typing import Dict, Any from torch.testing._internal.common_device_type import \ - (instantiate_device_type_tests, ops, dtypes, dtypesIfCPU, onlyCPU, onlyCUDA, deviceCountAtLeast) + (instantiate_device_type_tests, ops, dtypes, dtypesIfCPU, onlyCPU, onlyCUDA, deviceCountAtLeast, skipMeta) from torch.testing._internal.common_methods_invocations import \ (sparse_unary_ufuncs) @@ -133,6 +133,7 @@ def _test_print(self, device, dtype, coalesced): printed.append('') self.assertExpected('\n'.join(printed)) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_basic(self, device, dtype, coalesced): @@ -284,6 +285,7 @@ def test_ctor_size_checks(self, device, dtype): RuntimeError, lambda: self.sparse_tensor(indices, values, torch.Size([2, 4, 2, 1]))) + @skipMeta @dtypes(*torch.testing.floating_and_complex_types_and(torch.float16)) def test_to_dense(self, device, dtype): def test_tensor(x, res): @@ -412,6 +414,7 @@ def test_scalar(self, device, dtype): self.assertEqual(torch.tensor(0, dtype=dtype, device=device), a.to_dense()) self.assertEqual(a, a.to_dense().to_sparse()) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_shared(self, device, dtype): i = self.index_tensor([[2]], device=device) @@ -428,6 +431,7 @@ def test_shared(self, device, dtype): i[0][0] = 0 self.assertEqual(torch.empty((3, 0), dtype=dtype, device=device), self.safeToDense(x)) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_to_dense_hybrid(self, device, dtype): def test_tensor(x, res): @@ -473,6 +477,7 @@ def fn(x): res = torch.empty((3, 4, 2, 0), dtype=dtype, device=device) test_tensor(x, res) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_contig(self, device, dtype): def test_tensor(x, exp_i, exp_v): @@ -554,6 +559,7 @@ def test_tensor(x, exp_i, exp_v): exp_v = torch.empty([2, 0], dtype=dtype, device=device) test_tensor(x, exp_i, exp_v) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_contig_hybrid(self, device, dtype): def test_tensor(x, exp_i, exp_v): @@ -641,6 +647,7 @@ def test_tensor(x, exp_i, exp_v): exp_v = torch.empty([2, 3, 0], dtype=dtype, device=device) test_tensor(x, exp_i, exp_v) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_clone(self, device, dtype, coalesced): @@ -659,6 +666,7 @@ def test_shape(sparse_dims, nnz, with_size): test_shape(3, 10, [100, 100, 100, 5, 5, 5, 0]) test_shape(3, 0, [0, 0, 100, 5, 5, 5, 0]) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced): @@ -705,6 +713,7 @@ def test_Sparse_to_Sparse_copy_(self, device, dtype, coalesced): self.assertEqual(expected_grad.to_dense(), x2.grad.to_dense()) self.assertEqual(None, x1.grad) + @skipMeta @coalescedonoff @unittest.skipIf(torch.cuda.device_count() < 2, "no multi-GPU") @dtypes(torch.double, torch.cdouble) @@ -759,6 +768,7 @@ def test_tensor(x): x = torch.sparse.FloatTensor(2, 3, 4, 0) test_tensor(x) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_transpose(self, device, dtype, coalesced): @@ -832,6 +842,7 @@ def test_not_in_place(x): test_in_place(x) test_not_in_place(x) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_add_zeros(self, device, dtype, coalesced): @@ -860,6 +871,7 @@ def test_add_sub_nnz(self, device, dtype): x.sub_(2 * x) self.assertLessEqual(x._nnz(), 10) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_cat(self, device, dtype, coalesced): @@ -902,6 +914,7 @@ def test_shapes(shapes, dim, fail_message=None): "Concatenating sparse tensors, but a dense tensor was found at position 1."): torch.cat((sp, dn)) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_unsqueeze(self, device, dtype, coalesced): @@ -934,6 +947,7 @@ def test_shape(sparse_dims, nnz, sizes, unsqueeze_dim, fail_message=None): test_shape(3, 10, [5, 7, 11, 13, 17], -7, "Dimension out of range") test_shape(3, 10, [5, 7, 11, 13, 17], 6, "Dimension out of range") + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_select(self, device, dtype, coalesced): @@ -966,6 +980,7 @@ def test_shape(sparse_dims, nnz, sizes, select_dim, select_index, fail_message=N for i in range(sizes[d]): test_shape(1, 10, sizes, d, i) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_index_select(self, device, dtype, coalesced): @@ -1031,6 +1046,7 @@ def test_shape(di, dj, dk, nnz): TEST_CUDA and _get_torch_cuda_version() < (10, 1), "bmm sparse-dense requires CUDA 10.1 or greater" ) + @skipMeta @coalescedonoff @dtypes(torch.double) def test_bmm(self, device, dtype, coalesced): @@ -1233,6 +1249,7 @@ def test_shape(di, dj, dk, nnz): true_result = (bias.to_dense() + torch.matmul(weight.to_dense(), x)).to_sparse() self.assertEqual(self.safeToDense(res), self.safeToDense(true_result)) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_sparse_addmm(self, device, dtype, coalesced): @@ -1265,6 +1282,7 @@ def fn(S, D1, D2, beta=beta, alpha=alpha): test_shape(7, 8, 9, 20, False, (1, 1)) test_shape(7, 8, 9, 20, True, (1, 1)) + @skipMeta @coalescedonoff @dtypes(torch.double) def test_sparse_mm(self, device, dtype, coalesced): @@ -1286,6 +1304,7 @@ def fn(S, D): test_shape(7, 8, 9, 20, False) test_shape(7, 8, 9, 20, True) + @skipMeta @coalescedonoff @dtypes(torch.double) def test_dsmm(self, device, dtype, coalesced): @@ -1305,6 +1324,7 @@ def test_shape(di, dj, dk, nnz): test_shape(1000, 100, 0, 0) test_shape(1000, 100, 0, 20) + @skipMeta @coalescedonoff @dtypes(torch.double) def test_hsmm(self, device, dtype, coalesced): @@ -1324,6 +1344,7 @@ def test_shape(di, dj, dk, nnz): test_shape(1000, 100, 0, 0) test_shape(1000, 100, 0, 20) + @skipMeta @coalescedonoff @dtypes(torch.double) def test_spadd(self, device, dtype, coalesced): @@ -1414,6 +1435,7 @@ def test_sparse_add_out_bfloat16(self, device, dtype, coalesced): res_bf16 = res_bf16.float() # to compare with reference self.assertEqual(res_fp32, res_bf16, atol=1e-2, rtol=0) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_norm(self, device, dtype, coalesced): @@ -1591,6 +1613,7 @@ def _test_basic_ops_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, device, # coalesced. self.assertEqual(z._values(), y._values()) + @skipMeta @coalescedonoff @dtypes(torch.double) def test_basic_ops(self, device, dtype, coalesced): @@ -1622,6 +1645,7 @@ def _test_basic_ops_hybrid(): _test_basic_ops() _test_basic_ops_hybrid() + @skipMeta @dtypes(torch.double, torch.cdouble) def test_add_dense_sparse_mismatch(self, device, dtype): def test_shape(dense_size, sparse_dims_shape, dense_dims_shape, sparse_size): @@ -1659,6 +1683,7 @@ def _test_sparse_mask_shape(self, nnz_x1, nnz_x2, shape_i, shape_v, dtype, devic self.assertEqual(self.safeToDense(y1), expected) self.assertEqual(self.safeToDense(y2), expected) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_sparse_mask(self, device, dtype, coalesced): @@ -1704,6 +1729,7 @@ def _test_sparse_mask_fixed(): self._test_sparse_mask_shape(0, 0, [10, 10, 10], [], dtype, device, coalesced) self._test_sparse_mask_shape(0, 0, [10, 10, 0], [], dtype, device, coalesced) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_sparse_mask_hybrid(self, device, dtype, coalesced): @@ -1756,6 +1782,7 @@ def _test_sparse_mask_hybrid_fixed(): self._test_sparse_mask_shape(0, 0, [10, 10, 10], [2, 0], dtype, device, coalesced) self._test_sparse_mask_shape(0, 0, [10, 10, 0], [2, 0], dtype, device, coalesced) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_zeros(self, device, dtype, coalesced): @@ -1781,6 +1808,7 @@ def test_shape(i_shapes, v_shapes, shape, nnzs): test_shape([0, 3, 4], [3, 4, 5, 6], [2, 3, 0], [0]) test_shape([2, 3, 4], [0, 4, 5, 6], [2, 3, 0], [9, 12]) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_zeros_like(self, device, dtype, coalesced): @@ -1861,6 +1889,7 @@ def _test_empty_like(self, sparse_tensor, dtype, device, coalesced): dense_tensor = sparse_tensor.to_dense() result = torch.empty_like(dense_tensor, layout=torch.sparse_coo) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_empty_like(self, device, dtype, coalesced): @@ -1917,6 +1946,7 @@ def _all_narrow_combs(self, shape): for length in range(dim_sz - start): yield [dim, start, length] + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_narrow(self, device, dtype, coalesced): @@ -1971,6 +2001,7 @@ def is_integral(dtype): with self.assertRaisesRegex(RuntimeError, "only Tensors of floating point dtype can require gradients"): sparse_tensor.requires_grad_() + @skipMeta @coalescedonoff @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False, include_bfloat16=False, include_complex=False)) @@ -2034,6 +2065,7 @@ def _test_neg_negative(self, sparse_tensor): op(sparse_tensor, out=sparse_tensor_out) self.assertEqual(expected_output, sparse_tensor_out.to_dense()) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_neg_negative(self, device, dtype, coalesced): @@ -2119,6 +2151,7 @@ def is_integral(dtype): with self.assertRaisesRegex(RuntimeError, "asin: result type cannot be Integral"): op(sparse_tensor) + @skipMeta @coalescedonoff @dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False, include_bfloat16=False, include_complex=False)) @@ -2164,6 +2197,7 @@ def test_asin_arcsin(self, device, dtype, coalesced): ) self._test_asin_arcsin(input_uncoalesced, coalesced) + @skipMeta @coalescedonoff @dtypes(torch.double) def test_mv(self, device, dtype, coalesced): @@ -2191,6 +2225,7 @@ def test_shape(di, dj, dk, nnz): y, _, _ = self._gen_sparse(2, 20, [10, 100], dtype, device, coalesced) res = x.mv(y) + @skipMeta @dtypes(*torch.testing.floating_and_complex_types()) def test_sparse_add_coalesce(self, device, dtype): i = self.index_tensor([[1, 2, 1]], device=device) @@ -2277,6 +2312,7 @@ def test_new_device_multi_gpu(self): self._test_new_device((30, 20, 10), 1) self._test_new_device((30, 20, 10, 0), 1) + @skipMeta @coalescedonoff @dtypes(torch.double, torch.cdouble) def test_new(self, device, dtype, coalesced): @@ -2334,6 +2370,7 @@ def test_factory(self, device, dtype): self.assertEqual(device, sparse_tensor._values().device) self.assertEqual(True, sparse_tensor.requires_grad) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_factory_size_check(self, device, dtype): indices = self.index_tensor([[1, 2], @@ -2399,6 +2436,7 @@ def test_factory_empty_indices(self, device): expected_indices = torch.empty((4, 0), dtype=torch.long, device=device) self.assertEqual(tensor._indices(), expected_indices) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_factory_nnz(self, device, dtype): indices = self.index_tensor([[0]], device=device) # (sparse_dim, nnz): (1, 1) @@ -2434,6 +2472,7 @@ def test_shape(i_shape, v_shape, size, expected_size): test_shape([3, 0], [0, 2, 4, 0], [0, 0, 0, 2, 4, 0], [0, 0, 0, 2, 4, 0]) test_shape([3, 0], [0, 2, 4, 0], [1, 2, 3, 2, 4, 0], [1, 2, 3, 2, 4, 0]) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_factory_dense_dim(self, device, dtype): indices = self.index_tensor([[0]], device=device) @@ -2668,6 +2707,7 @@ def _test_resize_shape(self, x_i, x_v, x_size, y_i, y_v, y_size, dtype, device): self.assertEqual(x.to_dense().view(-1)[0:x_v_numel].view(x_v), x_dense.view(-1)[0:x_v_numel].view(x_v)) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_resize(self, device, dtype): # 1. Expand the size of some dense dimensions [Supported] @@ -2733,6 +2773,7 @@ def test_resize(self, device, dtype): [1, 1], [1, 2, 0], [2, 2, 0], dtype=dtype, device=device) + @skipMeta def test_is_nonzero(self, device): self.assertTrue(torch.sparse_coo_tensor(([0],), 1., (1,), device=device).is_nonzero()) self.assertFalse(torch.sparse_coo_tensor(([0],), 0., (1,), device=device).is_nonzero()) @@ -2788,6 +2829,7 @@ def do_test(t): do_test(self.sparse_empty([3, 0], device=device).data) do_test(self.sparse_empty([3, 0], device=device).detach()) + @skipMeta @dtypes(torch.double, torch.cdouble) def test_change_tensor_metadata(self, device, dtype): i = self.index_tensor([[0], [1]], device=device) @@ -2879,6 +2921,7 @@ def test_isnan(self, device): t_nan = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([False, True]), device=device) self.assertEqual(torch.isnan(t).int(), t_nan.int()) + @skipMeta @coalescedonoff @dtypes(torch.float32, torch.float64) def test_div_rounding_mode(self, device, dtype, coalesced): @@ -2916,6 +2959,7 @@ def test_sparse_to_numpy(self, device): t = torch.sparse_coo_tensor(torch.tensor(([0, 0], [2, 0])), torch.tensor([1, 4])) self.assertRaises(TypeError, lambda: t.numpy()) + @skipMeta @coalescedonoff @dtypes(torch.double) def test_softmax(self, device, dtype, coalesced): @@ -3211,6 +3255,7 @@ def sparse_log(x): # TODO: Check after why ROCm's cusparseXcsrgemm2Nnz function doesn't return the same nnz value as CUDA @skipIfRocm + @skipMeta @coalescedonoff @dtypes(torch.double) @dtypesIfCPU(torch.double, torch.cdouble) @@ -3370,10 +3415,11 @@ def different_dtypes(): test_sparse_matmul(2, 0, [0, 10], [10, 0]) test_error_cases() + @skipMeta @coalescedonoff @dtypes(torch.double) def test_assign(self, device, dtype, coalesced): - def assign_to(a): + def assign_to(): a, i_a, v_a = self._gen_sparse(2, 5, [2, 3], dtype, device, coalesced) a[0] = 100