Skip to content
24 changes: 12 additions & 12 deletions torchvision/ops/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ def sigmoid_focal_loss(
alpha: float = 0.25,
gamma: float = 2,
reduction: str = "none",
):
) -> torch.Tensor:
"""
Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py .
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.

Args:
inputs: A float tensor of arbitrary shape.
inputs (Tensor): A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
targets (Tensor): A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default = 0.25
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
reduction: 'none' | 'mean' | 'sum'
'none': No reduction will be applied to the output.
'mean': The output will be averaged.
'sum': The output will be summed.
alpha (float): Weighting factor in range (0,1) to balance
positive vs negative examples or -1 for ignore. Default: ``0.25``.
gamma (float): Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples. Default: ``2``.
reduction (string): ``'none'`` | ``'mean'`` | ``'sum'``
``'none'``: No reduction will be applied to the output.
``'mean'``: The output will be averaged.
``'sum'``: The output will be summed. Default: ``'none'``.
Returns:
Loss tensor with the reduction option applied.
"""
# Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(sigmoid_focal_loss)
p = torch.sigmoid(inputs)
Expand Down