Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feature request] Support soft target distribution in cross entropy loss #11959

Closed
ssnl opened this issue Sep 21, 2018 · 22 comments
Closed

[feature request] Support soft target distribution in cross entropy loss #11959

ssnl opened this issue Sep 21, 2018 · 22 comments
Assignees
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: loss Problem is related to loss function triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ssnl
Copy link
Collaborator

ssnl commented Sep 21, 2018

Currently our cross entropy loss (i.e., nn.CrossEntropyLoss) only supports a hard target class, i.e., wanting to maximize the output (log) probability of a particular class. But in many times training w.r.t. a soft target distribution (i.e., wanting the output to match a particular distribution) is quite useful too, e.g., preventing overfitting.

Math

Cross entropy loss operates on logits after softmax.

Denote the input vector as x. Log softmax computes a vector y of same length as x, where y_i = x_i - log( \sum_j exp(x_j) ), representing the log likelihood of each class.

  • In the hard target case, if the target clss is c, the loss is simply negative log likelihood loss -y_c.

  • In the soft target case, let the target distribution vector be p (i.e., p_i is the target probability for predicting class i). The loss is the KL divergence

    D( softmax(x) || p) = \sum_i p_i (log p_i  / softmax(x)_i) = -\sum_i p_i y_i + constant
    

    The constant is independent of x and thus discarded. Our loss formula is just -\sum_i p_i y_i.

    When p_c = 1 for some class c, this simplifies to the hard target class.

    The formula for gradient computation can be easily derived from this:

    
    d l / d y_i = -p_i
    
    d y_i / d x_i = 1 - exp(x_i) / \sum_j exp(x_j) = 1 - exp(y_i)
    
    # suppose k != i
    d y_k / d_x_i = -exp(x_i) / \sum_j exp(x_j) = - exp(y_i)
    
    # so
    d l / d x_i = exp(y_i) (\sum p) - p_i = exp (y_i) - p_i (= softmax(x) - p_i).
    
    

Possible Implementation

Currently our cross entropy loss implementation takes in batched x of shape (N, C) and floating point dtype (N is the batch size and C is the number of classes), and a batched target class indices vector target of shape (N), where target[i] is the index of the desired output class, and dtype long (an integral type).

Since we want it to also take in soft target distribution as target, we can allow it to also take in target as a target batched distribution of shape (N, C), and detect whether we want soft target or hard target basing on shape and dtype.

cc @gchanan

@gchanan gchanan self-assigned this Sep 24, 2018
@Naman-ntc
Copy link
Contributor

Naman-ntc commented Nov 18, 2018

Hi @ssnl @gchanan,
I was looking into this few days back for a part of my project & fell here.
I would love to try my hand to implementing it if you would like?

@Naman-ntc
Copy link
Contributor

Hi @gchanan @ssnl, any updates?

@tueboesen
Copy link

Is there any progress on this feature? I'm currently trying to implement something like this manually, but I'm having trouble.

@pietern pietern added module: operators module: loss Problem is related to loss function triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Oct 22, 2019
@gchanan
Copy link
Contributor

gchanan commented Nov 5, 2019

No progress as far as I know.

@djstrong
Copy link

@tranhungnghiep
Copy link

tranhungnghiep commented May 5, 2020

Is performance the reason why soft target is not officially supported? If so, we can provide a separate method of loss with soft target. Currently I have to manually implement it like this:

    def softmax_cross_entropy_with_softtarget(input, target, reduction='mean'):
        """
        :param input: (batch, *)
        :param target: (batch, *) same shape as input, each item must be a valid distribution: target[i, :].sum() == 1.
        """
        logprobs = torch.nn.functional.log_softmax(input.view(input.shape[0], -1), dim=1)
        batchloss = - torch.sum(target.view(target.shape[0], -1) * logprobs, dim=1)
        if reduction == 'none':
            return batchloss
        elif reduction == 'mean':
            return torch.mean(batchloss)
        elif reduction == 'sum':
            return torch.sum(batchloss)
        else:
            raise NotImplementedError('Unsupported reduction mode.')

@mruberry mruberry added function request A request for a new function or the addition of new arguments/modes to an existing function. and removed enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: operators (deprecated) labels Oct 8, 2020
@jbschlosser
Copy link
Contributor

jbschlosser commented Jun 3, 2021

TL;DR: For a soft-target CrossEntropyLoss, you can use KLDivLoss with log-prob inputs (e.g. KLDivLoss()(F.log_softmax(input), target).


Note that soft targets are supported already in PyTorch through KLDivLoss, which accepts floating-point inputs and targets of shape (N, C) (as well as arbitrary dims).

At a high level, CrossEntropyLoss does LogSoftmax followed by NLLLoss. For a CrossEntropyLoss with soft targets, the analogue would be LogSoftmax followed by KLDivLoss:

  • CrossEntropyLoss(input, target) = NLLLoss(LogSoftmax(input), target)
  • SoftCrossEntropyLoss(input, target) = KLDivLoss(LogSoftmax(input), target)

KLDivLoss for input of shape (N, C) with one-hot labels is essentially equivalent to NLLLoss with unit weights and no ignore_index. Formulas from the docs:

KLDivLoss: l_n = y_n (log y_n - x_n) = 0 if y_n == 0 else -x_n
NLLLoss: l_n = -w_{y_n} * x_{n, y_n} = -x_{n, y_n}

The equivalence is demonstrated below; computed losses are fully equivalent for reduction='sum' and differ by a constant factor of C for reduction='mean' (since the unreduced NLLLoss includes only a single value per batch item):

import torch
import torch.nn.functional as F

torch.manual_seed(1)
N = 5
C = 7
input = torch.randn(N, C)
target = torch.tensor([1, 2, 3, 2, 1], dtype=torch.long)
target_one_hot = F.one_hot(target, num_classes=C).to(torch.float32)
log_probs = F.log_softmax(input)

loss_nll = torch.nn.NLLLoss(reduction='none')
output_nll = loss_nll(log_probs, target)
loss_kl = torch.nn.KLDivLoss(reduction='none')
output_kl = loss_kl(log_probs, target_one_hot)

print('NLL:', output_nll)
print('KL:', output_kl)
print('NLL (sum):', output_nll.sum())
print('KL (sum):', output_kl.sum())
print('NLL (mean):', output_nll.mean())
print('KL (mean):', output_kl.mean())
NLL: tensor([1.9238, 1.9462, 3.1566, 2.9595, 4.4378])
KL: tensor([[0.0000, 1.9238, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.9462, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 3.1566, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 2.9595, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 4.4378, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000]])
NLL (sum): tensor(14.4241)
KL (sum): tensor(14.4241)
NLL (mean): tensor(2.8848)
KL (mean): tensor(0.4121)

KLDivLoss can be thought of as a generalization of NLLLoss for an arbitrary target distribution, unrestricted to a single class label per training example.

@datumbox
Copy link
Contributor

datumbox commented Jun 4, 2021

@jbschlosser First of all thanks for taking the time to write a comprehensive explanation of the situation.

Unfortunately there is much confusion around this (see #7455 (comment), #7455 (comment), etc), so hopefully your detailed analysis will provide proof that indeed KLDivLoss can support label smoothing. For those, who would like more info about the relationship between Label Smoothing and Kullback-Leibler divergence, here are some references:

Having said that, I wonder if the fact that there is so much confusion in the community hints that we have a UX problem. There are numerous tickets with several followers, forum posts and discussions around this. Most solutions, require quite a significant amount of boilerplate code and careful implementations. I wonder if that justifies providing a more user-friendly wrapper that simplifies the code and gives the requested functionality.

As you highlighted in the past, there are various potential implementations for this and each offers various degrees of flexibility/performance.

You can have a simple wrapper such as the following that accepts a target value and reuses standard building blocks from PyTorch. It's a middle ground solution with reasonable performance since it does not have to convert the target value to a one-hot encoded vector while at the same time does not introduce more complex parameters on the low level C++ code:

import torch.nn.functional as F

from torch import Tensor
from torch.nn.modules.loss import _Loss


class LabelSmoothingCrossEntropy(_Loss):
    def __init__(self, eps: float = 0.1, size_average=None, reduce=None, reduction: str = 'mean'):
        super().__init__(size_average, reduce, reduction)
        self.eps = eps

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        log_input = F.log_softmax(input, dim=-1)
        loss = (1 - self.eps) * F.nll_loss(log_input, target) - self.eps * log_input.mean(dim=-1)
        if self.reduction == "none":
            ret = loss
        elif self.reduction == "mean":
            ret = loss.mean()
        elif self.reduction == "sum":
            ret = loss.sum()
        else:
            raise ValueError(self.reduction + " is not valid")
        return ret

The above covers a vast majority of applications but unfortunately it won't do for Computer Vision applications that use Data Augmentation techniques such as mixup and cutmix. These are SOTA primitives that we would like to add on TorchVision (see pytorch/vision#3911). In order to achieve this we will need a modified loss which accepts an already smoothed target. The smoothed target can be the result of mixup/cutmix OR can be the result of a default smooth_labels method similar to what you described here:

import torch
import torch.nn.functional as F

from torch import Tensor
from torch.nn.modules.loss import _Loss


def smooth_labels(target: Tensor, num_classes: int, eps: float = 0.1):
    N = target.shape[0]
    device = target.device
    v = torch.full(size=(N, 1), fill_value=1 - eps, device=device)
    return torch.full(size=(N, num_classes), fill_value=eps / num_classes, device=device).scatter_add_(1, target.unsqueeze(1), v)


class LabelSmoothingCrossEntropy2(_Loss):
    def __init__(self, eps: float = 0.1, size_average=None, reduce=None, reduction: str = 'mean'):
        super().__init__(size_average, reduce, reduction)
        self.eps = eps

    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        log_input = F.log_softmax(input, dim=-1)
        loss = (- target * log_input).sum(-1)
        if self.reduction == "none":
            ret = loss
        elif self.reduction == "mean":
            ret = loss.mean()
        elif self.reduction == "sum":
            ret = loss.sum()
        else:
            raise ValueError(self.reduction + " is not valid")
        return ret

The two produce identical results:

N = 5
C = 7
input = torch.randn(N, C)
target = torch.tensor([1, 2, 3, 2, 1], dtype=torch.long)

loss1 = LabelSmoothingCrossEntropy()
loss2 = LabelSmoothingCrossEntropy2()

v1 = loss1(input, target)
v2 = loss2(input, smooth_labels(target, C))
assert (v1-v2).abs() < 0.0001, f"{v1} != {v2}"

We can obviously combine the two approaches and try to automatically handle both target values and smoothed targets. This is something that we could in theory put directly on TorchVision to serve our needs but we think that it's something that does not really belong there as the method can be useful across multiple Domain libraries. Thoughts?

@fmassa
Copy link
Member

fmassa commented Jun 7, 2021

I would like to quickly chime in on @jbschlosser proposal to add a

SoftCrossEntropyLoss(input, target) = KLDivLoss(LogSoftmax(input), target)

I think it fits very naturally with what we already have in PyTorch regarding CrossEntropyLoss and NLLLoss, and is a natural extension that will benefit many users (even though it might look like all the needed building blocks are already there).

For some history, Torch7 used to have only NLLLoss and a common user error was to forget the log_softmax in their model. It was only after several years in 2015 that CrossEntropyLoss was added, while NLLLoss was there since before 2012.

@vadimkantorov
Copy link
Contributor

With all the loss variants, I guess we need a good summary page for finding the right function name, given a formula :)

@jbschlosser
Copy link
Contributor

Linked a PR implementing the soft cross entropy loss. There's some discussion now on whether the constant term \sum_i p_i log p_i should be dropped (as originally described by @ssnl). It'd be more efficient to drop it and it shouldn't matter from an optimization perspective, but the loss would no longer be equivalent to LogSoftmax + KLDivLoss and the different loss values vs. CrossEntropyLoss may surprise some users. @fmassa / @datumbox / anyone else - opinions on this?

@ssnl
Copy link
Collaborator Author

ssnl commented Jun 15, 2021

I don't think it should be dropped by default.

@zou3519
Copy link
Contributor

zou3519 commented Jun 16, 2021

I don't think it should be dropped by default.

@ssnl out of curiosity, what is your reasoning?

@ssnl
Copy link
Collaborator Author

ssnl commented Jun 16, 2021

@zou3519 It would be confusing because (1) it is not the mathematical definition and thus may be unexpected (2) the hard label version has optima 0 and (3) many people expect naturally that 0 is the optima.

@jbschlosser
Copy link
Contributor

@ssnl quick clarification- unless I'm mistaken, the mathematical definition for cross entropy H(p, q) with p = target distribution and q = predicted distribution is:

H(p, q) = H(p) + D_{KL}(p||q)
        = -\sum_i p(x_i) \log p(x_i) + \sum_i p(x_i) \log \frac{p(x_i)}{q(x_i)}
        = -\sum_i p(x_i) \log p(x_i) + \sum_i p(x_i) (\log p(x_i) - \log q(x_i))
        = -\sum_i p(x_i) \log p(x_i) + \sum_i p(x_i) \log p(x_i) - \sum_i p(x_i) \log q(x_i))
        = -\sum_i p(x_i) \log q(x_i))

So to compute true cross entropy according to the mathematical definition, we'd need to drop the constant from the KL-divergence computation.

Cross entropy and KL-divergence are equivalent when the labels are one-hot, since a one-hot distribution has an entropy of 0. But they differ by a constant = target entropy for other label distributions. Technically, minimizing KL-divergence AKA "relative entropy" has the nice optima 0 property, but minimizing cross entropy AKA "absolute entropy" does not.

I wonder if the terminology of CrossEntropyLossWithSoftLabels will be confusing if we are actually calculating KL-divergence.

@ssnl
Copy link
Collaborator Author

ssnl commented Jun 22, 2021

@JBamberger Interesting! Thanks for pointing this out. I didn't realize that the formal definition differs from KL. I am slightly inclined towards aligning with KL here because of how it is used in ML. What do others think, @fmassa @vadimkantorov ?

Additionally, just out of curiosity, what is the reason for using a new class CrossEntropyLossWithSoftLabels rather than reusing the old class and deciding based on shape/dtype?

@jbschlosser
Copy link
Contributor

jbschlosser commented Jun 22, 2021

I am slightly inclined towards aligning with KL here because of how it is used in ML.

I agree that KL would be preferred in practice here- just wanted to bring up a possible incongruity between the definition and naming. As the new class is a combination of LogSoftmax + KLDivLoss, I guess it's not too surprising that KL-div is what is being computed, the only difference being that we expect logit inputs for the new loss vs. log-prob inputs for KLDivLoss. If there was naming precedent for this sort of thing within PyTorch, I'd probably suggest KLDivLossWithLogits as a more accurate name than CrossEntropyLossWithSoftLabels, but that does make the relationship with the original CrossEntropyLoss less clear.

We could also support both KL-div and true CE with a flag in the future, if we find that there are use cases for both.

Additionally, just out of curiosity, what is the reason for using a new class CrossEntropyLossWithSoftLabels rather than reusing the old class and deciding based on shape/dtype?

There's a few reasons for creating a new loss instead of expanding the existing one:

  • Shape / dtype specialization do not play well with FX
  • It may be surprising to users to get worse performance simply based on the dtype of the target
  • The existing CrossEntropyLoss has an ignore_index arg that doesn't make semantic sense with soft labels
  • With a new loss, it's more straightforward to support arbitrary dims beyond just shape (N, C)

@jbschlosser
Copy link
Contributor

If we want to calculate actual cross entropy, I'd probably reconsidering the idea of creating a new loss. Responding to my own "reasons" for not doing this above:

  • Shape / dtype specialization do not play well with FX - NLLLoss already does shape checks so this surely isn't a valid reason? Need to ping someone more familiar with the details here
  • It may be surprising to users to get worse performance simply based on the dtype of the target - Won't happen now that I think about it more; we can switch on target vs. input shape. If the shapes are the same, go down the "soft label" route and expect floating point targets. Otherwise, go down the current route
  • The existing CrossEntropyLoss has an ignore_index arg that doesn't make semantic sense with soft labels - This is true and non-ideal, but the docs could indicate that the arg only has meaning for non-soft labels
  • With a new loss, it's more straightforward to support arbitrary dims beyond just shape (N, C) - Doesn't actually matter if we want to remain consistent with current CE input shapes

@jbschlosser
Copy link
Contributor

jbschlosser commented Jul 6, 2021

To summarize a bit, there's 2 options for where to put the soft support (in the existing loss or in a new loss) and 2 options for what to compute (true cross entropy with soft probability labels or KL-divergence), giving 4 options total:

  1. Compute true cross entropy with soft labels within existing CrossEntropyLoss when input shape == target shape (shown in Support for target with class probs in CrossEntropyLoss #61044)
  • Pros: No need to know about new loss, name matches computation, matches what Keras and FLAX provide
  • Cons: No optima 0 property, ignore_index needs to be documented as unsupported with soft labels
  1. Compute log_softmax() + KL-divergence within existing CrossEntropyLoss when input shape == target shape
  • Pros: No need to know about new loss, has nice optima 0 property
  • Cons: Confusing, since KL-divergence is computed instead of true cross entropy
  1. Compute log_softmax() + KL-divergence within a new loss (shown in Implementation of nn.CrossEntropyLossWithSoftLabels #59824)
  • Pros: Has nice optima 0 property
  • Cons: Currently, the name CrossEntropyLossWithSoftLabels doesn't match the computation, but the name could change to e.g. KLDivLossWithLogits (if the name changes, it could be less discoverable for those looking for some sort of soft CE support)
  1. Compute true cross entropy with soft labels within a new loss
  • Pros: Name and computation match, can leave out the ignore_index arg from the new loss, edit: does match Keras with its categorical_cross_entropy / sparse_categorical_cross_entropy ops
  • Cons: Need to know about new loss that only differs from existing loss by supported target type, no optima 0 property, doesn't match FLAX

I'm partial to option 1 because the name and computation are consistent and it matches what Keras and FLAX provide. See #61044 for the details on how the things I brought up in my previous comments have been addressed, including doc updates to describe the new support.

If someone wants KL-divergence with the optima 0 property, imo they should use KLDivLoss. To avoid needing a log_softmax() call beforehand, perhaps a flag can be added there to support input in the form of logits, or a KLDivLossWithLogits loss can be added as well.

@ssnl @zou3519 @datumbox @fmassa @vadimkantorov thoughts / opinions?

@zou3519
Copy link
Contributor

zou3519 commented Jul 6, 2021

My ranked-choice vote is 4 > 1 > 2, 3.

Option 1 feels like it overloads CrossEntropyLoss to handle both soft and hard labels; the computation involved is different (one is a gather-style operation, the other is elementwise multiplication). As an extension writer, if I do something like write a vmap rule for cross_entropy_loss, I don't want to worry about if it sometimes does a gather and if it sometimes does an elementwise multiplication. (It's not the end of the world if I have to worry about it though and we should go with what our users think)

Compute true cross entropy with soft labels within a new loss
Pros: Name and computation match, can leave out the ignore_index arg from the new loss
Cons: Need to know about new loss that only differs from existing loss by supported target type, no optima 0 property, doesn't match Keras / FLAX

I'd argue that 4 is closer to the spirit of Keras: Keras has separate CategoricalCrossEntropy and SparseCategoricalCrossentropy. Having a lot of wordy losses may not be a good thing though.

Re: "Need to know about new loss" -- a cross reference from the nn.CrossEntropyLoss docs could alleviate that concern

@vadimkantorov
Copy link
Contributor

In every case, I propose to also have a single summary table with mapping formula -> function call. It's often hard to find the right function name even when knowing the formula.

@ssnl
Copy link
Collaborator Author

ssnl commented Jul 8, 2021

I vote for 1. To me, as an end user, cross entropy is cross entropy. It is soundly defined for both hard and soft labels, and creating a new loss would be hard to discover and confusing (since the old loss is not CrossEntropyLossWithHardLabels.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
function request A request for a new function or the addition of new arguments/modes to an existing function. module: loss Problem is related to loss function triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet