Skip to content

Commit

Permalink
CUDA BF16 sparse (#48807)
Browse files Browse the repository at this point in the history
Summary:
Fixes #{issue number}

Pull Request resolved: #48807

Reviewed By: mruberry

Differential Revision: D25526752

Pulled By: ngimel

fbshipit-source-id: 9ff8e637486cfd67d46daf0c05142bbe611e08ec
  • Loading branch information
zasdfgbnm authored and facebook-github-bot committed Dec 14, 2020
1 parent 690eaf9 commit 87636c0
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 38 deletions.
22 changes: 10 additions & 12 deletions aten/src/ATen/native/sparse/cuda/SparseCUDATensor.cu
Expand Up @@ -96,18 +96,16 @@ SparseTensor coalesce_sparse_cuda(const SparseTensor& self) {
dim3 block(C10_WARP_SIZE, SZ);
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, values.scalar_type(), "coalesce_sparse_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "coalesce_sparse_cuda", [&] {
using cuda_accscalar_t = acc_type<scalar_t, /* is_cuda */ true>;
apply::coalesceValuesKernel<scalar_t, cuda_accscalar_t><<<grid, block, 0, stream>>>(
uniqueOffsets.data_ptr<int64_t>(),
origIndices.data_ptr<int64_t>(),
values.data_ptr<scalar_t>(),
newValues.data_ptr<scalar_t>(),
nnz,
newNnz,
stride
);
});
using cuda_accscalar_t = acc_type<scalar_t, /* is_cuda */ true>;
apply::coalesceValuesKernel<scalar_t, cuda_accscalar_t><<<grid, block, 0, stream>>>(
uniqueOffsets.data_ptr<int64_t>(),
origIndices.data_ptr<int64_t>(),
values.data_ptr<scalar_t>(),
newValues.data_ptr<scalar_t>(),
nnz,
newNnz,
stride
);
});
}

Expand Down
40 changes: 16 additions & 24 deletions aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu
Expand Up @@ -340,13 +340,11 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] {
apply::sparseElementwiseKernelScalar<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
<<<grid, block, 0, stream>>>(
TensorCAddOp<scalar_t>(value.to<scalar_t>()),
V_INFO(r), I_INFO(indices), V_INFO(values),
static_cast<uint64_t>(nnz));
});
apply::sparseElementwiseKernelScalar<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
<<<grid, block, 0, stream>>>(
TensorCAddOp<scalar_t>(value.to<scalar_t>()),
V_INFO(r), I_INFO(indices), V_INFO(values),
static_cast<uint64_t>(nnz));
});
} else {
TORCH_CHECK(cuda::getApplyGrid(nnz * block.x, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions");
Expand All @@ -356,13 +354,11 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] {
apply::sparseElementwiseKernel<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
<<<grid, block, 0, stream>>>(
TensorCAddOp<scalar_t>(value.to<scalar_t>()),
V_INFO(r), I_INFO(indices), V_INFO(values),
static_cast<uint64_t>(nnz));
});
apply::sparseElementwiseKernel<TensorCAddOp<scalar_t>, uint64_t, scalar_t>
<<<grid, block, 0, stream>>>(
TensorCAddOp<scalar_t>(value.to<scalar_t>()),
V_INFO(r), I_INFO(indices), V_INFO(values),
static_cast<uint64_t>(nnz));
});
}
} else {
Expand All @@ -373,11 +369,9 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT
// NB: Purposely not inplace!
AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_dense_sparse_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_dense_sparse_cuda", [&] {
if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
values = values.mul(value);
}
});
if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
values = values.mul(value);
}
});

int64_t view_rows = 1;
Expand Down Expand Up @@ -445,11 +439,9 @@ SparseTensor& add_out_sparse_cuda(SparseTensor& r_, const SparseTensor& t, const

AT_DISPATCH_ALL_TYPES_AND2(
at::ScalarType::Half, at::ScalarType::BFloat16, commonDtype, "add_out_sparse_cuda", [&] {
AT_SKIP_BFLOAT16_IF_NOT_ROCM(scalar_t, "add_out_sparse_cuda", [&] {
if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
s_values_ = s_values_.mul(value);
}
});
if (value.to<scalar_t>() != static_cast<scalar_t>(1)) {
s_values_ = s_values_.mul(value);
}
});
LongTensor r_indices_ = at::cat({t_indices_, s_indices_}, 1);
Tensor r_values_ = at::cat({t_values_, s_values_}, 0);
Expand Down
3 changes: 1 addition & 2 deletions test/test_sparse.py
Expand Up @@ -10,7 +10,7 @@
import random
import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
do_test_empty_full, load_tests, TEST_NUMPY, TEST_WITH_ROCM, IS_WINDOWS
do_test_empty_full, load_tests, TEST_NUMPY, IS_WINDOWS
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
from torch.autograd.gradcheck import gradcheck
Expand Down Expand Up @@ -1301,7 +1301,6 @@ def test_spadd_hybrid(self):
self._test_spadd_shape(10, [50, 30, 20], [2, 0])

@cuda_only
@unittest.skipIf(not TEST_WITH_ROCM, "runs only on ROCm")
def test_sparse_add_out_bfloat16(self):
# fp32
x, _, _ = self._gen_sparse(3, 5, 10)
Expand Down

0 comments on commit 87636c0

Please sign in to comment.