Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -3621,6 +3621,13 @@ def check_empty(sparse_shape, nnz, dense_shape, coalesce):
x = self._gen_sparse(sparse_dim, nnz_val, empty_sparse_shape, dtype, device, coalesce)[0]
check(self, x, x)

def check_autograd(x, y):
if dtype in {torch.double, torch.cdouble}:
xa = x.detach().clone().requires_grad_(True)
ya = y.detach().clone().requires_grad_(True)
gradcheck(lambda a, b: (a * b).to_dense(), (xa, ya), check_sparse_nnz=True)
gradcheck(lambda a, b: (a * b).to_dense(), (ya, xa), check_sparse_nnz=True)

for dim in range(len(shape) + 1):
sub_shape = shape[dim:]
sparse_dim = len(sub_shape) // 2
Expand All @@ -3630,12 +3637,14 @@ def check_empty(sparse_shape, nnz, dense_shape, coalesce):
x = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
y = self._gen_sparse(sparse_dim, nnz, sub_shape, dtype, device, coalesced)[0]
check(self, x, y)
check_autograd(x, y)

# check broadcasting in dense dims
for d in range(sparse_dim, len(sub_shape)):
new_shape = sub_shape[:d] + (1,) + sub_shape[d + 1:]
y = self._gen_sparse(sparse_dim, nnz, new_shape, dtype, device, coalesced)[0]
check(self, x, y)
check_autograd(x, y)

@coalescedonoff
@dtypes(*all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16))
Expand Down
6 changes: 3 additions & 3 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1124,12 +1124,12 @@
values: gather_with_keepdimed_indices(self_t, dim, indices, keepdim)

- name: mul.Tensor(Tensor self, Tensor other) -> Tensor
self: mul_tensor_backward(grad, other, self.scalar_type())
other: mul_tensor_backward(grad, self, other.scalar_type())
self: mul_tensor_backward(grad, other, self, self.scalar_type())
other: mul_tensor_backward(grad, self, other, other.scalar_type())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be careful not to save an extra tensor for backward!

result: other_t * self_p + self_t * other_p

- name: mul.Scalar(Tensor self, Scalar other) -> Tensor
self: mul_tensor_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self.scalar_type())
self: mul_tensor_backward(grad, at::lift_fresh(at::scalar_to_tensor(other)), self, self.scalar_type())
result: self_t * other

- name: mv(Tensor self, Tensor vec) -> Tensor
Expand Down
59 changes: 56 additions & 3 deletions torch/csrc/autograd/FunctionsManual.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,9 +505,62 @@ Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask) {
: grad.masked_select(mask).sum();
}

Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) {
auto out = grad * other.conj();
return handle_r_to_c(self_st, out);
Tensor mul_tensor_backward(
const Tensor& grad,
const Tensor& other,
const Tensor& self,
ScalarType self_st) {
// If a.is_sparse and b.is_sparse, then mul(a, b) expects:
// 1. a.dim() == b.dim(),
// 2. a.shape[:a.sparse_dim] == b.shape[:b.sparse_dim].
// Note, however, that mul(a, b) will handle broadcasting in dense dims.
// Autograd, however, will have issues with that and will try running
// either a.sum(d=...) or b.sum(d=...) to propagate sparse grads of the
// same shape as the inputs'. However, sum(d=...) is not implemented
// for sparse tensors. So, instead, we explicitly reduce over grads'
// dense dims and create new sparse gradient tensors that now match
// the shape of the inputs.
const auto handle_sparse_sparse_case = [](Tensor& self_grad,
const Tensor& self) -> void {
// No broadcasting in dense dims, no need to modify anything.
if (self_grad.sizes() == self.sizes()) {
return;
}
// Here self dense dims broadcast over self_grad dense dims.
// This means we need to reduce over broadcasted dims in self_grad.values
// such that self_grad.values.sum(...).shape == self.values.shape,
// otherwise autograd would try running self_grad.sum(dim=...), and
// sum(dim=...) is not implemented for sparse tensors.
const auto self_sparse_dim = self.sparse_dim();
const auto self_grad_indices = self_grad._indices();
const auto self_grad_values = self_grad._values();
auto values_reduction_dims = at::DimVector();
// Find which dense dims broadcast
for (const auto d : c10::irange(self_sparse_dim, self.dim())) {
// If d broadcasts...
if (self.sizes()[d] == 1) {
// ... then map d to a dim relative to values to reduce over.
values_reduction_dims.push_back(d - self_sparse_dim + 1);
}
}
// Produce a "reduced" grad now.
self_grad = at::_sparse_coo_tensor_unsafe(
self_grad_indices,
// Need to specify dtype because sum can promote.
self_grad_values.sum(
values_reduction_dims,
/*keepdim=*/true,
self_grad_values.scalar_type()),
self.sizes());
Comment on lines +547 to +554
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be related to the failing autograd tests in the CI? If so, we could look into ways to change _values() in-place (via self_grad_values.set_(self_grad_values.sum(...))) as well as updating the dense part of self_grad shape in-place.

Copy link
Collaborator Author

@nikitaved nikitaved Oct 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failing tests are because self should not be inlined into the backward function in combination with copy_ used in the failing test.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mul_tensor_backward uses only self.sizes() and self.sparse_dim(). Does the above mean that self cannot be an argument to mul_tensor_backward but self.options() (includes layout and scalar type) and self shape information can be?

Copy link
Collaborator Author

@nikitaved nikitaved Oct 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it cannot be used. Layout is not inlined yet, so I had to modify that part. Also, sizes are not implemented for nested tensors, hence separate dispatch branches in derivatives.yaml. The changes are not in this PR, unfortunately , they are in the mul(sparse, sparse) PR where we decided to revert all grad changes :(

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so this PR is blocked atm, right?

Copy link
Collaborator Author

@nikitaved nikitaved Oct 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because probably it is best to upgrade autograd to handle shape reductions and propagation of different from inputs' layouts grads.

};

auto self_grad = grad * other.conj();
// Handle sparse only case
if (other.is_sparse() && self.is_sparse()) {
// NOTE: Modifies self_grad
handle_sparse_sparse_case(self_grad, self);
}
return handle_r_to_c(self_st, self_grad);
}

Tensor div_tensor_self_backward(
Expand Down
6 changes: 5 additions & 1 deletion torch/csrc/autograd/FunctionsManual.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ at::Tensor pow_backward_exponent(
const at::Tensor& exponent,
at::Tensor result);
at::Tensor angle_backward(at::Tensor grad, const at::Tensor& self);
at::Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st);
at::Tensor mul_tensor_backward(
const Tensor& grad,
const Tensor& other,
const Tensor& self,
ScalarType self_st);
at::Tensor div_tensor_self_backward(
Tensor grad,
Tensor other,
Expand Down