Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sparse semantics support in gradcheck (2nd try) #95405

Closed
wants to merge 4 commits into from

Conversation

pearu
Copy link
Collaborator

@pearu pearu commented Feb 23, 2023

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 23, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/95405

Note: Links to docs will display an error until the docs builds have been completed.

❗ 1 Active SEVs

There are 1 currently active SEVs. If your PR is affected, please view them below:

❌ 2 Failures

As of commit 9cc8a29:

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

pearu added a commit that referenced this pull request Feb 23, 2023
ghstack-source-id: 27a15673993edd8f7d199aa9ff3ed90b066207fd
Pull Request resolved: #95405
@pearu pearu self-assigned this Feb 23, 2023
@pearu pearu added module: sparse Related to torch.sparse module: autograd Related to torch.autograd, and the autograd engine in general open source release notes: sparse release notes category ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR labels Feb 23, 2023
@pearu pearu added this to In progress in Sparse tensors via automation Feb 23, 2023
test/test_sparse.py Outdated Show resolved Hide resolved
Replaces #94714 that was reverted due to #94714 (comment)




cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
pearu added a commit that referenced this pull request Feb 23, 2023
ghstack-source-id: 638094c34df5cb164e7488179126490e85d6e6e4
Pull Request resolved: #95405
Replaces #94714 that was reverted due to #94714 (comment)




cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
pearu added a commit that referenced this pull request Feb 24, 2023
ghstack-source-id: fb5de9d0fa81e42ddd728971073ff4acc6ece20a
Pull Request resolved: #95405
Replaces #94714 that was reverted due to #94714 (comment)




cc alexsamardzic nikitaved cpuhrsch amjames bhosmer ezyang albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
pearu added a commit that referenced this pull request Feb 24, 2023
ghstack-source-id: 513500c246013bd6dce10388f137de84b5e86270
Pull Request resolved: #95405
Sparse tensors automation moved this from In progress to Reviewer approved Feb 24, 2023
@pearu
Copy link
Collaborator Author

pearu commented Feb 24, 2023

@pytorchbot merge

@pytorch-bot pytorch-bot bot added the ciflow/trunk Trigger trunk jobs on your pull request label Feb 24, 2023
@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@nikitaved
Copy link
Collaborator

Do we actually benefit from it in the long run? FWIW, #95550 (comment).

@pearu
Copy link
Collaborator Author

pearu commented Feb 27, 2023

Do we actually benefit from it in the long run? FWIW, #95550 (comment).

This PR fixes gradcheck for sparse inputs under non-masked semantics while keeping the masked semantics support.

So, I don't understand the question under the assumption that tensors with sparse layouts are semantically equivalent to strided tensors. Do you challenge this assumption?

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 27, 2023

@pearu, I would argue there is a better fix, namely, to_dense should not impose any masking, and that if the masking behavior is tested, it could be done in the forward function with the the use of differentiable sparse_mask. See examples in my point 4. Therefore, we already have the tools to fix it without any additional complexity in the gradcheck, without masked=True and without check_sparse_nnz=True altogether.

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@pytorchmergebot
Copy link
Collaborator

@pearu
Copy link
Collaborator Author

pearu commented Feb 27, 2023

From the discussions #95550 (comment), above, and DM with @nikitaved , there exists another approach for fixing gradcheck without introducing the masked kw argument as explained below.

First, we'll postulate the following requirements:

  1. Tensors using sparse layouts are semantically equivalent to tensors with strided layout, both for operations as well as their autograd results

  2. gradcheck analytical path should use a user-provided function and its inputs without any transformations to these to mimic the user application environment (recall that for the numerical path, the inputs must be densified so that perturbations can be applied to unspecified elements as well)

  3. gradcheck operates under non-masked semantics only

The above should be sufficient for gradcheck to be always successful on a user-provided function that uses operations that backward functions implement non-masked semantics.

Now, if the user-specified operation uses operations that backward use masked semantics, the user must also specify the corresponding masks that are applied within the user-specified function to inputs using sparse_mask or a similar tool, or by the gradcheck (requires specification of masks that relate to particular inputs, say, via torch.masked.tensor). The application of masks to respective inputs within the user-specified operations will suppress the perturbations (from gradcheck numerical jacobian computations) of unspecified elements.

@cpuhrsch
Copy link
Contributor

@pytorchbot merge -m "Leftover failures are unrelated"

@pytorch-bot
Copy link

pytorch-bot bot commented Feb 27, 2023

❌ 🤖 pytorchbot command failed:

@pytorchbot: error: unrecognized arguments: -m Leftover failures are unrelated

usage: @pytorchbot [-h] {merge,revert,rebase,label,drci} ...

Try @pytorchbot --help for more info.

@cpuhrsch
Copy link
Contributor

@pytorchbot merge -f "Leftover failures are unrelated"

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged immediately since you used the force (-f) flag, bypassing any CI checks (ETA: 1-5 minutes).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 28, 2023

@pearu , I am not sure that providing masks is a good strategy:

  1. Under strided <-> sparse equivalence perturbations that result in a different mask might render masked functions non-differentiable. We have functions that are only differentiable in a rank-preserving neighborhood. For some reason, there is no option in gradcheck to preserve the rank, so I believe it is up to the user to generate a function that does what is expected, while preserving the general behavior of gradcheck no matter the space the inputs live in.

  2. We have functions that are not masked in forward, but are in the backward, for example torch.sparse.mm. There is no way to test backward for these function inside gradcheck if masked=True is eliminated unless gradcheck performs a pre-processing of the inputs with x = x.sparse_mask(x_mask), which contradicts the requirement of gradcheck to test the function provided with the inputs provided. This is to support point 1 in favor of giving all the freedom to the user.

  3. Less complexity for gradcheck without any masks. Why would sparse inputs be treated differently? They are semantically equivalent to dense, so let them be just like dense so gradcheck is agnostic to the nature of the inputs. It is up to the user to make sure all the input are processed and restricted to satisfy the requirements of the function in question being differentiable. This will also eliminate us from taking the blame when gradcheck returns a false positive.

  4. For masked semantics gradcheck has to derive that corresponding Jacobian entries are always zero for perturbations not in the mask. Not showing this and just restricting to the mask might produce false positives. Again, this could be done in an agnostic to the inputs' nature way with a proper input function.

How do we even define a Masked Semantics anyway?
It seems to me a function f satisfying a masked semantics has to satisfy f(x + dx) = f((x + dx).sparse_mask(x)) (an alternative definition of Sparse Parametrization for smooth functions?). Under this definition sparse.mm is not even a masked semantics function. Do we define it as functions that projects the backward gradients on the non-zero elements in the backward so that we can say that sparse.mm supports Masked Semantics? So, we do not even have a strict definition of what Masked Semantics even mean, so I am not sure gradcheck has to take responsibility here...

@pearu
Copy link
Collaborator Author

pearu commented Feb 28, 2023

@nikitaved

Under strided <-> sparse equivalence perturbations that result in a different mask might render masked functions non-differentiable

What do you mean by "perturbations that result in a different mask"? When masks are provided as inputs, these are typically bool tensors that are never perturbed. In general, only float or complex input tensors that have the requires_grad flag set to True, are perturbed, all other inputs are not.

I agree that the masked kw argument is not sufficient in general. Here's an example that fails for both masked=True and masked=False cases:

>>> def foo(x):
...     y = torch.sparse.mm(x, x)
...     z = torch.mm(y, x)
...     return z
... 
>>> t = torch.tensor([[0, 1], [2, 3]], dtype=torch.float64).to_sparse().requires_grad_()
>>> torch.autograd.gradcheck(lambda x: foo(x).to_dense(), t, masked=False)
<snip>
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ 4.0000,  3.0000,  6.0000,  2.0000],
        [ 6.0000, 13.0000,  4.0000, 12.0000],
        [ 3.0000,  1.0000, 13.0000,  6.0000],
        [ 2.0000,  6.0000, 12.0000, 31.0000]], dtype=torch.float64)
analytical:tensor([[ 2.,  0.,  6.,  0.],
        [ 6., 13.,  4., 12.],
        [ 3.,  1., 13.,  6.],
        [ 2.,  6., 12., 31.]], dtype=torch.float64)

>>> torch.autograd.gradcheck(lambda x: foo(x).to_dense(), t, masked=True, check_sparse_nnz=True)
<snip>
torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 6.0000, 13.0000,  4.0000, 12.0000],
        [ 3.0000,  1.0000, 13.0000,  6.0000],
        [ 2.0000,  6.0000, 12.0000, 31.0000]], dtype=torch.float64)
analytical:tensor([[ 2.,  0.,  6.,  0.],
        [ 6., 13.,  4., 12.],
        [ 3.,  1., 13.,  6.],
        [ 2.,  6., 12., 31.]], dtype=torch.float64)

because foo uses functions that backward implementations assume different semantics.

Clearly, the above example makes sense only in masked semantics (because it uses operations that backwards are implemented for masked semantics). The numerical jacobian appears to be correct when the input mask is [[False, True], [True, True]] (this is implied from the fact that the implicit mask for torch.sparse.mm is the sparsity pattern of its input). To fix the analytical jacobian, the mask must be applied to the input when entering the function:

>>> def foo(x, x_mask):
...     x = x.sparse_mask(x_mask)
...     y = torch.sparse.mm(x, x)
...     z = torch.mm(y, x)
...     return z
... 
>>> t_mask = torch.tensor([[0, 1], [1, 1]], dtype=torch.bool).to_sparse()
>>> torch.autograd.gradcheck(lambda *args: foo(*args).to_dense(), (t, t_mask), masked=True, check_sparse_nnz=True)
True

(Btw, here I am using pytorch version that suffers from issue #95550 . When applying the patch from #95550 (comment), the example above still works.)

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 28, 2023

Oh, I thought the point was to manipulate mask inside the gradcheck. My mistake.
Still the concern is there, what is a Masked semantics anyway? What do we do with that? Do we plan to redefine it or remove altogether? I am in favor of redefinition, so that sparse.mm as it is right now does not even exist :) I will create an issue about it, along with a potential redefinition.

@pearu
Copy link
Collaborator Author

pearu commented Feb 28, 2023

what is Masked semantics anyway?

Use https://pytorch.org/docs/master/masked.html as a starting point for the answer.

One can also ask "what is a tensor anyway?". A tensor is a set of values that are arranged over a regular grid where the grid nodes are indexed. The index-value pair is called a tensor element. One can define various high-level operations with tensors using operations defined on the tensor values and taking into account the regular arrangement of values on the grid.
A masked tensor is a generalization of the tensor where its values are arranged over an irregular grid. To represent a masked tensor, one can use a pair of regular tensors: values and mask, where mask values are 0 or 1 so that values elements are defined only when the corresponding mask element has value 1.
Clearly, not all high-level operations with tensors can be generalized for masked tensors in an unambiguous way but one can still try by introducing various restrictions. For instance, one could require that adding two masked tensors is defined only when the masked tensors have the same mask. Or more generally, one could define the addition of masked tensors formally using the operations defined on regular tensors: (values1, mask1) + (values2, mask2) = (values1 + values2, mask1 * mask2).
Finally, to answer the original question, the Masked semantics is a study that defines the meaning of masked tensors and operations with masked tensors. Since we have regular tensors, then it is convenient to define the masked tensors and their operations in terms of regular tensors and operations with regular tensors as exemplified above.
That said, I don't think we have a clear understanding of what operations with masked tensors are needed in applications and what should be their definitions. For instance, take a matrix product and try to generalize this to masked tensors. There exist several definitions (e.g. (i) consider masked-out elements as 0 value, or (ii) consider masked-out elements as undefined, etc) but it is not clear which one will be useful for practical applications or which even make sense in terms of defining a well-defined algebra on masked tensors.

@nikitaved
Copy link
Collaborator

nikitaved commented Feb 28, 2023

what is Masked semantics anyway?

Use https://pytorch.org/docs/master/masked.html as a starting point for the answer.

One can also ask "what is a tensor anyway?". A tensor is a set of values that are arranged over a regular grid where the grid nodes are indexed.

Understanding of "what is a tensor" is important to understand what a sparse tensor is, and what the functions that manipulate them do. Let me just gather some thoughts...

A real tensor, matrix for simplicity, of size m x n is isomorphic to R^{n x m} and is such a vector space with the canonical basis {Delta_ij}_kl = 1 if i == k and j == l, 0 otherwise. A sparse tensor is hence represented as S = \sum_{i, j} s_ij * Delta_ij, where {i, j} \in S.indices(), and s_ij is the corresponding value .

There is a clear bijection between a sparse matrix S and it's dense counterpart S.to_dense(). Hence, to_dense and to_sparse define isomorphisms between the vector space of dense matrices and the vector space of sparse matrices. As such, under the isomorphic map to_dense we can use all the operations defined for dense tensors with sparse inputs, and we get the immediate relationship for maps that map dense inputs to dense inputs, namely
f(d_1, ..., d_n) <=> f(d_1.to_sparse(), ..., d_n.to_sparse()) <=> f(d_1.to_sparse().to_dense(), ..., d_n.to_sparse().to_dense()), this is how we get our behavior of sparse being just an optimization layout for functions defined over dense tensors.

Masked Semantics, or better be called Sparse Parametrization(?) is more tricky when translating the mappings defined over dense inputs to sparse inputs, because dense mappings map full bases to full bases, not subsets of bases to subsets of bases. Recall the representation of a sparse tensor
S = \sum_{i, j} s_ij * Delta_ij, where {i, j} \in S.indices(), and s_ij is the corresponding value.
Now, any function g(S) that respects Masked Semantics actually maps the basis {Delta_ij, where {i, j} \in S.indices()}, so if we define supp(S) := {Delta_ij | {i, j} \in S.indices()}, then any perturbation of S, dS has to be in the span of supp(S).
And this is what gradcheck is doing, it takes a perturbation dS \in span(supp(S) and computes g(S + dS), all fine here.
Since gradcheck actually tests (to_dense o g)(S), and since to_dense does support Masked Semantics, we actually get that gradcheck tests a function g that maps a "masked tensor" to a "masked tensor", so that we get that
dg(S) \in span(supp(g(S))).

Disabling the masked semantics in to_dense.backaward will turn to_dense into a function that preserves the "natural" isomorphism between the dense and sparse tensors, and it will make it differentiable everywhere! This implies that we can still test the Masked Semantics grads with gradcheck, but not only do we need to mask the inputs, but also the outputs. Something along the lines of:

x = ...to_sparse().requires_grad_(True)
x_mask = x
y = ...to_sparse().requires_grad_(True)
y_mask = y
res_mask = f(x.detach(), y.detach())

def g(x, y):
    x = x.sparse_mask(x_mask)
    y = y.sparse_mask(y_mask)
    res = f(x, y)
    res = res.sparse_mask(res_mask)
    return res

gradcheck(g, [x, y], ...)

The Jacobian for g should show that irrelevant basis vectors get zero, and if something is wrong, g is not really a function that does support Masked Semantics. Considering sparse tensors under the to_dense isomorphism it will imply that this function should satisfy f(x + dx) == f((x + dx).sparse_mask(x)) with dx in a sufficiently small neighborhood of x, which should be tested properly, i.e backward grads are zero for entries not in the mask and that JVP is zero in the entries that do not belong to the output's mask f(x).

My whole point above and here, @pearu, was just to understand better and advocated that this is a better way to test Masked Semantics grads than having just masked=True, which can still give us false positives. Do we agree that this could be a nice follow-up altogether, @pearu, I mean fixing both to_dense and setting masked=False to be the default behavior ? Then we can fix all the tests for which masked=True, and if they are green, we are golden, otherwise we might have issues in the way how forward/backward is actually handled for "masked" functions.

Although masking the result should not be necessary if the function is assumed to map to "dense" instead, provided that it projects the grads correctly onto the inputs' subspaces... So, I believe it is important to understand what Masked Semantics is, it's relationship to dense tensors, and domains/co-domains of functions that manipulate them as to understand when things are differentiable and where they are not, because with the current implementation of to_dense,backward we can say it is not differentiable under the sparse-dense isomorphism because for a sparse input S and a non-zero perturbation dS not in the supp(S) we have that:

to_dense(S + dS)_{supp(dS)} = dS != 0 != (to_dense(S)_grad * dS)_{supp(dS)} = 0,
where _{supp(dS)} assumes restriction to supp(dS).

So more reasons to fix to_dense I believe, because this might bite the user and us big time.
With all said and done we can see that sparse.mm is a map from "masked" tensors to a dense tensor sparse.mm(x, y).to_dense(), while softmax, for example, is a map from a masked tensor to a masked tensor. As such, the result of softmax needs masking in gradcheck (testing forward/backward) before doing to_dense (once it is fixed) under the natural sparse-dense isomorphism.

@pearu
Copy link
Collaborator Author

pearu commented Feb 28, 2023

@nikitaved thanks for the formalized discussion! I think we are mostly on the same page but I have a few notes.

  1. when we have agreed that regular tensors with strided and sparse layouts are isomorphic within non-masked or regular semantics of tensors then in the formal discussion we can just talk about tensors without referring to the used storage format. In fact, the masked tensor (defined as a pair of values and mask tensors) may have values and mask using both strided layout equally to using any sparse layout.

  2. the ops to_dense and to_sparse are isomorphisms between the space of dense and sparse tensors under the implicit assumption that unspecified elements map to elements with zero value and elements with zero value map to unspecified elements, respectively. We should be very careful with this approach because all conclusions to be drawn will be valid under the assumption that unspecified elements are equivalent to zeros and vice-versa. In fact, an important point in masked semantics is that there exists no relation between unspecified elements (also called masked-out or ignored elements) and particular values because unspecified elements just do not exist (as the adjective "unspecified" also suggests) within masked semantics. If we decide to introduce the so-called fill value (be it 0 or some other finite or infinite value) then we are not in the domain of masked semantics anymore but in the domain of regular semantics.

    To put it differently, I hope we can overcome naturally relating sparse tensors with masked tensors to respect the fundamental assumption that sparse tensors (in terms of storage format) are regular tensors similar to strided tensors. If the mask and the indices set of a sparse tensor match then we just have a great optimization opportunity in implementing the operations and their backwards, but in general, the mask and the indices set of sparse tensors are unrelated.

  3. "Do we agree that ... fixing both to_dense and setting masked=False to be the default behavior ?" - I totally agree. In fact, we have torch.to_dense backward ignores unspecified elements in sparse inputs #95550 to fix to_dense, and masked=False is already the default for gradcheck due to Implement sparse semantics support in gradcheck (2nd try) #95405.

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 1, 2023
@nikitaved
Copy link
Collaborator

nikitaved commented Mar 1, 2023

3. "Do we agree that ... fixing both `to_dense` and setting `masked=False` to be the default behavior ?" - I totally agree. In fact, we have [torch.to_dense backward ignores unspecified elements in sparse inputs #95550](https://github.com/pytorch/pytorch/issues/95550) to fix `to_dense`, and `masked=False` is already the default for `gradcheck` due to [Implement sparse semantics support in gradcheck (2nd try) #95405](https://github.com/pytorch/pytorch/pull/95405).

👍

@pearu, thank you for your comments and this discussion!

Regarding 2: I believe we were careful in defining Sparse Parametrization. Even though, under zero nse, a sparse tensor is represented as S = \sum_{i, j} s_ij * Delta_ij, where {i, j} \in S.indices(), and s_ij is the corresponding value, in the mappings I discuss I consider not that we map this representation (just an embedding into dense tensors really), but the basis functions that span S, i.e Delta_ij and that values s_ij, these are the parameters. No other things exist outside of the support, supp(S).
Now, when a fill_value is introduced, and if S is represented as
S = \sum_{i, j in S.indices()} s_ij * Delta_ij + \sum_{i, j not in S.indices()} fill_value * Delta_ij, this is again just a representation, an embedding into a dense tensor.
And we have two options here as well: namely, is fill_value a parameter or not?
All-in-all, this part is not relevant to whether to_sparse or to_dense define an isomorphism. In this case they are just used as an embedding, not as an equivalence relationship. And then the goal is to understand the functions' domains/co-domains, and what are their parameters. We use this Sparse Parametrization in the "masked tensors" context, but it seems to me most of these operations could be trivialized with their torch.* + sparse_mask counterpart while making them implicitly differentiable. And the most interesting ops are the ones which map "masked" tensors to "masked" tensors (like softmax, for example), as these, especially in non-linear cases, might have non-trivial backward which is more complicated than just projecting grads onto the mask.
The point is we need these tensors to be available on the CPP side to implement backward for functions mapping them, and I believe gradcheck should in general be agnostic to any parametrization, it's task is to handle full bases. And then the user and us need to test functions properly, understanding domains and co-domains.

cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 2, 2023
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 5, 2023
cyyever pushed a commit to cyyever/pytorch_private that referenced this pull request Mar 27, 2023
pruthvistony added a commit to ROCm/pytorch that referenced this pull request May 2, 2023
@facebook-github-bot facebook-github-bot deleted the gh/pearu/94/head branch June 8, 2023 18:18
jhavukainen pushed a commit to kulinseth/pytorch that referenced this pull request Mar 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: autograd Related to torch.autograd, and the autograd engine in general module: sparse Related to torch.sparse open source release notes: sparse release notes category
Projects
Development

Successfully merging this pull request may close these issues.

None yet

5 participants