Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[primTorch] Minor improvements to doc and impl of gaussian_nll_loss #85612

Closed
wants to merge 9 commits into from

Conversation

nkaretnikov
Copy link
Collaborator

@nkaretnikov nkaretnikov commented Sep 25, 2022

@pytorch-bot
Copy link

pytorch-bot bot commented Sep 25, 2022

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/85612

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit bd18a12:
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

nkaretnikov added a commit that referenced this pull request Sep 25, 2022
ghstack-source-id: e6ae96ccb06d9f070d87c55daf3a9b7b5bf3e140
Pull Request resolved: #85612
nkaretnikov added a commit that referenced this pull request Sep 25, 2022
ghstack-source-id: a3dad8e4d7bd1d8e86a2dff56bd60468eb507e65
Pull Request resolved: #85612
@nkaretnikov
Copy link
Collaborator Author

nkaretnikov commented Sep 25, 2022

Note: see #53964 and #56469 for context on the size checks and clamping before the loss calculation. The ref code is almost a 1-to-1 copy of the Python implementation.

nkaretnikov added a commit that referenced this pull request Sep 25, 2022
ghstack-source-id: 149d24524a2f38c2195efaeb6bbf6e74fff729b1
Pull Request resolved: #85612
nkaretnikov added a commit that referenced this pull request Sep 25, 2022
ghstack-source-id: bb6522bc64188cc25ec32173ee8a670d2d193218
Pull Request resolved: #85612
@nkaretnikov nkaretnikov marked this pull request as ready for review September 26, 2022 08:49
@nkaretnikov nkaretnikov requested review from lezcano and removed request for albanD, ngimel, mruberry and jbschlosser September 26, 2022 08:49
Copy link
Collaborator

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an interesting one. This function is already implemented in Python in core, so it's not clear to me whether we want to re-implement it in PrimTorch. The only benefit I see in doing this is that we can implement some promotion rules for it, and the fact that this way we would have all our implementations in the same place... WDYT @mruberry? Do

torch/_refs/nn/functional/__init__.py Outdated Show resolved Hide resolved
torch/_refs/nn/functional/__init__.py Outdated Show resolved Hide resolved
torch/_refs/nn/functional/__init__.py Outdated Show resolved Hide resolved
torch/_refs/nn/functional/__init__.py Outdated Show resolved Hide resolved
@mruberry mruberry self-requested a review September 26, 2022 13:22
@@ -2777,8 +2777,10 @@ def gaussian_nll_loss(
Args:
input: expectation of the Gaussian distribution.
target: sample from the Gaussian distribution.
var: tensor of positive variance(s), one for each of the expectations
in the input (heteroscedastic), or a single one (homoscedastic).
var: same shape as the input, or same shape as the input but with the
Copy link
Collaborator

@mruberry mruberry Sep 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's cool this PR is updating the documentation for the function.

When describing the parameters it's important to start with what's most important about them. In this case, I don't think it's the shape of var that's most important, but that var is a tensor describing the variances of either a multivariate normal distribution or multiple independent distributions (see question above).

This documentation also seems a little odd to me because input and target refer to "the Gaussian distribution", but this seems wrong because

  1. the loss can be used for multiple Gaussian distributions simultaneously (because it supports batches)
  2. I believe the correct semantic interpretation for this loss is that it works on multivariate normal distributions OR multiple normal distributions, and not just one?

So there may be more we can do here

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just removed the doc from the functional part. See the above comment WRT the Gaussian bit.

@nkaretnikov
Copy link
Collaborator Author

reminder to close this one once we fix the docs: #53392

@nkaretnikov nkaretnikov changed the title [primTorch] Add ref for gaussian_nll_loss, add error inputs [primTorch] Minor improvements to doc and impl of gaussian_nll_loss Sep 28, 2022
nkaretnikov added a commit that referenced this pull request Sep 28, 2022
ghstack-source-id: 86473881086564d9b4f502b9325dadac7bc83643
Pull Request resolved: #85612
nkaretnikov added a commit that referenced this pull request Sep 29, 2022
Fixes #53392.

ghstack-source-id: 351e85882f2035d50cdf09572d0374a5dcd7f39a
Pull Request resolved: #85612
@@ -2817,14 +2817,15 @@ def gaussian_nll_loss(
raise ValueError(reduction + " is not a valid value for reduction")

# Clamp for stability
var = var.clone()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what was the purpose of cloning here since the same variable is used later anyway

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clone + in-place was a worse version of doing clamp out of place I think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I know why. It's either the original behavior (with no_grad and inplace clamp) or just this (without any context): var = var.clamp(min=eps). I couldn't find any other code that wouldn't break the following tests:

python -m pytest test/test_modules.py -k GaussianNLLLoss -vvv
python -m pytest test/test_ops_gradients.py -k gaussian_nll_loss -vvv

This claims that doing it without no_grad will "cause divergence," but the tests pass locally.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any thoughts @albanD ? Clamping without no_grad LGTM, but I know we've historically done it with the no_grad...

Copy link
Collaborator

@albanD albanD Oct 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clamping without no_grad will zero out a bunch of gradients. We definitely don't want that!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about figuring out what's the value we want this function to take at var = 0 and return that? Although we would need to somehow deal with the gradients as well...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: @lezcano told me offline there are plans (IIUC) to have a "framework" that would help with numerical issues like this one, so postponing for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my point is that, at the moment, we don't care about gradients on PrimTorch (just yet). But we do care about gradients in PyTorch. As such, given that this is a function exposed in PyTorch, it should be correct. In particular, this change makes the gradients of this function to be incorrect and should be reverted.

Then, we should revisit at some point how to approach the whole point of the gradients in PrimTorch.

functorch/test/test_ops.py Show resolved Hide resolved
@@ -2817,14 +2817,15 @@ def gaussian_nll_loss(
raise ValueError(reduction + " is not a valid value for reduction")

# Clamp for stability
var = var.clone()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

clone + in-place was a worse version of doing clamp out of place I think.

nkaretnikov added a commit that referenced this pull request Sep 30, 2022
Fixes #53392.

ghstack-source-id: b624b8533e25af07e503849acae0c5c7d9865e1f
Pull Request resolved: #85612
nkaretnikov added a commit that referenced this pull request Sep 30, 2022
Fixes #53392.

ghstack-source-id: c0f85f80b031f5c287440c98748ad82c7768851e
Pull Request resolved: #85612
@nkaretnikov
Copy link
Collaborator Author

nkaretnikov commented Oct 1, 2022

@lezcano @mruberry PTAL.

Summary:

  • removed no_grad from clamp and changed to functional clamp
  • removed a call to .item, which resulted in better test coverage
  • improved the docs to document allowed var inputs
  • this doesn't add a ref (based on our discussion in another PR) because it's a pure Python impl
  • added error inputs.

@facebook-github-bot
Copy link
Contributor

/easycla

As part of the transition to the PyTorch Foundation, this project now requires contributions be covered under the new CLA. See #85559 for additional details.

This comment will trigger a new check of this PR. If you are already covered, you will simply see a new "EasyCLA" check that passes. If you are not covered, a bot will leave a new comment with a link to sign.

@linux-foundation-easycla
Copy link

CLA Not Signed

@@ -2762,6 +2762,7 @@ def poisson_nll_loss(
return ret


# TODO: Pure Python impl - don't add a primTorch ref
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would remove this comment -- it's not really a TODO since there's nothing to do

# If var is the same shape as input, it's the heteroscedastic case.
# If var is *not* the same shape as input, it's the homoscedastic case.
#
# To support broadcasting, the following sub-cases are allowed in the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really broadcasting, though. We should be clear about what, exactly, it does, and not use a similar concept, which could confuse the reader

@@ -314,13 +314,13 @@ class GaussianNLLLoss(_Loss):
where :attr:`eps` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same
size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.
of 1 or have one fewer dimension (when comparing from the outermost dimension, with all other
sizes being the same) for correct later broadcasting.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not use the term "broadcasting" here because this is not the same thing

@@ -314,13 +314,13 @@ class GaussianNLLLoss(_Loss):
where :attr:`eps` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same
size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the phrase "final dimension" is commonly understood

@@ -314,13 +314,13 @@ class GaussianNLLLoss(_Loss):
where :attr:`eps` is used for stability. By default, the constant term of
the loss function is omitted unless :attr:`full` is ``True``. If ``var`` is not the same
size as ``input`` (due to a homoscedastic assumption), it must either have a final dimension
of 1 or have one fewer dimension (with all other sizes being the same) for correct broadcasting.
of 1 or have one fewer dimension (when comparing from the outermost dimension, with all other
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just break this into cases. I'm not sure the order of comparison makes it that clear. Either var has the same shape, has the same shape except its innermost dimension is 1, or has the same shape except it's "missing" the innermost dimension. In the last two cases the variance is assumed to be the same for each distribution (the distributions are assumed to be homoscedastic)

Copy link
Collaborator

@mruberry mruberry left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The functional and test changes look good, but I think we should take the time to be more diligent with the docs. It could save us a lot of headaches in the future.

One option would be to separate the doc changes into another PR.

@@ -2773,19 +2774,6 @@ def gaussian_nll_loss(
r"""Gaussian negative log likelihood loss.

See :class:`~torch.nn.GaussianNLLLoss` for details.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep this.

While we do like to reduce document redundancy, in a more perfect world modules would just wrap functions, and the functions would be documented.

@nkaretnikov nkaretnikov marked this pull request as draft November 11, 2022 19:30
@github-actions
Copy link

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions bot added the Stale label Jan 10, 2023
@github-actions github-actions bot closed this Feb 9, 2023
@facebook-github-bot facebook-github-bot deleted the gh/nkaretnikov/6/head branch June 8, 2023 18:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants