-
Notifications
You must be signed in to change notification settings - Fork 25.4k
mul(sparse, sparse): extend autograd support to cases with broadcasted dense dims. #84929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -505,9 +505,62 @@ Tensor masked_fill_backward(const Tensor& grad, const Tensor& mask) { | |
: grad.masked_select(mask).sum(); | ||
} | ||
|
||
Tensor mul_tensor_backward(Tensor grad, Tensor other, ScalarType self_st) { | ||
auto out = grad * other.conj(); | ||
return handle_r_to_c(self_st, out); | ||
Tensor mul_tensor_backward( | ||
const Tensor& grad, | ||
const Tensor& other, | ||
const Tensor& self, | ||
ScalarType self_st) { | ||
// If a.is_sparse and b.is_sparse, then mul(a, b) expects: | ||
// 1. a.dim() == b.dim(), | ||
// 2. a.shape[:a.sparse_dim] == b.shape[:b.sparse_dim]. | ||
// Note, however, that mul(a, b) will handle broadcasting in dense dims. | ||
// Autograd, however, will have issues with that and will try running | ||
// either a.sum(d=...) or b.sum(d=...) to propagate sparse grads of the | ||
// same shape as the inputs'. However, sum(d=...) is not implemented | ||
// for sparse tensors. So, instead, we explicitly reduce over grads' | ||
// dense dims and create new sparse gradient tensors that now match | ||
// the shape of the inputs. | ||
const auto handle_sparse_sparse_case = [](Tensor& self_grad, | ||
const Tensor& self) -> void { | ||
// No broadcasting in dense dims, no need to modify anything. | ||
if (self_grad.sizes() == self.sizes()) { | ||
return; | ||
} | ||
// Here self dense dims broadcast over self_grad dense dims. | ||
// This means we need to reduce over broadcasted dims in self_grad.values | ||
// such that self_grad.values.sum(...).shape == self.values.shape, | ||
// otherwise autograd would try running self_grad.sum(dim=...), and | ||
// sum(dim=...) is not implemented for sparse tensors. | ||
const auto self_sparse_dim = self.sparse_dim(); | ||
const auto self_grad_indices = self_grad._indices(); | ||
const auto self_grad_values = self_grad._values(); | ||
auto values_reduction_dims = at::DimVector(); | ||
// Find which dense dims broadcast | ||
for (const auto d : c10::irange(self_sparse_dim, self.dim())) { | ||
// If d broadcasts... | ||
if (self.sizes()[d] == 1) { | ||
// ... then map d to a dim relative to values to reduce over. | ||
values_reduction_dims.push_back(d - self_sparse_dim + 1); | ||
} | ||
} | ||
// Produce a "reduced" grad now. | ||
self_grad = at::_sparse_coo_tensor_unsafe( | ||
self_grad_indices, | ||
// Need to specify dtype because sum can promote. | ||
self_grad_values.sum( | ||
values_reduction_dims, | ||
/*keepdim=*/true, | ||
self_grad_values.scalar_type()), | ||
self.sizes()); | ||
Comment on lines
+547
to
+554
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could this be related to the failing autograd tests in the CI? If so, we could look into ways to change There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The failing tests are because There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it cannot be used. Layout is not inlined yet, so I had to modify that part. Also, sizes are not implemented for nested tensors, hence separate dispatch branches in derivatives.yaml. The changes are not in this PR, unfortunately , they are in the mul(sparse, sparse) PR where we decided to revert all grad changes :( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, so this PR is blocked atm, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, because probably it is best to upgrade autograd to handle shape reductions and propagation of different from inputs' layouts grads. |
||
}; | ||
|
||
auto self_grad = grad * other.conj(); | ||
// Handle sparse only case | ||
if (other.is_sparse() && self.is_sparse()) { | ||
// NOTE: Modifies self_grad | ||
handle_sparse_sparse_case(self_grad, self); | ||
} | ||
return handle_r_to_c(self_st, self_grad); | ||
} | ||
|
||
Tensor div_tensor_self_backward( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should be careful not to save an extra tensor for backward!