diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu index 5d25138500d7..660862181262 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu @@ -96,18 +96,16 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) { dim3 block(C10_WARP_SIZE, SZ); AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, values.scalar_type(), "coalesce_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "coalesce_sparse_cuda", [&] { - using cuda_accscalar_t = acc_type; - apply::coalesceValuesKernel<<>>( - uniqueOffsets.data_ptr(), - origIndices.data_ptr(), - values.data_ptr(), - newValues.data_ptr(), - nnz, - newNnz, - stride - ); - }); + using cuda_accscalar_t = acc_type; + apply::coalesceValuesKernel<<>>( + uniqueOffsets.data_ptr(), + origIndices.data_ptr(), + values.data_ptr(), + newValues.data_ptr(), + nnz, + newNnz, + stride + ); }); } diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 81058ec266f2..d0aafe680efb 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -340,13 +340,11 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> - <<>>( - TensorCAddOp(value.to()), - V_INFO(r), I_INFO(indices), V_INFO(values), - static_cast(nnz)); - }); + apply::sparseElementwiseKernelScalar, uint64_t, scalar_t> + <<>>( + TensorCAddOp(value.to()), + V_INFO(r), I_INFO(indices), V_INFO(values), + static_cast(nnz)); }); } else { TORCH_CHECK(cuda::getApplyGrid(nnz * block.x, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); @@ -356,13 +354,11 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - apply::sparseElementwiseKernel, uint64_t, scalar_t> - <<>>( - TensorCAddOp(value.to()), - V_INFO(r), I_INFO(indices), V_INFO(values), - static_cast(nnz)); - }); + apply::sparseElementwiseKernel, uint64_t, scalar_t> + <<>>( + TensorCAddOp(value.to()), + V_INFO(r), I_INFO(indices), V_INFO(values), + static_cast(nnz)); }); } } else { @@ -373,11 +369,9 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT // NB: Purposely not inplace! AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] { - if (value.to() != static_cast(1)) { - values = values.mul(value); - } - }); + if (value.to() != static_cast(1)) { + values = values.mul(value); + } }); int64_t view_rows = 1; @@ -445,11 +439,9 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const AT_DISPATCH_ALL_TYPES_AND2( at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_sparse_cuda", [&] { - AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_sparse_cuda", [&] { - if (value.to() != static_cast(1)) { - s_values_ = s_values_.mul(value); - } - }); + if (value.to() != static_cast(1)) { + s_values_ = s_values_.mul(value); + } }); LongTensor r_indices_ = at::cat({t_indices_, s_indices_}, 1); Tensor r_values_ = at::cat({t_values_, s_values_}, 0); diff --git a/test/test_sparse.py b/test/test_sparse.py index 72a67caa2038..5af630c0acb4 100644 --- a/test/test_sparse.py +++ b/test/test_sparse.py @@ -10,7 +10,7 @@ import random import unittest from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \ - do_test_empty_full, load_tests, TEST_NUMPY, TEST_WITH_ROCM, IS_WINDOWS + do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version from numbers import Number from torch.autograd.gradcheck import gradcheck @@ -1301,7 +1301,6 @@ def test_spadd_hybrid(self): self._test_spadd_shape(10, [50, 30, 20], [2, 0]) @cuda_only - @unittest.skipIf(not TEST_WITH_ROCM, "runs only on ROCm") def test_sparse_add_out_bfloat16(self): # fp32 x, _, _ = self._gen_sparse(3, 5, 10)