New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of nn.CrossEntropyLossWithSoftLabels #59824
Conversation
💊 CI failures summary and remediationsAs of commit 9b02944 (more details on the Dr. CI page and at hud.pytorch.org/pr/59824):
🕵️ 1 new failure recognized by patternsThe following CI failures do not appear to be due to upstream breakages: pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build (1/1)Step: "Build" (full log | diagnosis details | 🔁 rerun)
|
5c7a3b4
to
d674167
Compare
4d325ea
to
773bde9
Compare
@zou3519 Are you cool with reviewing this? I think you have the most familiarity with this issue within the frontend group |
Yup, can do! |
@@ -1121,6 +1125,91 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor: | |||
ignore_index=self.ignore_index, reduction=self.reduction) | |||
|
|||
|
|||
class CrossEntropyLossWithSoftLabels(_Loss): | |||
r"""This criterion combines :class:`~torch.nn.LogSoftmax` and :class:`~torch.nn.KLDivLoss` in |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If anyone else is interested in reading the rendered docs: https://14108352-65600975-gh.circle-artifacts.com/0/docs/generated/torch.nn.CrossEntropyLossWithSoftLabels.html?highlight=crossentropy#torch.nn.CrossEntropyLossWithSoftLabels
torch/nn/modules/loss.py
Outdated
|
||
.. math:: | ||
L = l_{n,c} = y_{n,c} \cdot \left( \log y_{n,c} - \\ | ||
\log\left(\frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} \right) \right) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed that in #11959, @ssnl dropped the constant term. In elegy, a jax nn library, the constant term is also dropped. I haven't been able to figure out how to read the TF source code for categorial_crossentropy yet, so idk what they're doing there.
Should we drop the constant term as well? This would change the implementation of this loss function from being a composition of LogSoftmax and KLDivLoss to being something else that is (1) more efficient to compute and (2) optimizes the same way (assuming your targets don't require grad...) but (3) isn't exactly the quantity users may expect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is the TF source code for categorical_crossentropy and it looks like they drop the constant term as well: https://github.com/keras-team/keras/blob/0eb1954092e2110868e8ef06381cebac891e01e9/keras/backend.py#L4872
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great question- definitely true that it'd be more efficient to drop the constant from an optimization perspective. Not sure what the impact is for 3 (user experience), since the one-hot case will still generally be equivalent between soft and hard cross losses whether the constant is dropped or not (since the constant=0 for that case).
FWIW, SsnL doesn't think the constant should be dropped by default: #11959 (comment). I wonder if it makes sense for this PR to leave as-is and possibly add a flag for dropping the constant later on?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable. Let's go ahead with this unless we get feedback from the other folks in the thread that they do want the constant dropped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did some more digging and I believe the constant is dropped in TF because that corresponds with the actual mathematical definition of cross entropy (see #11959 (comment)).
Also, the TF docs say that they assume one-hot labels, and the constant is always 0 for that case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had some comments on the docs and the testing (I'm not sure what's being tested). Other than that, the implementations look correct to me
@@ -110,6 +110,7 @@ torch::nn::PairwiseDistance|Yes|No | |||
torch::nn::L1Loss|Yes|No | |||
torch::nn::MSELoss|Yes|No | |||
torch::nn::CrossEntropyLoss|Yes|No | |||
torch::nn::CrossEntropyLossWithSoftLabels|Yes|No |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's been a long time since I've seen the parity tracker. Is there a comment or documentation somewhere about what tests this auto-generates?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No docs that I'm aware of :( My understanding is that it does this:
- Instantiates a module in both C++ / python using the
constructor_args
/cpp_constructor_args
entries in thecriterion_tests
dict entry corresponding to the module. If those aren't specified, it defaults to no args, which is okay for this loss. - It sends
input
/target
through the C++ / python modules - It compares the resulting outputs
I noticed that the parity tracker checks functional forms as well, so I added that.
94b6dd4
to
aff5c4e
Compare
19e4bf0
to
9b02944
Compare
- Input: :math:`(*, C)` where :math:`*` represents any number of dimensions (including none) and | ||
`C = number of classes` | ||
- Target: :math:`(*, C)`, same shape as the input | ||
- Output: If :attr:`reduction` is ``'none'``, then :math:`(*, C)`, otherwise scalar |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyTorch CrossEntropyLoss accepts input of shape (N, C, *) and target of shape (N, *) where * is any number of dimensions. Do the shapes here make CrossEntropyLoss and CrossEntropyLossWithSoftLabels inconsistent?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In TF images are NHWC by default (instead of NCHW in PyTorch) so that's why they don't have this problem...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right that they're inconsistent, but there's a few things to consider that made this choice hard:
- As you pointed out below, target shape == input shape for
CrossEntropyLossWithSoftLabels
, so that part can't be consistent withCrossEntropyLoss
- We don't want to unnecessarily bake in the need for
N
, since that will have to be undone for the no-batch-dim support work. To be consistent withCrossEntropyLoss
, this would mean supporting shapes(C)
,(N, C)
,(N, C, d1)
,(N, C, d1, d2)
, etc. (i.e. sometimesC
is in at dim 0 and sometimes it's at dim 1). I find this to be pretty confusing and honestly don't understand why the original CE did it this way, whend1
,d2
, etc. are semantically just more batch dimensions.
An interesting thing to consider could be supporting a configurable location for the C
dim (i.e. the dim that log-softmax is taken over). But we couldn't easily make the default match CE's behavior, since the location of the C
dim changes as the number of dims changes (location is 0 or -1 for shape (C)
, 1 or -1 for shape (N, C)
, and 1 for shapes (N, C, d1, ...)
).
Also, admittedly confusingly, the C
in NHWC
/ NCHW
refers to "channels" (e.g. red, green, blue) while the C
here refers to "number of classes".
input = torch.randn(N, C) | ||
target = torch.tensor([1, 2, 0, 1, 2], dtype=torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should test correctness on an example with more than 2 dimensions
@@ -16387,6 +16389,50 @@ def cosine_distance(x, y): | |||
self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6) | |||
self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6) | |||
|
|||
@onlyOnCPUAndCUDA | |||
def test_cross_entropy_loss_with_soft_labels(self, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: It would be nice to test correctness for some cases where target
is not one-hot-encoded
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that the tests generated from the criterion_tests
entry check against a reference_fn
for the non-one-hot encoded target case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Implementation looks correct and some minor comments on the testing. I had two main concerns:
- naming. Is it really OK for us to call this nn.CrossEntropyLossWithSoftLabels if it doesn't exactly compute "cross entropy"? Should we provide an option to actually cross_entropy and have that be the default? The implementation of that would not be very difficult (it's just
- (input *F.log_softmax(target)).sum()
and we could optimize that with some cpu/cuda kernels if necessary in the future - Multidimensional shape behavior: CrossEntropyLoss takes in inputs and targets of shape (N, C, *) and (N, *) respectively. To be consistent with that I would have expected us to accept inputs and targets of shape (N, C, *) and (N, C, *) but I might be missing something here
Closing since #61044 landed. |
Fixes #11959. The implementation is a combination of
log_softmax
+kl_div
(see #11959 (comment) for details).Btw this specific name was chosen for a few reasons:
SoftMarginLoss
, where "soft" has a different meaning applying to the margin (in contrast to a hard margin)BCELossWithLogits
MultiLabelMarginLoss
andMultiLabelSoftMarginLoss
This ruled out:
SoftCrossEntropyLoss
SoftLabelCrossEntropyLoss
CrossEntropyLossWithSoftTargets
SoftTargetCrossEntropyLoss