-
Notifications
You must be signed in to change notification settings - Fork 21.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
torch.to_dense backward ignores unspecified elements in sparse inputs #95550
Comments
Personally, I'm in favor of more explicit solutions (such as |
I am not sure we have to enforce any masked semantics, be it autograd (looking at
In [1]: import torch
In [2]: x_data = torch.rand(3, 3)
In [3]: y_data = torch.rand(3, 3)
In [4]: x1 = x_data.mul(x_data < 0.5).to_sparse().requires_grad_(True)
In [5]: x2 = x_data.mul(x_data < 0.5).to_sparse().requires_grad_(True)
In [6]: y1 = y_data.mul(y_data < 0.5).to_sparse().requires_grad_(True)
In [7]: y2 = y_data.mul(y_data < 0.5).to_sparse().requires_grad_(True)
In [8]: def custom_sparse_mm(x, y):
...: x = x.sparse_mask(x)
...: y = y.sparse_mask(y)
...: res = x @ y
...: res = res.sparse_mask(res)
...: return res
...:
In [9]: torch.autograd.grad(custom_sparse_mm(y1, x1), (y1, x1), torch.ones(3, 3).to_sparse())
<ipython-input-8-4d9400ee8449>:4: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at /home/nik/git/Quansight/pytorch/aten/src/ATen/SparseCsrTensorImpl.cpp:54.)
res = x @ y
Out[9]:
(tensor(indices=tensor([[0, 1, 2, 2, 2],
[1, 0, 0, 1, 2]]),
values=tensor([0.8071, 0.6476, 0.6476, 0.8071, 0.2095]),
size=(3, 3), nnz=5, layout=torch.sparse_coo),
tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 2],
[0, 1, 2, 0, 1, 2, 2]]),
values=tensor([0.1705, 0.1705, 0.1705, 0.6261, 0.6261, 0.6261, 0.3517]),
size=(3, 3), nnz=7, layout=torch.sparse_coo))
In [10]: torch.autograd.grad(torch.sparse.mm(y2, x2), (y2, x2), torch.ones(3, 3).to_sparse())
Out[10]:
(tensor(indices=tensor([[0, 1, 2, 2, 2],
[1, 0, 0, 1, 2]]),
values=tensor([0.8071, 0.6476, 0.6476, 0.8071, 0.2095]),
size=(3, 3), nnz=5, layout=torch.sparse_coo),
tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 2],
[0, 1, 2, 0, 1, 2, 2]]),
values=tensor([0.1705, 0.1705, 0.1705, 0.6261, 0.6261, 0.6261, 0.3517]),
size=(3, 3), nnz=7, layout=torch.sparse_coo))
In [11]: x1
Out[11]:
tensor(indices=tensor([[0, 0, 0, 1, 1, 1, 2],
[0, 1, 2, 0, 1, 2, 2]]),
values=tensor([0.0543, 0.1362, 0.4572, 0.0271, 0.4451, 0.3349, 0.2095]),
size=(3, 3), nnz=7, layout=torch.sparse_coo, requires_grad=True)
In [12]: y1
Out[12]:
tensor(indices=tensor([[0, 1, 2, 2, 2],
[1, 0, 0, 1, 2]]),
values=tensor([0.3481, 0.0818, 0.0887, 0.2780, 0.3517]),
size=(3, 3), nnz=5, layout=torch.sparse_coo, requires_grad=True) In general, any composition of operations This approach:
>>> a = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse().requires_grad_()
>>> def f(a):
... a = a.sparse_mask(a)
... res = torch.Tensor.t(a)
... res = res.sparse_mask(res)
... return res.to_dense()
...
>>> torch.autograd.gradcheck(lambda a: f(a), a, check_sparse_nnz=True)
True Alternatively, without explicit >>> a = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).requires_grad_(True)
>>> mask = a.detach().to_sparse()
>>> def f(a, mask):
... res = a.sparse_mask(mask)
... res = torch.Tensor.t(res)
... res = res.sparse_mask(res)
... return res.to_dense()
...
>>> torch.autograd.gradcheck(lambda a: f(a, mask), (a,))
True |
@nikitaved Do you agree or disagree that |
@pearu, I agree with that statement. But I also propose an alternative solution to the ones you posited. |
Great! IIUC, your solution is to not have
With the above, So, how would you test |
@pearu, like this. Based on our PRs:
>>> x = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse().requires_grad_()
>>> mask = x.detach()
>>> torch.autograd.gradcheck(lambda t: t.sparse_mask(mask).to_dense(), (x,), masked=False)
True
>>> torch.autograd.gradcheck(lambda t: torch.Tensor.t(t.sparse_mask(mask)).to_dense(), (x,), masked=False)
True
>>> def f(x, mask):
... x = x.sparse_mask(mask)
... x = torch.sparse.mm(x, x)
... return x.to_dense()
>>> torch.autograd.gradcheck(lambda x: f(x, mask), (x,), masked=False)
True
|
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False`. The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=True)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True)) ``` Fixes #95550 cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False`. The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=True)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True)) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False`. The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=True)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True)) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False`. The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=True)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True)) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False`. The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=True)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True)) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False`. The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=True)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True)) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. ~The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default.~ As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. ~The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default.~ As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. ~The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default.~ As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. ~The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default.~ As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. ~The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default.~ As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The masked kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked==False` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. ~The PR is BC-breaking in the sense that the masked semantics has been the default semantics for `to_dense` (its backward ignores unspecified elements in the input) and this PR enables the non-masked semantics as the default.~ As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes #95550 cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7 jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 [ghstack-poisoned]
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes pytorch/pytorch#95550 Pull Request resolved: pytorch/pytorch#96095 Approved by: https://github.com/cpuhrsch
As in the title. The `masked_grad` kw argument is required for `to_dense` backward to distinguish the expected semantics of sparse tensors. `masked_grad=True` means that the `to_dense` backward will apply a mask to the returned gradient where the mask is defined by the input indices. The default semantics implies `masked_grad==True` for BC but see the [comment](https://github.com/pytorch/pytorch/pull/96095/files#diff-d4df180433a09071e891d552426911c227b30ae9b8a8e56da31046e7ecb1afbeR501-R513) in `to_dense_backward`. As a consequence, existing code that is run through autograd engine must replace `.to_dense()` calls with `.to_dense(masked_grad=False)`. For example, ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense()) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense()) ``` (recall, gradcheck has `masked=False` as default) must be updated to ```python torch.autograd.gradcheck(lambda x: torch.sum(x, [0]).to_dense(masked_grad=False)) torch.autograd.gradcheck(lambda x: torch.sparse.sum(x, [0]).to_dense(masked_grad=True), masked=True) ``` Fixes pytorch/pytorch#95550 Pull Request resolved: pytorch/pytorch#96095 Approved by: https://github.com/cpuhrsch
Issue description
For historical reasons, torch.to_dense backward on sparse inputs implements masked semantics that contradicts the current interpretation of sparse tensors, that is, sparse tensors are semantically equivalent to strided tensors. The use of a sparse format is considered merely a memory optimization that does not define a mask for operations on sparse tensors.
The masked semantics of tensors is currently implemented in
torch.masked
(the future) andtorch.sparse
(to be deprecated).This issue breaks autograd on sparse tensors because the use of
to_dense
method is required for operations resulting in sparse tensors, for example:Code example
When using the recommendation from the above example, gradcheck using non-masked semantics fails:
but succeeds under masked semantics:
Possible solutions
A solution is implemented in Add sparse semantics context manager. #94728 that introduces a global Context flag that defines how the operations and their backward implementations should interpret the unspecified elements of sparse tensors.
Introduce a
masked
kw argument (default isFalse
) toto_dense
that enables explicit control of semantics when usingto_dense
method. For example, the following examples should succeed:Discussion
Both solutions have pros and cons (see discussions in #94728) and both will be BC-breaking because the default semantics of operations in
torch
namespace need to switch to non-masked semantics. This cannot be avoided when pursuing the idea of considering sparse tensors semantically equivalent to strided tensors. Fortunately, most tensor operations (fromtorch
namespace) on sparse tensors already implement non-masked semantics, for example:The failure above is expected because both transposing and indexing use non-masked semantics (the use of indexing op circumvents "Sparse output is not supported at gradcheck yet" exception and the current issue with
to_dense
).From the future perspective, when the masked and non-masked semantics will be well separated between
torch.masked
andtorch
namespaces, the use ofmasked
kw argument both ingradcheck
andto_dense
become unnecessary because the semantics will be defined by the operations (whether these are fromtorch.masked
ortorch
namespaces). Ironically, this sounds very similar to the original plan of havingtorch.sparse
andtorch
namespaces for different semantics. I guess the main difference is that the memory optimization and masked semantics features from using sparse tensors become uncoupled features.System Info
#95405 that implements
masked
kw argument support togradcheck
.cc @alexsamardzic @nikitaved @cpuhrsch @amjames @bhosmer @ezyang @gchanan @albanD @zou3519 @gqchen @soulitzer @lezcano @Varal7
The text was updated successfully, but these errors were encountered: