Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for target with class probs in CrossEntropyLoss #61044

Closed
wants to merge 12 commits into from

Conversation

jbschlosser
Copy link
Contributor

Fixes #11959

Alternative approach to creating a new CrossEntropyLossWithSoftLabels class. This PR simply adds support for "soft targets" AKA class probabilities to the existing CrossEntropyLoss and NLLLoss classes.

Implementation is dumb and simple right now, but future work can add higher performance kernels for this case.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 30, 2021

🔗 Helpful links

💊 CI failures summary and remediations

As of commit 9ca7ae3 (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



1 job timed out:

  • pytorch_xla_linux_bionic_py3_6_clang9_test

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

torch/nn/modules/loss.py Outdated Show resolved Hide resolved
torch/nn/modules/loss.py Show resolved Hide resolved
@jbschlosser jbschlosser changed the title Support for target with class probs in NLLLoss / CrossEntropyLoss Support for target with class probs in CrossEntropyLoss Jul 22, 2021
@jbschlosser jbschlosser requested a review from zou3519 July 22, 2021 17:10
@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks pretty good, some suggestions and comments

aten/src/ATen/native/LossNLL.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/LossNLL.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/LossNLL.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/LossNLL.cpp Outdated Show resolved Hide resolved
Comment on lines 483 to 484
// Compute weighted mean
ret = ret.sum() / (target * weight_).sum();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This... seems like a good argument for why a user would expect reduction=mean to return (-(input * target * weight_).mean()), I'm having a hard time coming up with a use case where someone wants probabilities and wants to do a weighted mean over the probabilities and weights.

At any rate we should probably be consistent with our hard targets cross_entropy function...

Copy link
Contributor Author

@jbschlosser jbschlosser Jul 26, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I agree it doesn't make sense to do a weighted mean over probabilities and weights. I did it this way here to maintain consistency with the hard target cross-entropy loss- with one-hot targets, the results are equivalent between soft and hard if done like this :/

Also to be fully precise: I think a mean computation that fits user intuitions would be -(input * target * weight_).sum(1).mean(). As in the non-weighted calculation, sum(1) should be taken first before the mean to be correct.

torch/testing/_internal/common_nn.py Show resolved Hide resolved
torch/testing/_internal/common_nn.py Show resolved Hide resolved
torch/nn/modules/loss.py Show resolved Hide resolved
torch/nn/functional.py Show resolved Hide resolved
torch/nn/functional.py Show resolved Hide resolved
Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor comments but otherwise this LGTM!

aten/src/ATen/native/LossNLL.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/LossNLL.cpp Outdated Show resolved Hide resolved
aten/src/ATen/native/LossNLL.cpp Outdated Show resolved Hide resolved
weight,
reduction,
ignore_index);
Tensor ret;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NRVO yes, but I'd expect the compiler to do RVO. Not sure how to test for this though; feel free to leave the code as-is.

torch/nn/modules/loss.py Show resolved Hide resolved
Comment on lines +1069 to +1077
.. math::
\ell(x, y) = \begin{cases}
\frac{\sum_{n=1}^N l_n}{N}, &
\text{if reduction} = \text{`mean';}\\
\sum_{n=1}^N l_n, &
\text{if reduction} = \text{`sum'.}
\end{cases}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image

The mean case is only true if the input and target are of size (N, C). Otherwise, we divide by a factor that isn't the batch size -- for a tensor of shape (N, C, d1, d2, ..., dk) we end up dividing by a factor of tensor.numel() / C, right?

Maybe this is OK because we can view data of (N, C, d1, d2, ..., dk) as being a "batch" of (N, d1, d2, ..., dk) distributions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, N is doing a lot of work implicitly here. I do think that data of shape (N, C, d1, d2, ..., dk) is conceptually a batch of (N, d1, d2, ..., dk) distributions (and as mentioned before, I think d1, ..., dk should have been added to the left of C for this to be clearer, but that ship has sailed).

While this was carried over to some extent from the old docs, each item in the formula is now more explicitly defined, so I think it needs to be more precise. Specifically, "N is the batch size" should change. Borrowing some terminology from KLDivLoss, we could do something like:

N spans the minibatch dimension as well as dimensions d1, ..., dk in the case of K-dimensional loss

wdyt?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

N spans the minibatch dimension as well as dimensions d1, ..., dk in the case of K-dimensional loss

That sounds good

@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Contributor

@jbschlosser merged this pull request in a42345a.

@VitamintK
Copy link

the commit message in a42345a (from the first post of this PR) says This PR simply adds support for "soft targets" AKA class probabilities to the existing CrossEntropyLossandNLLLoss classes., but it actually only adds the change for CrossEntropyLoss, right?

@jbschlosser
Copy link
Contributor Author

jbschlosser commented Jan 6, 2023

the commit message in a42345a (from the first post of this PR) says This PR simply adds support for "soft targets" AKA class probabilities to the existing CrossEntropyLossandNLLLoss classes., but it actually only adds the change for CrossEntropyLoss, right?

@VitamintK That's right; it was only added to CrossEntropyLoss.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feature request] Support soft target distribution in cross entropy loss
4 participants