Skip to content

BCELoss with weights for labels (like weighted_cross_entropy_with_logits in TF) #5660

@velikodniy

Description

@velikodniy

I've implemented an analog of weighted_cross_entropy_with_logits in my current project. It's useful for working with imbalanced datasets. I want to add it to PyTorch but I'm in doubt if it is really needed for others.

For example, my implementation:

def weighted_binary_cross_entropy_with_logits(logits, targets, pos_weight, weight=None, size_average=True, reduce=True):
    if not (targets.size() == logits.size()):
        raise ValueError("Target size ({}) must be the same as input size ({})".format(targets.size(), logits.size()))

    max_val = (-logits).clamp(min=0)
    log_weight = 1 + (pos_weight - 1) * targets
    loss = (1 - targets) * logits + log_weight * (((-max_val).exp() + (-logits - max_val).exp()).log() + max_val)

    if weight is not None:
        loss = loss * weight

    if not reduce:
        return loss
    elif size_average:
        return loss.mean()
    else:
        return loss.sum()

class WeightedBCEWithLogitsLoss(torch.nn.Module):
    def __init__(self, pos_weight, weight=None, size_average=True, reduce=True):
        super().__init__()
        self.register_buffer('weight', weight)
        self.register_buffer('pos_weight', pos_weight)
        self.size_average = size_average
        self.reduce = reduce

    def forward(self, input, target):
        pos_weight = Variable(self.pos_weight) if not isinstance(self.pos_weight, Variable) else self.pos_weight
        if self.weight is not None:
            weight = Variable(self.weight) if not isinstance(self.weight, Variable) else self.weight
            return weighted_binary_cross_entropy_with_logits(input, target,
                                                             pos_weight,
                                                             weight=weight
                                                             size_average=self.size_average,
                                                             reduce=self.reduce)
        else:
            return weighted_binary_cross_entropy_with_logits(input, target,
                                                             pos_weight,
                                                             weight=None,
                                                             size_average=self.size_average,
                                                             reduce=self.reduce)

(Of course, tests and WeightedBCELoss should be written too.)

Metadata

Metadata

Assignees

No one assigned

    Labels

    todoNot as important as medium or high priority tasks, but we will work on these.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions