Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions aten/src/ATen/native/TensorConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,23 +498,52 @@ std::vector<Tensor> _to_cpu(TensorList tensors) {
return cpu_tensors;
}

Tensor to_dense_backward(const Tensor& grad, const Tensor& input_) {
Tensor to_dense_backward(const Tensor& grad, const Tensor& input_, c10::optional<bool> masked_grad_) {
/*
For historical reasons, to_dense backward implements masked
semantics for sparse tensors, that is, gradients with respect to
unspecified elements are ignored. The masked_grad kw argument of
to_dense is introduced to allow to_dense to be used in the
non-masked semantics context. However, for BC reasons, the default
value to masked_grad kw argument is set True as a first instance.
Eventually, we should eliminate the masked_grad kw argument and
let to_dense backward to behave according to non-masked
semantics. Masked semantics of tensors is implemented in the
framework of masked tensors.
*/
const auto input_layout = input_.layout();
const bool masked_grad = masked_grad_.value_or(true);
switch (input_layout) {
case kStrided:
return grad.to_dense();
// TODO: return grad as it is
return grad.to_dense(input_.scalar_type(), masked_grad_);
case kSparse:
// Autograd operates on the coalesced assumption, i.e. no duplicate values.
return grad.sparse_mask(input_.coalesce());
if (masked_grad) {
return grad.sparse_mask(input_.coalesce());
} else {
// TODO: return grad as it is
return grad.to_sparse(input_.sparse_dim());
}
case kSparseCsr:
case kSparseCsc:
// TODO: add efficient CSR/CSC support for sparse_mask
return grad.sparse_mask(input_.to_sparse()).to_sparse(input_layout);
if (masked_grad) {
return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout);
} else {
// TODO: return grad as it is
return grad.to_sparse(input_layout, /*blocksize=*/c10::nullopt, /*dense_dim=*/input_.dense_dim());
}
case kSparseBsr:
case kSparseBsc: {
// TODO: add efficient BSR/BSC support for sparse_mask
const auto blocksize = at::DimVector(input_.values().sizes().slice(1, 2));
return grad.sparse_mask(input_.to_sparse()).to_sparse(input_layout, blocksize);
const auto blocksize = at::sparse_csr::getBlockSize(input_);
if (masked_grad) {
return grad.sparse_mask(input_.to_sparse(input_.sparse_dim())).to_sparse(input_layout, blocksize);
} else {
// TODO: return grad as it is
return grad.to_sparse(input_layout, blocksize, input_.dense_dim());
}
}
case kMkldnn:
return grad.to_mkldnn(input_.scalar_type());
Expand All @@ -529,18 +558,18 @@ Tensor to_mkldnn_backward(const Tensor& grad, const Tensor& input_) {
return grad.to_dense(input_.scalar_type());
}

Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype) {
Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype, c10::optional<bool> masked_grad) {
if (tensor.layout() == c10::kSparse) {
return tensor._to_dense(dtype);
return tensor._to_dense(dtype, masked_grad);
}
if (tensor.layout() == c10::kSparseCsr ||
tensor.layout() == c10::kSparseCsc ||
tensor.layout() == c10::kSparseBsr ||
tensor.layout() == c10::kSparseBsc) {
return tensor._to_dense(dtype);
return tensor._to_dense(dtype, masked_grad);
}
if (tensor.layout() == c10::kMkldnn) {
return tensor._to_dense(dtype);
return tensor._to_dense(dtype, masked_grad);
}
TORCH_CHECK(
tensor.layout() == c10::kStrided,
Expand All @@ -552,7 +581,7 @@ Tensor to_dense(const Tensor& tensor, c10::optional<c10::ScalarType> dtype) {
return tensor;
}

Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype) {
Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype, c10::optional<bool> masked) {
TORCH_CHECK(
!dtype.has_value(), "dtype argument is not supported by sparse_to_dense");
Tensor dst = at::zeros(self.sizes(), self.options().layout(kStrided));
Expand All @@ -561,7 +590,8 @@ Tensor sparse_to_dense(const Tensor& self, c10::optional<ScalarType> dtype) {

Tensor sparse_compressed_to_dense(
const Tensor& self,
c10::optional<ScalarType> dtype) {
c10::optional<ScalarType> dtype,
c10::optional<bool> masked_grad) {
TORCH_CHECK(
!dtype.has_value(),
"dtype argument is not supported by sparse_csr_to_dense");
Expand Down Expand Up @@ -1756,7 +1786,7 @@ Tensor sparse_compressed_to_sparse(const Tensor& self, c10::optional<c10::Layout
}
switch (layout_) {
case kStrided:
return sparse_compressed_to_dense(self);
return sparse_compressed_to_dense(self, /*dtype=*/c10::nullopt, /*masked_grad=*/c10::nullopt);
Comment on lines -1759 to +1789
Copy link
Collaborator

@nikitaved nikitaved Mar 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Applies to the whole PR: wouldn't it be more clear to pass explicit true/false for masked_grad instead? All while leaving nullopt for something to indicate that the user (not the dev) did not provide any value. Or is it done on purpose to ease the shift to the right semantics?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. I am using nullopt to ensure the default behavior. In addition, sparse_compressed_to_dense does not use masked_grad value anyway.

Hmm, it is not clear to me how the specification of dtype was not required before...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the right thing to do here would be to add masked_grad kw argument to to_sparse method as well so that the correct value could be propagated to to_dense call when the specified layout is torch.strided. However, changing to_sparse API is OT here and I suggest we leave the code as it is for masked_grad visibility here.

case kSparse:
return sparse_compressed_to_sparse(self, 2);
case kSparseCsr:
Expand Down Expand Up @@ -1801,7 +1831,7 @@ Tensor sparse_coo_to_sparse(const Tensor& self, c10::optional<c10::Layout> layou
}
switch (layout_) {
case kStrided:
return self.to_dense();
return self.to_dense(c10::nullopt, c10::nullopt);
case kSparse:
return self;
case kSparseCsr:
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace at { namespace native {

#if AT_MKLDNN_ENABLED()

Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) {
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
TORCH_CHECK(mkldnn_tensor.scalar_type() == ScalarType::Float ||
mkldnn_tensor.scalar_type() == ScalarType::BFloat16,
"mkldnn_to_dense expects float or bfloat16 tensor input");
Expand Down Expand Up @@ -269,7 +269,7 @@ TORCH_LIBRARY_IMPL(mkldnn, MkldnnCPU, m) {

#else

Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype) {
Tensor mkldnn_to_dense(const Tensor& mkldnn_tensor, c10::optional<ScalarType> dtype, c10::optional<bool> masked_grad) {
TORCH_CHECK(false, "MKL-DNN build is disabled");
}

Expand Down Expand Up @@ -342,4 +342,5 @@ TORCH_LIBRARY_IMPL(mkl, MkldnnCPU, m) {
}

#endif // AT_MKL_ENABLED && AT_MKLDNN_ENABLED

}}
6 changes: 3 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6702,19 +6702,19 @@
- func: _to_cpu(Tensor[] tensors) -> Tensor[]
variants: function

- func: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
- func: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
variants: method

# Special case of to_dense with custom derivative
- func: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
- func: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_to_dense
SparseCsrCPU, SparseCsrCUDA: sparse_compressed_to_dense
MkldnnCPU: mkldnn_to_dense
autogen: _to_dense.out

- func: to_dense_backward(Tensor grad, Tensor input) -> Tensor
- func: to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor

- func: sparse_dim(Tensor self) -> int
variants: method
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/sparse/SparseTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,8 @@ const SparseTensor& resize_as_sparse_(const SparseTensor& self, const SparseTens
SparseTensor dense_to_sparse(const Tensor& self, c10::optional<c10::Layout> layout, OptionalIntArrayRef blocksize, c10::optional<int64_t> dense_dim_opt) {
if (layout.has_value()) {
if (blocksize.has_value() && !(*layout == kSparseBsr || *layout == kSparseBsc)) {
AT_ERROR("to_sparse for ", self.layout(), " to ", *layout, " conversion does not use specified blocksize");
AT_ERROR("to_sparse for ", self.layout(), " to ", *layout,
" conversion does not use the specified blocksize ", blocksize.value(), ".");
}
if (self.layout() == *layout) {
return self;
Expand Down
40 changes: 8 additions & 32 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm, do_test_dtypes, \
load_tests, TEST_NUMPY, TEST_SCIPY, IS_WINDOWS, gradcheck, coalescedonoff, \
DeterministicGuard, first_sample, TEST_WITH_CROSSREF, TEST_WITH_ROCM, skipIfTorchDynamo, \
parametrize, subtest, is_coalesced_indices, suppress_warnings, is_slow_gradcheck_env
parametrize, subtest, is_coalesced_indices, suppress_warnings
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from numbers import Number
from typing import Dict, Any
Expand Down Expand Up @@ -413,8 +413,6 @@ def test_ctor_size_checks(self, device, dtype):
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
@gradcheck_semantics()
def test_to_dense_with_gradcheck(self, device, dtype, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')

def test_tensor(x, res):
x.to_dense() # Tests triple to_dense for memory corruption
Expand All @@ -432,7 +430,7 @@ def test_tensor(x, res):
return

def fn(x):
return x.to_dense()
return x.to_dense(masked_grad=gradcheck.masked)
x.requires_grad_(True)
gradcheck(fn, (x,), check_sparse_nnz=True)

Expand Down Expand Up @@ -550,8 +548,6 @@ def test_shared(self, device, dtype):
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
@gradcheck_semantics()
def test_to_dense_hybrid(self, device, dtype, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')

def test_tensor(x, res):
x.to_dense() # Tests double to_dense for memory corruption
Expand All @@ -561,7 +557,7 @@ def test_tensor(x, res):
self.assertEqual(res, self.safeToDense(x))

def fn(x):
return x.to_dense()
return x.to_dense(masked_grad=gradcheck.masked)
x.requires_grad_(True)
gradcheck(fn, (x,), check_sparse_nnz=True)

Expand Down Expand Up @@ -908,8 +904,6 @@ def test_shape(sparse_dims, nnz, with_size):
@unittest.skipIf(TEST_WITH_CROSSREF, "generator unsupport triggers assertion error")
@gradcheck_semantics()
def test_permute(self, device, dtype, coalesced, gradcheck):
if not gradcheck.masked and is_slow_gradcheck_env():
self.skipTest('FIXME: to_dense_backward supports masked semantics only')
# trivial checks
s = torch.rand(3, 3, 3, device=device, dtype=dtype).to_sparse()
with self.assertRaisesRegex(RuntimeError, "does not match the length"):
Expand Down Expand Up @@ -941,7 +935,7 @@ def test_shape(sparse_dims, nnz, with_size):
else:
self.assertFalse(s_permuted.is_coalesced())

gradcheck(lambda t: t.permute(dims).to_dense(), s.requires_grad_(True), check_sparse_nnz=True)
gradcheck(lambda t: t.permute(dims).to_dense(masked_grad=gradcheck.masked), s.requires_grad_())
else:
# otherwise check if exception is thrown
fail_message = "transpositions between sparse and dense dimensions are not allowed"
Expand Down Expand Up @@ -1778,10 +1772,7 @@ def run_tests(S, td=None):
self.assertEqual(S_sum.item(), D_sum.item())

def fn(S):
res = torch.sparse.sum(S)
if res.is_sparse:
res = res.to_dense()
return res
return torch.sparse.sum(S)
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)
else:
S_sum = torch.sparse.sum(S, td)
Expand All @@ -1790,9 +1781,7 @@ def fn(S):

def fn(S):
res = torch.sparse.sum(S, td)
if res.is_sparse:
res = res.to_dense()
return res
return res.to_dense(masked_grad=True)
gradcheck(fn, (S,), check_sparse_nnz=True, masked=True)

nnz = 10
Expand Down Expand Up @@ -4012,7 +4001,7 @@ def test_cuda_sparse_cpu_dense_add(self):

def _sparse_to_dense(tensor):
if tensor.dtype != torch.bool:
return tensor.to_dense()
return tensor.to_dense(masked_grad=True)

# to_dense uses coalesce which isn't implemented for bool
return tensor.to(torch.int8).to_dense().to(torch.bool)
Expand Down Expand Up @@ -4423,22 +4412,9 @@ def test_gradcheck_to_dense(self, from_layout, device, dtype, index_dtype, gradc
# TODO: implement batch support in _convert_indices_from_csr_to_coo
continue
t = t.clone().detach().requires_grad_(True)
if is_slow_gradcheck_env() and not gradcheck.masked:
# TODO: remove this if-block when TODO items below are resolved
try:
gradcheck(torch.Tensor.to_dense, t)
except RuntimeError as msg:
# TODO: implement non-masked semantics support in to_dense_backward
with self.assertRaisesRegex(RuntimeError, "Jacobian mismatch"):
gradcheck(torch.Tensor.to_dense, t)
self.skipTest('non-masked semantics not supported')
r = gradcheck(torch.Tensor.to_dense, t)
r = gradcheck(lambda x: torch.Tensor.to_dense(x, masked_grad=gradcheck.masked), t)
self.assertTrue(r)

# when the following assert fails, it means that the if-block
# above and the assertFalse test below can be safely removed
self.assertFalse(is_slow_gradcheck_env() and not gradcheck.masked)

@all_sparse_layouts('from_layout', include_strided=True)
@all_sparse_layouts('to_layout', include_strided=False)
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
Expand Down
6 changes: 3 additions & 3 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1676,10 +1676,10 @@

# DO NOT define a backward for to_dense
# See [Note: Sometimes view derivatives]
# - name: to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
# - name: to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor
#
- name: _to_dense(Tensor self, ScalarType? dtype=None) -> Tensor
self: to_dense_backward(grad, self)
- name: _to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor
self: to_dense_backward(grad, self, masked_grad)

- name: to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor
self: to_sparse_backward(grad, self.layout(), self.sym_blocksize())
Expand Down
8 changes: 7 additions & 1 deletion torch/_tensor_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -5360,10 +5360,16 @@ def callable(a, b) -> number
add_docstr_all(
"to_dense",
r"""
to_dense() -> Tensor
to_dense(dtype=None, *, masked_grad=True) -> Tensor

Creates a strided copy of :attr:`self` if :attr:`self` is not a strided tensor, otherwise returns :attr:`self`.

Keyword args:
{dtype}
masked_grad (bool, optional): If set to ``True`` (default) and
:attr:`self` has a sparse layout then the backward of
:meth:`to_dense` returns ``grad.sparse_mask(self)``.

Example::

>>> s = torch.sparse_coo_tensor(
Expand Down
4 changes: 2 additions & 2 deletions torch/autograd/gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,10 +770,10 @@ def _check_outputs(outputs) -> None:
# it is easier to call to_dense() on the sparse output than
# to modify analytical jacobian
raise ValueError('Sparse output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.')
'Please call to_dense(masked_grad=...) on the output of fn for gradcheck.')
if any(t.layout == torch._mkldnn for t in outputs if isinstance(t, torch.Tensor)): # type: ignore[attr-defined]
raise ValueError('MKLDNN output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.')
'Please call to_dense(masked_grad=...) on the output of fn for gradcheck.')


def _check_no_differentiable_outputs(func, inputs, func_out, eps, *, is_forward_ad) -> bool:
Expand Down
10 changes: 5 additions & 5 deletions torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -737,11 +737,11 @@ void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
if (!v->type()->cast<TensorType>()) {
continue;
}
auto from_mkldnn =
graph
->create(
c10::Symbol::fromQualString("aten::to_dense"), {v, none_value})
->insertAfter(subgraph_node);
auto from_mkldnn = graph
->create(
c10::Symbol::fromQualString("aten::to_dense"),
{v, none_value, none_value})
->insertAfter(subgraph_node);
Comment on lines +740 to +744
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This changeset fixes the following failure:

2023-03-06T20:19:12.8348868Z RuntimeError: 0 INTERNAL ASSERT FAILED at "/var/lib/jenkins/workspace/torch/csrc/jit/ir/alias_analysis.cpp":621, please report a bug to PyTorch. We don't have an op for aten::to_dense but it isn't a special case.  Argument types: Tensor, NoneType,
2023-03-06T20:19:12.8348877Z 
2023-03-06T20:19:12.8348950Z Candidates:
2023-03-06T20:19:12.8349167Z 	aten::to_dense(Tensor self, ScalarType? dtype=None, bool? masked=False) -> Tensor

v->replaceAllUsesAfterNodeWith(from_mkldnn, from_mkldnn->output());
}

Expand Down
4 changes: 2 additions & 2 deletions torch/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -1335,8 +1335,8 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.sum_to_size: lambda self, size: -1,
Tensor.tile: lambda self, *reps: -1,
Tensor.to: lambda self, dtype, non_blocking=False, copy=False, memory_format=torch.preserve_format: -1,
Tensor.to_dense: lambda self, dtype=None: -1,
Tensor._to_dense: lambda self, dtype=None: -1,
Tensor.to_dense: lambda self, dtype=None, *, masked_grad=None: -1,
Tensor._to_dense: lambda self, dtype=None, masked_grad=None: -1,
Tensor.to_sparse: lambda self: -1,
Tensor.tolist: lambda self: -1,
Tensor.to_mkldnn: lambda self: -1,
Expand Down