Skip to content

Commit

Permalink
Enable cadd_sparse for BFloat16 on CPU
Browse files Browse the repository at this point in the history
  • Loading branch information
itaraban committed Mar 30, 2023
1 parent 4e1060c commit 22ba0c9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 5 deletions.
2 changes: 1 addition & 1 deletion aten/src/ATen/native/sparse/SparseTensorMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ SparseTensor& add_out_sparse_contiguous(SparseTensor& r, const SparseTensor& t,
auto r_indices_accessor = r_indices.accessor<int64_t, 2>();
auto src_indices_accessor = src_indices.accessor<int64_t, 2>();

AT_DISPATCH_ALL_TYPES_AND_COMPLEX(
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(kBFloat16,
commonDtype, "cadd_sparse", [&] {
scalar_t* t_values_ptr = t_values.data_ptr<scalar_t>();
scalar_t* s_values_ptr = s_values.data_ptr<scalar_t>();
Expand Down
5 changes: 1 addition & 4 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1719,15 +1719,12 @@ def _test_spadd_hybrid():
_test_spadd()
_test_spadd_hybrid()

@onlyCUDA
@coalescedonoff
@dtypes(torch.double, torch.cdouble)
@dtypes(torch.float)
def test_sparse_add_out_bfloat16(self, device, dtype, coalesced):
# fp32
x, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
y, _, _ = self._gen_sparse(3, 5, 10, dtype, device, coalesced)
x = x.float().cuda()
y = y.float().cuda()
res_fp32 = torch.add(x, y)

# bfloat16
Expand Down

0 comments on commit 22ba0c9

Please sign in to comment.