Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of nn.CrossEntropyLossWithSoftLabels #59824

Closed
wants to merge 7 commits into from

Conversation

jbschlosser
Copy link
Contributor

@jbschlosser jbschlosser commented Jun 10, 2021

Fixes #11959. The implementation is a combination of log_softmax + kl_div (see #11959 (comment) for details).

Btw this specific name was chosen for a few reasons:

  1. To distinguish the criterion from SoftMarginLoss, where "soft" has a different meaning applying to the margin (in contrast to a hard margin)
  2. To fit with the "With" naming adopted by BCELossWithLogits
  3. To fit with the "Labels" naming adopted by MultiLabelMarginLoss and MultiLabelSoftMarginLoss

This ruled out:

  • SoftCrossEntropyLoss
  • SoftLabelCrossEntropyLoss
  • CrossEntropyLossWithSoftTargets
  • SoftTargetCrossEntropyLoss
  • etc.

@facebook-github-bot
Copy link
Contributor

facebook-github-bot commented Jun 10, 2021

💊 CI failures summary and remediations

As of commit 9b02944 (more details on the Dr. CI page and at hud.pytorch.org/pr/59824):



🕵️ 1 new failure recognized by patterns

The following CI failures do not appear to be due to upstream breakages:

See CircleCI build pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build (1/1)

Step: "Build" (full log | diagnosis details | 🔁 rerun)

Jun 24 16:01:15 rm: cannot remove '/var/lib/jenkins/sccache_error.log': No such file or directory
Jun 24 16:01:15 ++++ extract_trap_cmd
Jun 24 16:01:15 ++++ printf '%s\n' ''
Jun 24 16:01:15 +++ printf '%s\n' cleanup
Jun 24 16:01:15 ++ trap -- '
Jun 24 16:01:15 cleanup' EXIT
Jun 24 16:01:15 ++ [[ pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build != *pytorch-win-* ]]
Jun 24 16:01:15 ++ which sccache
Jun 24 16:01:15 ++ sccache --stop-server
Jun 24 16:01:15 ++ true
Jun 24 16:01:15 ++ rm /var/lib/jenkins/sccache_error.log
Jun 24 16:01:15 rm: cannot remove '/var/lib/jenkins/sccache_error.log': No such file or directory
Jun 24 16:01:15 ++ true
Jun 24 16:01:15 ++ [[ -n '' ]]
Jun 24 16:01:15 ++ [[ pytorch-linux-xenial-cuda11.1-cudnn8-py3-gcc7-build == *rocm* ]]
Jun 24 16:01:15 ++ SCCACHE_ERROR_LOG=/var/lib/jenkins/sccache_error.log
Jun 24 16:01:15 ++ SCCACHE_IDLE_TIMEOUT=1200
Jun 24 16:01:15 ++ RUST_LOG=sccache::server=error
Jun 24 16:01:15 ++ sccache --start-server
Jun 24 16:01:15 sccache: Starting the server...
Jun 24 16:01:15 ++ sccache --zero-stats
Jun 24 16:01:15 Compile requests                      0

1 job timed out:

  • pytorch_linux_xenial_cuda11_1_cudnn8_py3_gcc7_build

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Preview docs built from this PR

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

Click here to manually regenerate this comment.

@jbschlosser jbschlosser force-pushed the soft_label_ce branch 4 times, most recently from 5c7a3b4 to d674167 Compare June 11, 2021 15:15
@jbschlosser jbschlosser changed the title [WIP] Implementation of nn.CrossEntropyLossWithSoftLabels Implementation of nn.CrossEntropyLossWithSoftLabels Jun 11, 2021
@datumbox datumbox self-requested a review June 11, 2021 17:12
@jbschlosser
Copy link
Contributor Author

@zou3519 Are you cool with reviewing this? I think you have the most familiarity with this issue within the frontend group

@zou3519
Copy link
Contributor

zou3519 commented Jun 15, 2021

Yup, can do!

@@ -1121,6 +1125,91 @@ def forward(self, input: Tensor, target: Tensor) -> Tensor:
ignore_index=self.ignore_index, reduction=self.reduction)


class CrossEntropyLossWithSoftLabels(_Loss):
r"""This criterion combines :class:`~torch.nn.LogSoftmax` and :class:`~torch.nn.KLDivLoss` in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


.. math::
L = l_{n,c} = y_{n,c} \cdot \left( \log y_{n,c} - \\
\log\left(\frac{\exp(x_{n,c})}{\sum_{i=1}^C \exp(x_{n,i})} \right) \right)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that in #11959, @ssnl dropped the constant term. In elegy, a jax nn library, the constant term is also dropped. I haven't been able to figure out how to read the TF source code for categorial_crossentropy yet, so idk what they're doing there.

Should we drop the constant term as well? This would change the implementation of this loss function from being a composition of LogSoftmax and KLDivLoss to being something else that is (1) more efficient to compute and (2) optimizes the same way (assuming your targets don't require grad...) but (3) isn't exactly the quantity users may expect.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is the TF source code for categorical_crossentropy and it looks like they drop the constant term as well: https://github.com/keras-team/keras/blob/0eb1954092e2110868e8ef06381cebac891e01e9/keras/backend.py#L4872

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question- definitely true that it'd be more efficient to drop the constant from an optimization perspective. Not sure what the impact is for 3 (user experience), since the one-hot case will still generally be equivalent between soft and hard cross losses whether the constant is dropped or not (since the constant=0 for that case).

FWIW, SsnL doesn't think the constant should be dropped by default: #11959 (comment). I wonder if it makes sense for this PR to leave as-is and possibly add a flag for dropping the constant later on?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems reasonable. Let's go ahead with this unless we get feedback from the other folks in the thread that they do want the constant dropped.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did some more digging and I believe the constant is dropped in TF because that corresponds with the actual mathematical definition of cross entropy (see #11959 (comment)).

Also, the TF docs say that they assume one-hot labels, and the constant is always 0 for that case.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had some comments on the docs and the testing (I'm not sure what's being tested). Other than that, the implementations look correct to me

aten/src/ATen/native/Loss.cpp Outdated Show resolved Hide resolved
torch/nn/modules/loss.py Outdated Show resolved Hide resolved
torch/nn/modules/loss.py Outdated Show resolved Hide resolved
torch/nn/modules/loss.py Show resolved Hide resolved
aten/src/ATen/native/Loss.cpp Outdated Show resolved Hide resolved
test/cpp/api/modules.cpp Outdated Show resolved Hide resolved
@@ -110,6 +110,7 @@ torch::nn::PairwiseDistance|Yes|No
torch::nn::L1Loss|Yes|No
torch::nn::MSELoss|Yes|No
torch::nn::CrossEntropyLoss|Yes|No
torch::nn::CrossEntropyLossWithSoftLabels|Yes|No
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's been a long time since I've seen the parity tracker. Is there a comment or documentation somewhere about what tests this auto-generates?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No docs that I'm aware of :( My understanding is that it does this:

  1. Instantiates a module in both C++ / python using the constructor_args / cpp_constructor_args entries in the criterion_tests dict entry corresponding to the module. If those aren't specified, it defaults to no args, which is okay for this loss.
  2. It sends input / target through the C++ / python modules
  3. It compares the resulting outputs

I noticed that the parity tracker checks functional forms as well, so I added that.

torch/testing/_internal/common_nn.py Show resolved Hide resolved
torch/testing/_internal/common_nn.py Outdated Show resolved Hide resolved
torch/testing/_internal/common_nn.py Outdated Show resolved Hide resolved
Comment on lines +1173 to +1176
- Input: :math:`(*, C)` where :math:`*` represents any number of dimensions (including none) and
`C = number of classes`
- Target: :math:`(*, C)`, same shape as the input
- Output: If :attr:`reduction` is ``'none'``, then :math:`(*, C)`, otherwise scalar
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyTorch CrossEntropyLoss accepts input of shape (N, C, *) and target of shape (N, *) where * is any number of dimensions. Do the shapes here make CrossEntropyLoss and CrossEntropyLossWithSoftLabels inconsistent?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In TF images are NHWC by default (instead of NCHW in PyTorch) so that's why they don't have this problem...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right that they're inconsistent, but there's a few things to consider that made this choice hard:

  • As you pointed out below, target shape == input shape for CrossEntropyLossWithSoftLabels, so that part can't be consistent with CrossEntropyLoss
  • We don't want to unnecessarily bake in the need for N, since that will have to be undone for the no-batch-dim support work. To be consistent with CrossEntropyLoss, this would mean supporting shapes (C), (N, C), (N, C, d1), (N, C, d1, d2), etc. (i.e. sometimes C is in at dim 0 and sometimes it's at dim 1). I find this to be pretty confusing and honestly don't understand why the original CE did it this way, when d1, d2, etc. are semantically just more batch dimensions.

An interesting thing to consider could be supporting a configurable location for the C dim (i.e. the dim that log-softmax is taken over). But we couldn't easily make the default match CE's behavior, since the location of the C dim changes as the number of dims changes (location is 0 or -1 for shape (C), 1 or -1 for shape (N, C), and 1 for shapes (N, C, d1, ...)).

Also, admittedly confusingly, the C in NHWC / NCHW refers to "channels" (e.g. red, green, blue) while the C here refers to "number of classes".

Comment on lines +16395 to +16396
input = torch.randn(N, C)
target = torch.tensor([1, 2, 0, 1, 2], dtype=torch.long)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test correctness on an example with more than 2 dimensions

@@ -16387,6 +16389,50 @@ def cosine_distance(x, y):
self.assertEqual(functional, modular, atol=1e-6, rtol=1e-6)
self.assertEqual(traced, modular, atol=1e-6, rtol=1e-6)

@onlyOnCPUAndCUDA
def test_cross_entropy_loss_with_soft_labels(self, device):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It would be nice to test correctness for some cases where target is not one-hot-encoded

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that the tests generated from the criterion_tests entry check against a reference_fn for the non-one-hot encoded target case.

Copy link
Contributor

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation looks correct and some minor comments on the testing. I had two main concerns:

  1. naming. Is it really OK for us to call this nn.CrossEntropyLossWithSoftLabels if it doesn't exactly compute "cross entropy"? Should we provide an option to actually cross_entropy and have that be the default? The implementation of that would not be very difficult (it's just - (input *F.log_softmax(target)).sum() and we could optimize that with some cpu/cuda kernels if necessary in the future
  2. Multidimensional shape behavior: CrossEntropyLoss takes in inputs and targets of shape (N, C, *) and (N, *) respectively. To be consistent with that I would have expected us to accept inputs and targets of shape (N, C, *) and (N, C, *) but I might be missing something here

@jbschlosser
Copy link
Contributor Author

Closing since #61044 landed.

@jbschlosser jbschlosser closed this Aug 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[feature request] Support soft target distribution in cross entropy loss
3 participants