diff --git a/torchvision/ops/focal_loss.py b/torchvision/ops/focal_loss.py index f38855ea459..c8cc9a8ac14 100644 --- a/torchvision/ops/focal_loss.py +++ b/torchvision/ops/focal_loss.py @@ -10,28 +10,28 @@ def sigmoid_focal_loss( alpha: float = 0.25, gamma: float = 2, reduction: str = "none", -): +) -> torch.Tensor: """ - Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py . Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. Args: - inputs: A float tensor of arbitrary shape. + inputs (Tensor): A float tensor of arbitrary shape. The predictions for each example. - targets: A float tensor with the same shape as inputs. Stores the binary + targets (Tensor): A float tensor with the same shape as inputs. Stores the binary classification label for each element in inputs (0 for the negative class and 1 for the positive class). - alpha: (optional) Weighting factor in range (0,1) to balance - positive vs negative examples or -1 for ignore. Default = 0.25 - gamma: Exponent of the modulating factor (1 - p_t) to - balance easy vs hard examples. - reduction: 'none' | 'mean' | 'sum' - 'none': No reduction will be applied to the output. - 'mean': The output will be averaged. - 'sum': The output will be summed. + alpha (float): Weighting factor in range (0,1) to balance + positive vs negative examples or -1 for ignore. Default: ``0.25``. + gamma (float): Exponent of the modulating factor (1 - p_t) to + balance easy vs hard examples. Default: ``2``. + reduction (string): ``'none'`` | ``'mean'`` | ``'sum'`` + ``'none'``: No reduction will be applied to the output. + ``'mean'``: The output will be averaged. + ``'sum'``: The output will be summed. Default: ``'none'``. Returns: Loss tensor with the reduction option applied. """ + # Original implementation from https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/focal_loss.py if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(sigmoid_focal_loss) p = torch.sigmoid(inputs)