Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gradcheck produces false positives with sparse inputs when masked=False. #103518

Open
nikitaved opened this issue Jun 13, 2023 · 14 comments
Open
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@nikitaved
Copy link
Collaborator

nikitaved commented Jun 13, 2023

🐛 Describe the bug

As per title. As an example, let's consider the sampled_addmm method which is semantically equivalent to
sampled_addmm(s, m1, m2, alpha, beta) := alpha * (m1 @ m2).sparse_mask(s) + beta * s.

If we inspect the subgradient of sampled_addmm wrt s in derivatives.yaml, we find the following:

- name: sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor
  self: maybe_multiply(grad, beta.conj())

Note, that under the assumption of masked semantics this formula is correct, even though it does not account for the (mat1 @ mat2).sparse_mask(self) part. This follows from the sparse semantics that implies self.indices == (self + perturbation_of_self).indices. Hence we can expect gradcheck to work with masked=True:

In [1]: import torch

In [2]: x = torch.eye(3, dtype=torch.double).to_sparse_csr().requires_grad_(True)

In [3]: y = torch.rand(3, 3, dtype=torch.double)

In [4]: z = torch.rand(3, 3, dtype=torch.double)

In [5]: torch.autograd.gradcheck(lambda x: torch.sparse.sampled_addmm(x, y, z).to_dense(masked_grad=True), (x,), masked=True)
Out[5]: True

However, the situation is reversed for masked=False. In this case the backward formula for self should take alpha * (m1 @ m2).sparse_mask(self) into consideration, so it is expected for gradcheck with masked=False to fail.
This, however, does not happen:

In [6]: torch.autograd.gradcheck(lambda x: torch.sparse.sampled_addmm(x, y, z).to_dense(masked_grad=False), (x,), masked=False)
Out[6]: True

As per @pearu's insight, this happens during the densification process in gradcheck. Namely, it sometimes expands self.indices to full dimensions while producing a new sparse input self_densified. Unfortunately, sampled_addmm(self) and sampled_addmm(self_densified) are not equivalent in backward, because sampled_addmm(self_densified) should pass gradcheck with either masked=True or masked=False since it's mask is the whole space.

Versions

Current master.

cc @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer @ezyang @albanD @zou3519 @gqchen @soulitzer @lezcano @Varal7

@nikitaved nikitaved added module: sparse Related to torch.sparse module: autograd Related to torch.autograd, and the autograd engine in general labels Jun 13, 2023
@mikaylagawarecki mikaylagawarecki added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jun 13, 2023
@pearu
Copy link
Collaborator

pearu commented Jun 13, 2023

gradcheck with masked=False assumes that its input function func is layout-agnostic, that is, func(densify(x)).to_dense() == func(x).to_dense() holds for x with any layout, where densify is a function that materializes all unspecified elements of its input as zeros. Internally, gradcheck uses such a densify function for computing numerical gradients.

The torch.sparse.sampled_addmm is not a layout-agnostic function as its implementation explicitly uses the col/crow indices of its input to define the sparsity pattern used in the mathematical definition of sampled_addmm. In order words, torch.sparse.sampled_addmm is defined for masked semantics where the mask is defined by the input tensor indices.

In this case the backward formula for self should take alpha * (m1 @ m2).sparse_mask(self) into consideration..

I think this is the key point for determining if the raised issue is valid or not.

By definition, torch.sparse.sampled_addmm "performs a matrix multiplication of the dense matrices mat1 and mat2 at the locations specified by the sparsity pattern of input. The matrix input is added to the final result."
Assuming the non-masked semantics, the key question is: do want the "sparsity pattern" to be variable under perturbation of inputs or do we want it to be fixed?

Recall, in non-masked semantics, the "sparsity pattern" of a tensor (in terms of its indices set) does not define a mask because non-masked semantics is layout-agnostic. The sampled_addmm in non-masked semantics should be defined as

sampled_addmm(s, m1, m2, alpha, beta) := addmm(s, m1, m2, alpha, beta).sparse_mask(mask)

where mask = (s != 0). While mask depends on s, it is non-differentiable, and autograd will ignore it because mask.dtype is non-float/non-complex.

Btw, the title "gradcheck produces false positives.." may be misleading as the reported issue depends on the backward formula of sampled_addmm and the gradcheck numerical path does not use the backward formula at all. Roughly speaking, gradcheck compares the jacobians obtained numerically (using perturbations on function inputs) and analytically (using the backward formula) and if these match, it will return True. If the example in the issue description is expected to fail, then in the case of the false positive result, one should determine if the numerical jacobian is wrong, if the analytical jacobian is wrong, or if gradcheck is misused. Atm, I tend to think it is the latter as masked=False means that the sparsity pattern of the inputs should be unessential while in the case of torch.sparse.sampled_addmm it will be essential if we are going to change its backward formula, or if we would need to apply perturbations under the mask defined by the input indices (that begs using masked=True).

nikitaved added a commit that referenced this issue Jun 13, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 13, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@nikitaved
Copy link
Collaborator Author

nikitaved commented Jun 13, 2023

gradcheck with masked=False assumes that its input function func is layout-agnostic, that is, func(densify(x)).to_dense() == func(x).to_dense() holds for x with any layout, where densify is a function that materializes all unspecified elements of its input as zeros. Internally, gradcheck uses such a densify function for computing numerical gradients.

sampled_addmm does satisfy func(densify(x)).to_dense() == func(x).to_dense() holds for x.
But it is all related to the question of parametrization that we discussed in #95405. I still believe that having masked=True is a bad design as it might create a false illusion of "correct" gradients, because the user could supply any grad into autograd.grad, and if this gradient is not properly projected, the resulting gradient might be incorrect, and we cannot detect such cases with gradcheck set to masked=True, only with masked=False.
But then, again, masked=False in my opinion should compute the complete Jacobian (i.e. full basis) of the original function, not that of func(densify(x)), which is not the same as just func. gradcheck should be agnostic to any hidden parametrizations which is currently the case with strided inputs, and densify replaces one parametrization with another not necessarily equivalent. Also, then, it seems like densify should not be func agnostic, i.e. for softmax the neutral element is -inf, not 0. masked=False implies that ANY subset of materialized indices with the neutral value does not affect the result of the function as per func(x + subset_of_neutral_elems).to_dense(neutral_val=...) == func(x).to_dense(neutral_val=...) for any x and any possible subset of indices, but densify as it is now produces only a SINGLE parametrization, so it is insufficient for our needs...

By definition, torch.sparse.sampled_addmm "performs a matrix multiplication of the dense matrices mat1 and mat2 at the locations specified by the sparsity pattern of input. The matrix input is added to the final result."
Assuming the non-masked semantics, the key question is: do want the "sparsity pattern" to be variable under perturbation of inputs or do we want it to be fixed?

Recall, in non-masked semantics, the "sparsity pattern" of a tensor (in terms of its indices set) does not define a mask because non-masked semantics is layout-agnostic. The sampled_addmm in non-masked semantics should be defined as

sampled_addmm(s, m1, m2, alpha, beta) := addmm(s, m1, m2, alpha, beta).sparse_mask(mask)

where mask = (s != 0). While mask depends on s, it is non-differentiable, and autograd will ignore it because mask.dtype is non-float/non-complex.

The function sampled_addmm is parametrized by self and not by a pair (self, mask=self != 0). In this case self != 0 although non-differentiable, it has a non-trivial subgradient.

Btw, the title "gradcheck produces false positives.." may be misleading as the reported issue depends on the backward formula of sampled_addmm and the gradcheck numerical path does not use the backward formula at all. Roughly speaking, gradcheck compares the jacobians obtained numerically (using perturbations on function inputs) and analytically (using the backward formula) and if these match, it will return True. If the example in the issue description is expected to fail, then in the case of the false positive result, one should determine if the numerical jacobian is wrong, if the analytical jacobian is wrong, or if gradcheck is misused. Atm, I tend to think it is the latter as masked=False means that the sparsity pattern of the inputs should be unessential while in the case of torch.sparse.sampled_addmm it will be essential if we are going to change its backward formula, or if we would need to apply perturbations under the mask defined by the input indices (that begs using masked=True).

gradcheck produces True, which means that both numeric and analytic Jacobians are incorrect, since analytic formula for grads wrt self is incorrect because it misses the projection step and/or, if not for the sparse semantics, it should include the subgradient generated by the part alpha * (s != 0) * (mat1 @ mat2).

@pearu
Copy link
Collaborator

pearu commented Jun 14, 2023

sampled_addmm does satisfy func(densify(x)).to_dense() == func(x).to_dense() holds for x.

Ah, right. So, the requirements for func are more subtle.

gradcheck produces True, which means that both numeric and analytic Jacobians are incorrect, since analytic formula for grads wrt self is incorrect because it misses the projection step and/or, if not for the sparse semantics, it should include the subgradient generated by the part alpha * (s != 0) * (mat1 @ mat2).

The numerical path in gradcheck(func, x, masked=False) is equivalent to the path in gradcheck(func, densify(x), masked=False), that is, before computing numerical jacobian, all sparse inputs are mapped to sparse inputs that don't have unspecified elements (these are materialized to zero values). This means that the expression s != 0 is equivalent to 1 (notice that the actual implementation of torch.sparse.sampled_addmm does not use s != 0 nor sparse_mask, it just iterates over all specified elements which include the zero values due to densify pre-processing of inputs). Such degradation of the function definition means that gradcheck(torch.sparse.sampled_addmm, x, ..., masked=False) is not testing torch.sparse.sampled_addmm autograd-correctness, but of the torch.addmm equivalent.

Could you propose a fix to the backward formula of torch.sparse.sampled_addmm for the non-masked semantics case? This would provide a reproducer for the degradation defect in gradcheck for input functions that numerical jacobians may be sensitive to the densify usage approach.

@nikitaved
Copy link
Collaborator Author

nikitaved commented Jun 14, 2023

Could you propose a fix to the backward formula of torch.sparse.sampled_addmm for the non-masked semantics case? This would provide a reproducer for the degradation defect in gradcheck for input functions that numerical jacobians may be sensitive to the densify usage approach.

I am not sure that such a fix will cut it. The usage of densify is a problem since I cannot test what I want with masked=False.

Suppose x is a sparse 1D tensor with nnz elements and x.to_dense() is an n-dim vector. When masked=False, the gradcheck should evaluate func at func(x + eps * e_i) where {e_i}_{i=1}^n is a canonical basis of R^n. When masked=True, it evaluates func while applying perturbations only in the directions of specified elements along a corresponding subset of the canonical basis. densify is not needed at all.

Now suppose that masked=False is doing what is described in the paragraph above.

Suppose I have a function f that accepts sparse differentiable inputs and we have backward implemented for it, WHICH IS CORRECT in the domain it is supposed to be used in.
We have the following scenarios:

  • Simplest. f assumes no sparse semantics and just treats sparse inputs as an optimization. Then gradcheck should succeed with either masked=True or masked=False. Then testing masked=True is redundant in such cases.
    The user has to be informed about such properties in the documentation.
  • Simpler. Suppose the developer made a very robust function f that projects all the inputs/gradients onto the right domains, this means that the function does support masked semantics but can tolerate inputs/grads not in the domain in a differentiable manner. In this case, again, both masked=True and masked=False should succeed. If I want to really make sure that all my projections are correct. and that the sparse semantics is respected, I would compute the full Jacobian over the whole space and check that the Jacobian entries not corresponding to nnz indices are all zeroes. Neither masked=True nor masked=False allows me to do that.
    If we provide such a robust function, the user has to be informed about that!
  • Simple. Suppose f.backward assumes only sparse semantics in the backward. Then in the absence of domain/co-domain pojections masked=True should succeed and masked=False should fail, unless the sparse input/output spans the whole space.
    This is the problem with densify as it actually re-parametrizes functions in an unexpected way to cover the whole space.
    This case is the hardest for the user because the user should be really careful with the gradients and data manipulations. The users should be informed that it is their responsibility to handle domains/co-domains carefully.

sampled_addmm actually falls into the hardest, Simple category, and the user is not aware of that AT ALL!
While keeping the grads for sampled_addmm as is, and under the assumption of masked=False working as described above, one way to test it with masked=False could be:

def wrapped_sampled_addmm(self, mask, *args, **kwargs)
    return sampled_addmm(self.sparse_mask(mask), *args, **kwargs).sparse_mask(mask)

gradcheck(lambda self, mask: wrapped_sampled_addmm(self, mask, ...), (self, self.detach() != 0), masked=False)

wrapped_sampled_addmm is now falls into the Simpler category. It projects the input to handle arbitrary perturbations of gradcheck - check! It projects the ouput to handle arbitrary input grads - check! Ideally, we still need to inspect Jacobian to make sure that "irrelevant' entries are all zeros...

nikitaved added a commit that referenced this issue Jun 14, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 14, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
@pearu
Copy link
Collaborator

pearu commented Jun 14, 2023

Suppose x is a sparse 1D tensor with nnz elements and x.to_dense() is an n-dim vector. When masked=False, the gradcheck should evaluate func at func(x + eps * e_i) where {e_i}_{i=1}^n is a canonical basis of R^n.

Right. This is exactly what densify enables. Notice that densify is applied to function inputs, not to the function itself. The only drawback is that one loses the original sparsity pattern of inputs that the input function may use (as is the case for sparse.sampled_addmm) but in the context of non-masked semantics, where the layout is considered only a storage optimization, losing the original sparsity of inputs is insignificant.

densify is not needed at all.

One of the fundamental assumptions in gradcheck numerical path is that the inputs are perturbed in-place (for efficiency). The densification of sparse inputs allowed more-or-less transparent addition of sparse tensors support to autograd for non-masked semantics. The alternative to the densification pre-processing step is to perturb the inputs by creating a new input tensor (x + eps * e_i) that will have different indices set from x if i corresponds to a non-specified element in x. This breaks the fundamental assumption in gradcheck and requires considerable changes to the gradcheck numerical path. I think such changes are very likely to be rejected by the autograd team (the complexity of the needed changes is much higher than that of #97825 which was rejected as too complex), so I find the discussion about removing the densify step without the support from the autograd team just too unrealistic.

I am not sure that such a fix will cut it. The usage of densify is a problem since I cannot test what I want with masked=False

I think the fix to the backward formula would be a good way to get some attention to your suggestion as it would provide a reproducer for the possible gradcheck problem. Another approach for testing what you want is to rephrase the problem using strided tensors only (replace a sparse input with a pair of strided values and a mask tensor).

@nikitaved
Copy link
Collaborator Author

nikitaved commented Jun 14, 2023

Suppose x is a sparse 1D tensor with nnz elements and x.to_dense() is an n-dim vector. When masked=False, the gradcheck should evaluate func at func(x + eps * e_i) where {e_i}_{i=1}^n is a canonical basis of R^n.

Right. This is exactly what densify enables. Notice that densify is applied to function inputs, not to the function itself. The only drawback is that one loses the original sparsity pattern of inputs that the input function may use (as is the case for sparse.sampled_addmm) but in the context of non-masked semantics, where the layout is considered only a storage optimization, losing the original sparsity of inputs is insignificant.

Modification of inputs is wrong in general, gradcheck should treat inputs and functions as black boxes, there is no guaranteed that this re-parametrization is equivalent, no matter whether the inputs are sparse or dense. And, in fact, it breaks on functions like sampled_addmm.

densify is not needed at all.

One of the fundamental assumptions in gradcheck numerical path is that the inputs are perturbed in-place (for efficiency). The densification of sparse inputs allowed more-or-less transparent addition of sparse tensors support to autograd for non-masked semantics. The alternative to the densification pre-processing step is to perturb the inputs by creating a new input tensor (x + eps * e_i) that will have different indices set from x if i corresponds to a non-specified element in x. This breaks the fundamental assumption in gradcheck and requires considerable changes to the gradcheck numerical path. I think such changes are very likely to be rejected by the autograd team (the complexity of the needed changes is much higher than that of #97825 which was rejected as too complex), so I find the discussion about removing the densify step without the support from the autograd team just too unrealistic.

Sorry, what is the fundamental assumption of gradcheck? I always thought that PyTorch follows the principle of correctness first.

I am not sure that such a fix will cut it. The usage of densify is a problem since I cannot test what I want with masked=False

I think the fix to the backward formula would be a good way to get some attention to your suggestion as it would provide a reproducer for the possible gradcheck problem. Another approach for testing what you want is to rephrase the problem using strided tensors only (replace a sparse input with a pair of strided values and a mask tensor).

There is nothing to fix in backward, it is correct. Is densify documented? Is it being applied only to differentiable inputs? Is x -> densify(x) differentiable? How can I turn it off because it does something unexpected to me as a user? Can I have a differentiable functionundensify so that I could test g(x):=f(undensify(x)) before feeding into gradcheck?

It is just so much easier for gradcheck to assume that sparse is just an optimization layout, nothing more, nothing less. Masked semantics is the property of the function, and gradcheck should be agnostic to that. We used to have only masked=True before right? Well, I believe this is incorrect... I have no idea how may of sparse-supporting functions actually produce wrong gradients. IMHO, of course, but this issue is fundamental. How else would the user trust in sparse if the grads are potentially broken for a lot of functions in the torch.sparse namespace?

@pearu
Copy link
Collaborator

pearu commented Jun 14, 2023

Modification of inputs is wrong in general, ...
Sorry, what is the fundamental assumption of gradcheck? I always thought that PyTorch follows the principle of correctness first.

Oh, I feel like I should just reference https://github.com/pytorch/pytorch/blob/main/torch/autograd/gradcheck.py instead of trying to explain here how gradcheck works with sparse inputs in detail.

Is densify documented?

def _densify(x):

Is it being applied only to differentiable inputs?

no

Is x -> densify(x) differentiable?

densify is used only in the numerical path of gradcheck where differentiability is irrelevant.

How can I turn it off because it does something unexpected to me as a user?

As a user, you should not be concerned about the densify, it is an internal method. Otherwise, use masked=True to disable it.

Can I have a differentiable function undensify so that I could test g(x):=f(undensify(x)) before feeding into gradcheck?

densify is internal to gradcheck, so there is nothing to undensify outside of the gradcheck.

@nikitaved
Copy link
Collaborator Author

nikitaved commented Jun 14, 2023

How can I get a full Jacobian of a sparse function?
Back to the issue. Do we agree that gradcheck producing True with masked=False is incorrect? Backward is correct, so how do we fix that?
Just for the context, densify will still test a wrong function in the numeric path, and if a non-modified input is fed into the function in the analytic path, and yet the grads match, something is wrong.

@pearu
Copy link
Collaborator

pearu commented Jun 14, 2023

Do we agree that gradcheck producing True with masked=False is incorrect?

I am not certain about this. As I see it, with masked=False, the input function behavior is expected not to depend on the layout nor indices of its inputs. gradcheck(f, x, masked=False) and gradcheck(f, x.to_dense(), masked=False) should give the same results.

densify will still test a wrong function in the numeric path,

This does not make sense to me. There is one and only one function used everywhere in gradcheck - it is the input function.

@nikitaved
Copy link
Collaborator Author

nikitaved commented Jun 14, 2023

Do we agree that gradcheck producing True with masked=False is incorrect?

I am not certain about this. As I see it, with masked=False, the input function behavior is expected not to depend on the layout nor indices of its inputs. gradcheck(f, x, masked=False) and gradcheck(f, x.to_dense(), masked=False) should give the same results.

OK, the backward formula for self does not account for the mm part, right? This only means that this implementation is only correct when mask(x + perturb) == mask(x).

Now, let's assume mat1 = 2 * torch.eye(2, 2), mat2 = 3 * torch.eye(3, 3), alpha==beta==1.
Suppose
x = torch.zeros(2, 2); x[0, 0] = 1; x = x.to_sparse(), and
p = torch.zeros(2, 2); x[-1, -1] = eps; p = p.to_sparse(), then

f(x)[-1, -1] = 0, f(x + p)[-1, -1] = 6 + eps, so I guess the entry for J[f_11,x_11](x) -> inf.
This is the case when p and x have different sparsity patterns.
That was numeric path.
In the analytic path, reading from how backward is implemented for self we would have that
J[f_11, x_11](x) = 1. So, analytic is not finite while numeric is finite.
If p = eps * x, so it shares the sparsity pattern of x, then both paths would be finite and match (=1), this is what happens
when densify is used in the numeric path.

densify will still test a wrong function in the numeric path,

This does not make sense to me. There is one and only one function used everywhere in gradcheck - it is the input function.

But f(x) and g(x):=f(densify(x) are different functions, right?

@pearu
Copy link
Collaborator

pearu commented Jun 14, 2023

@nikitaved your example has many typos but when eliminating these and taking into account that gradcheck uses the 2nd-order central finite differences for computing numerical jacobian, then J[f_11,x_11](x) is actually approaching 1, not inf (f(x) is never evaluated, instead, f(x+p) and f(x-p) are evaluated)).

But f(x) and g(x):=f(densify(x)) are different functions, right?

No. Two functions are equal when their values are equal for all possible argument combinations. Within non-masked semantics, f(x) and g(x) are equal because unspecified values in x are treated as implicit zeros: f(x).to_dense() == g(x).to_dense() for any x.

In the case of f being sparse.sampled_addmm, f maps unspecified values to unspecified values (read: implicit zero values), explicit zero values to explicit zero values (by the definition of spy), and non-zero values v with index to alpha*(mat1@mat2)[index] + beta*v.

That said, notice that there is a discrepancy in the current implementation of sparse.sampled_addmm and its documentation:

>>> m1 = torch.eye(2, 2) * 2
>>> m2 = torch.eye(2, 2) * 3
>>> x = torch.tensor([[1, 0], [0, 0.]]).to_sparse_csr()
>>> densified_x = (x.to_dense() - 99).to_sparse_csr() + 99. * torch.ones(x.shape).to_sparse_csr()
>>> x
tensor(crow_indices=tensor([0, 1, 1]),
       col_indices=tensor([0]),
       values=tensor([1.]), size=(2, 2), nnz=1, layout=torch.sparse_csr)
>>> densified_x
tensor(crow_indices=tensor([0, 2, 4]),
       col_indices=tensor([0, 1, 0, 1]),
       values=tensor([1., 0., 0., 0.]), size=(2, 2), nnz=4,
       layout=torch.sparse_csr)
>>> torch.sparse.sampled_addmm(x, m1, m2)
tensor(crow_indices=tensor([0, 1, 1]),
       col_indices=tensor([0]),
       values=tensor([7.]), size=(2, 2), nnz=1, layout=torch.sparse_csr)
>>> torch.sparse.sampled_addmm(densified_x, m1, m2)
tensor(crow_indices=tensor([0, 2, 4]),
       col_indices=tensor([0, 1, 0, 1]),
       values=tensor([7., 0., 0., 6.]), size=(2, 2), nnz=4,
       layout=torch.sparse_csr)

that is, the implementation considers specified elements as non-zeros (when evaluating spy) even when the corresponding values are explicit zeros.

@nikitaved
Copy link
Collaborator Author

nikitaved commented Jun 14, 2023

The order of the approximation does not matter, really, it only applies to differentiable functions. This function is not differentiable since it's right derivative is not finite which the example from above shows.

These functions are different because operate over different masks, and mask is a hidden parameter. Specified indices specify parameters, and these are different.

@pearu
Copy link
Collaborator

pearu commented Jun 14, 2023

The order of the approximation does not matter, really. This function is not differentiable since it's right derivative is not finite which the example from above shows.

Right. What matters is the usage of the central finite difference scheme which will never give inf contrary to using forward or backward finite difference schemes.

@pearu
Copy link
Collaborator

pearu commented Jun 14, 2023

These functions are different because operate over different masks, and mask is a hidden parameter.

Such functions are not using non-masked semantics. Hence, using masked=False is invalid usage and the gradcheck result is undefined (there is no way for gradcheck to determine if the input function and the specified masked option are consistent).

nikitaved added a commit that referenced this issue Jun 16, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 16, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 16, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 16, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 17, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 17, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 19, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 19, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 19, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 19, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 20, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 20, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 22, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 22, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 23, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 23, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 23, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 23, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 26, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 26, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 26, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 26, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 27, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 27, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 27, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 27, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 28, 2023
…ient wrt self"

As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
nikitaved added a commit that referenced this issue Jun 28, 2023
As per title. Previous gradient was only correct under the Sparse Semantics, i.e. with`alpha * (mat1 @ mat2)` ignored. However, then, it is wrongly parametrized in the backward pass, as we need to project the gradient in a generic case.
Under this parametrization we can expect `gradcheck` to succeed with either `masked=True` or `masked=False` even after #103518 is fixed.

cc pearu .




cc alexsamardzic pearu cpuhrsch amjames bhosmer ezyang gchanan albanD zou3519 gqchen soulitzer Lezcano Varal7

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants