-
Notifications
You must be signed in to change notification settings - Fork 22.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feature request] Support soft target distribution in cross entropy loss #11959
Comments
Is there any progress on this feature? I'm currently trying to implement something like this manually, but I'm having trouble. |
No progress as far as I know. |
Label smoothing is implemented in fastai: https://github.com/fastai/fastai/blob/8013797e05f0ae0d771d60ecf7cf524da591503c/fastai/layers.py#L300 |
Is performance the reason why soft target is not officially supported? If so, we can provide a separate method of loss with soft target. Currently I have to manually implement it like this: def softmax_cross_entropy_with_softtarget(input, target, reduction='mean'):
"""
:param input: (batch, *)
:param target: (batch, *) same shape as input, each item must be a valid distribution: target[i, :].sum() == 1.
"""
logprobs = torch.nn.functional.log_softmax(input.view(input.shape[0], -1), dim=1)
batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
if reduction == 'none':
return batchloss
elif reduction == 'mean':
return torch.mean(batchloss)
elif reduction == 'sum':
return torch.sum(batchloss)
else:
raise NotImplementedError('Unsupported reduction mode.') |
TL;DR: For a soft-target Note that soft targets are supported already in PyTorch through At a high level,
The equivalence is demonstrated below; computed losses are fully equivalent for import torch
import torch.nn.functional as F
torch.manual_seed(1)
N = 5
C = 7
input = torch.randn(N, C)
target = torch.tensor([1, 2, 3, 2, 1], dtype=torch.long)
target_one_hot = F.one_hot(target, num_classes=C).to(torch.float32)
log_probs = F.log_softmax(input)
loss_nll = torch.nn.NLLLoss(reduction='none')
output_nll = loss_nll(log_probs, target)
loss_kl = torch.nn.KLDivLoss(reduction='none')
output_kl = loss_kl(log_probs, target_one_hot)
print('NLL:', output_nll)
print('KL:', output_kl)
print('NLL (sum):', output_nll.sum())
print('KL (sum):', output_kl.sum())
print('NLL (mean):', output_nll.mean())
print('KL (mean):', output_kl.mean())
|
@jbschlosser First of all thanks for taking the time to write a comprehensive explanation of the situation. Unfortunately there is much confusion around this (see #7455 (comment), #7455 (comment), etc), so hopefully your detailed analysis will provide proof that indeed
Having said that, I wonder if the fact that there is so much confusion in the community hints that we have a UX problem. There are numerous tickets with several followers, forum posts and discussions around this. Most solutions, require quite a significant amount of boilerplate code and careful implementations. I wonder if that justifies providing a more user-friendly wrapper that simplifies the code and gives the requested functionality. As you highlighted in the past, there are various potential implementations for this and each offers various degrees of flexibility/performance. You can have a simple wrapper such as the following that accepts a target value and reuses standard building blocks from PyTorch. It's a middle ground solution with reasonable performance since it does not have to convert the target value to a one-hot encoded vector while at the same time does not introduce more complex parameters on the low level C++ code: import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.loss import _Loss
class LabelSmoothingCrossEntropy(_Loss):
def __init__(self, eps: float = 0.1, size_average=None, reduce=None, reduction: str = 'mean'):
super().__init__(size_average, reduce, reduction)
self.eps = eps
def forward(self, input: Tensor, target: Tensor) -> Tensor:
log_input = F.log_softmax(input, dim=-1)
loss = (1 - self.eps) * F.nll_loss(log_input, target) - self.eps * log_input.mean(dim=-1)
if self.reduction == "none":
ret = loss
elif self.reduction == "mean":
ret = loss.mean()
elif self.reduction == "sum":
ret = loss.sum()
else:
raise ValueError(self.reduction + " is not valid")
return ret The above covers a vast majority of applications but unfortunately it won't do for Computer Vision applications that use Data Augmentation techniques such as mixup and cutmix. These are SOTA primitives that we would like to add on TorchVision (see pytorch/vision#3911). In order to achieve this we will need a modified loss which accepts an already smoothed target. The smoothed target can be the result of mixup/cutmix OR can be the result of a default import torch
import torch.nn.functional as F
from torch import Tensor
from torch.nn.modules.loss import _Loss
def smooth_labels(target: Tensor, num_classes: int, eps: float = 0.1):
N = target.shape[0]
device = target.device
v = torch.full(size=(N, 1), fill_value=1 - eps, device=device)
return torch.full(size=(N, num_classes), fill_value=eps / num_classes, device=device).scatter_add_(1, target.unsqueeze(1), v)
class LabelSmoothingCrossEntropy2(_Loss):
def __init__(self, eps: float = 0.1, size_average=None, reduce=None, reduction: str = 'mean'):
super().__init__(size_average, reduce, reduction)
self.eps = eps
def forward(self, input: Tensor, target: Tensor) -> Tensor:
log_input = F.log_softmax(input, dim=-1)
loss = (- target * log_input).sum(-1)
if self.reduction == "none":
ret = loss
elif self.reduction == "mean":
ret = loss.mean()
elif self.reduction == "sum":
ret = loss.sum()
else:
raise ValueError(self.reduction + " is not valid")
return ret The two produce identical results: N = 5
C = 7
input = torch.randn(N, C)
target = torch.tensor([1, 2, 3, 2, 1], dtype=torch.long)
loss1 = LabelSmoothingCrossEntropy()
loss2 = LabelSmoothingCrossEntropy2()
v1 = loss1(input, target)
v2 = loss2(input, smooth_labels(target, C))
assert (v1-v2).abs() < 0.0001, f"{v1} != {v2}" We can obviously combine the two approaches and try to automatically handle both target values and smoothed targets. This is something that we could in theory put directly on TorchVision to serve our needs but we think that it's something that does not really belong there as the method can be useful across multiple Domain libraries. Thoughts? |
I would like to quickly chime in on @jbschlosser proposal to add a SoftCrossEntropyLoss(input, target) = KLDivLoss(LogSoftmax(input), target) I think it fits very naturally with what we already have in PyTorch regarding For some history, Torch7 used to have only |
With all the loss variants, I guess we need a good summary page for finding the right function name, given a formula :) |
Linked a PR implementing the soft cross entropy loss. There's some discussion now on whether the constant term |
I don't think it should be dropped by default. |
@ssnl out of curiosity, what is your reasoning? |
@zou3519 It would be confusing because (1) it is not the mathematical definition and thus may be unexpected (2) the hard label version has optima 0 and (3) many people expect naturally that 0 is the optima. |
@ssnl quick clarification- unless I'm mistaken, the mathematical definition for cross entropy H(p, q) with p = target distribution and q = predicted distribution is:
So to compute true cross entropy according to the mathematical definition, we'd need to drop the constant from the KL-divergence computation. Cross entropy and KL-divergence are equivalent when the labels are one-hot, since a one-hot distribution has an entropy of 0. But they differ by a constant = target entropy for other label distributions. Technically, minimizing KL-divergence AKA "relative entropy" has the nice optima 0 property, but minimizing cross entropy AKA "absolute entropy" does not. I wonder if the terminology of |
@JBamberger Interesting! Thanks for pointing this out. I didn't realize that the formal definition differs from KL. I am slightly inclined towards aligning with KL here because of how it is used in ML. What do others think, @fmassa @vadimkantorov ? Additionally, just out of curiosity, what is the reason for using a new class |
I agree that KL would be preferred in practice here- just wanted to bring up a possible incongruity between the definition and naming. As the new class is a combination of We could also support both KL-div and true CE with a flag in the future, if we find that there are use cases for both.
There's a few reasons for creating a new loss instead of expanding the existing one:
|
If we want to calculate actual cross entropy, I'd probably reconsidering the idea of creating a new loss. Responding to my own "reasons" for not doing this above:
|
To summarize a bit, there's 2 options for where to put the soft support (in the existing loss or in a new loss) and 2 options for what to compute (true cross entropy with soft probability labels or KL-divergence), giving 4 options total:
I'm partial to option 1 because the name and computation are consistent and it matches what Keras and FLAX provide. See #61044 for the details on how the things I brought up in my previous comments have been addressed, including doc updates to describe the new support. If someone wants KL-divergence with the optima 0 property, imo they should use @ssnl @zou3519 @datumbox @fmassa @vadimkantorov thoughts / opinions? |
My ranked-choice vote is 4 > 1 > 2, 3. Option 1 feels like it overloads CrossEntropyLoss to handle both soft and hard labels; the computation involved is different (one is a gather-style operation, the other is elementwise multiplication). As an extension writer, if I do something like write a vmap rule for cross_entropy_loss, I don't want to worry about if it sometimes does a gather and if it sometimes does an elementwise multiplication. (It's not the end of the world if I have to worry about it though and we should go with what our users think)
I'd argue that 4 is closer to the spirit of Keras: Keras has separate CategoricalCrossEntropy and SparseCategoricalCrossentropy. Having a lot of wordy losses may not be a good thing though. Re: "Need to know about new loss" -- a cross reference from the nn.CrossEntropyLoss docs could alleviate that concern |
In every case, I propose to also have a single summary table with mapping formula -> function call. It's often hard to find the right function name even when knowing the formula. |
I vote for 1. To me, as an end user, cross entropy is cross entropy. It is soundly defined for both hard and soft labels, and creating a new loss would be hard to discover and confusing (since the old loss is not |
Currently our cross entropy loss (i.e., nn.CrossEntropyLoss) only supports a hard target class, i.e., wanting to maximize the output (log) probability of a particular class. But in many times training w.r.t. a soft target distribution (i.e., wanting the output to match a particular distribution) is quite useful too, e.g., preventing overfitting.
Math
Cross entropy loss operates on logits after softmax.
Denote the input vector as
x
. Log softmax computes a vectory
of same length asx
, wherey_i = x_i - log( \sum_j exp(x_j) )
, representing the log likelihood of each class.In the hard target case, if the target clss is
c
, the loss is simply negative log likelihood loss-y_c
.In the soft target case, let the target distribution vector be
p
(i.e.,p_i
is the target probability for predicting classi
). The loss is the KL divergenceThe constant is independent of
x
and thus discarded. Our loss formula is just-\sum_i p_i y_i
.When
p_c = 1
for some classc
, this simplifies to the hard target class.The formula for gradient computation can be easily derived from this:
Possible Implementation
Currently our cross entropy loss implementation takes in batched
x
of shape(N, C)
and floating point dtype (N
is the batch size andC
is the number of classes), and a batched target class indices vectortarget
of shape(N)
, wheretarget[i]
is the index of the desired output class, and dtypelong
(an integral type).Since we want it to also take in soft target distribution as target, we can allow it to also take in
target
as a target batched distribution of shape(N, C)
, and detect whether we want soft target or hard target basing on shape and dtype.cc @gchanan
The text was updated successfully, but these errors were encountered: