-
Notifications
You must be signed in to change notification settings - Fork 25.1k
Closed
Labels
todoNot as important as medium or high priority tasks, but we will work on these.Not as important as medium or high priority tasks, but we will work on these.
Description
I've implemented an analog of weighted_cross_entropy_with_logits in my current project. It's useful for working with imbalanced datasets. I want to add it to PyTorch but I'm in doubt if it is really needed for others.
For example, my implementation:
def weighted_binary_cross_entropy_with_logits(logits, targets, pos_weight, weight=None, size_average=True, reduce=True):
if not (targets.size() == logits.size()):
raise ValueError("Target size ({}) must be the same as input size ({})".format(targets.size(), logits.size()))
max_val = (-logits).clamp(min=0)
log_weight = 1 + (pos_weight - 1) * targets
loss = (1 - targets) * logits + log_weight * (((-max_val).exp() + (-logits - max_val).exp()).log() + max_val)
if weight is not None:
loss = loss * weight
if not reduce:
return loss
elif size_average:
return loss.mean()
else:
return loss.sum()
class WeightedBCEWithLogitsLoss(torch.nn.Module):
def __init__(self, pos_weight, weight=None, size_average=True, reduce=True):
super().__init__()
self.register_buffer('weight', weight)
self.register_buffer('pos_weight', pos_weight)
self.size_average = size_average
self.reduce = reduce
def forward(self, input, target):
pos_weight = Variable(self.pos_weight) if not isinstance(self.pos_weight, Variable) else self.pos_weight
if self.weight is not None:
weight = Variable(self.weight) if not isinstance(self.weight, Variable) else self.weight
return weighted_binary_cross_entropy_with_logits(input, target,
pos_weight,
weight=weight
size_average=self.size_average,
reduce=self.reduce)
else:
return weighted_binary_cross_entropy_with_logits(input, target,
pos_weight,
weight=None,
size_average=self.size_average,
reduce=self.reduce)
(Of course, tests and WeightedBCELoss should be written too.)
Metadata
Metadata
Assignees
Labels
todoNot as important as medium or high priority tasks, but we will work on these.Not as important as medium or high priority tasks, but we will work on these.