-
Notifications
You must be signed in to change notification settings - Fork 22.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for target with class probs in CrossEntropyLoss #61044
Conversation
🔗 Helpful links
💊 CI failures summary and remediationsAs of commit 9ca7ae3 (more details on the Dr. CI page): ✅ None of the CI failures appear to be your fault 💚
1 job timed out:
🚧 1 fixed upstream failure:These were probably caused by upstream breakages that were already fixed.
Please rebase on the
|
e95f9ae
to
abd8c30
Compare
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks pretty good, some suggestions and comments
aten/src/ATen/native/LossNLL.cpp
Outdated
// Compute weighted mean | ||
ret = ret.sum() / (target * weight_).sum(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This... seems like a good argument for why a user would expect reduction=mean to return (-(input * target * weight_).mean()), I'm having a hard time coming up with a use case where someone wants probabilities and wants to do a weighted mean over the probabilities and weights.
At any rate we should probably be consistent with our hard targets cross_entropy function...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I agree it doesn't make sense to do a weighted mean over probabilities and weights. I did it this way here to maintain consistency with the hard target cross-entropy loss- with one-hot targets, the results are equivalent between soft and hard if done like this :/
Also to be fully precise: I think a mean computation that fits user intuitions would be -(input * target * weight_).sum(1).mean()
. As in the non-weighted calculation, sum(1)
should be taken first before the mean to be correct.
0ba16ea
to
7e7cace
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor comments but otherwise this LGTM!
weight, | ||
reduction, | ||
ignore_index); | ||
Tensor ret; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NRVO yes, but I'd expect the compiler to do RVO. Not sure how to test for this though; feel free to leave the code as-is.
.. math:: | ||
\ell(x, y) = \begin{cases} | ||
\frac{\sum_{n=1}^N l_n}{N}, & | ||
\text{if reduction} = \text{`mean';}\\ | ||
\sum_{n=1}^N l_n, & | ||
\text{if reduction} = \text{`sum'.} | ||
\end{cases} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mean case is only true if the input and target are of size (N, C). Otherwise, we divide by a factor that isn't the batch size -- for a tensor of shape (N, C, d1, d2, ..., dk) we end up dividing by a factor of tensor.numel() / C
, right?
Maybe this is OK because we can view data of (N, C, d1, d2, ..., dk) as being a "batch" of (N, d1, d2, ..., dk) distributions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, N
is doing a lot of work implicitly here. I do think that data of shape (N, C, d1, d2, ..., dk)
is conceptually a batch of (N, d1, d2, ..., dk)
distributions (and as mentioned before, I think d1, ..., dk
should have been added to the left of C
for this to be clearer, but that ship has sailed).
While this was carried over to some extent from the old docs, each item in the formula is now more explicitly defined, so I think it needs to be more precise. Specifically, "N is the batch size" should change. Borrowing some terminology from KLDivLoss
, we could do something like:
N spans the minibatch dimension as well as dimensions d1, ..., dk in the case of K-dimensional loss
wdyt?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N spans the minibatch dimension as well as dimensions d1, ..., dk in the case of K-dimensional loss
That sounds good
5137b89
to
8dab918
Compare
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@jbschlosser has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator. |
@jbschlosser merged this pull request in a42345a. |
the commit message in a42345a (from the first post of this PR) says |
@VitamintK That's right; it was only added to |
Fixes #11959
Alternative approach to creating a new
CrossEntropyLossWithSoftLabels
class. This PR simply adds support for "soft targets" AKA class probabilities to the existingCrossEntropyLoss
andNLLLoss
classes.Implementation is dumb and simple right now, but future work can add higher performance kernels for this case.