Skip to content

Commit

Permalink
[inductor] Added smooth_l1_loss refs
Browse files Browse the repository at this point in the history
  • Loading branch information
vfdev-5 committed May 23, 2023
1 parent 88b6a45 commit 9a728ab
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 1 deletion.
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.movedim',
'_refs.narrow',
'_refs.nn.functional.l1_loss',
'_refs.nn.functional.smooth_l1_loss',
'_refs.nn.functional.log_softmax',
'_refs.nn.functional.poisson_nll_loss',
'_refs.nn.functional.softmax',
Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.silu_backward,
aten.sinc,
aten.slice_backward,
aten.smooth_l1_loss,
aten.smooth_l1_loss_backward,
aten.soft_margin_loss,
aten.soft_margin_loss_backward,
Expand Down
13 changes: 13 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,19 @@ def mse_loss_backward(
return norm * (input - target) * grad_output


@register_decomposition(aten.smooth_l1_loss)
@pw_cast_for_opmath
def smooth_l1_loss(
self: Tensor,
target: Tensor,
reduction: int = Reduction.MEAN.value,
beta: float = 1.0,
):
loss = (self - target).abs()
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.smooth_l1_loss_backward.default)
@pw_cast_for_opmath
def smooth_l1_loss_backward(
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1494,7 +1494,6 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten.resize_as)
make_fallback(aten.resize_as_)
make_fallback(aten.searchsorted)
make_fallback(aten.smooth_l1_loss)
make_fallback(aten.special_airy_ai)
make_fallback(aten.special_bessel_j0, warn=False)
make_fallback(aten.special_bessel_j1, warn=False)
Expand Down
29 changes: 29 additions & 0 deletions torch/_refs/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"hinge_embedding_loss",
"huber_loss",
"l1_loss",
"smmooth_l1_loss",
"log_softmax",
"margin_ranking_loss",
"mish",
Expand Down Expand Up @@ -542,6 +543,34 @@ def l1_loss(
return _apply_loss_reduction(loss, reduction)


def smooth_l1_loss(
input: TensorLikeType,
target: TensorLikeType,
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean",
beta: float = 1.0,
) -> TensorLikeType:
"""
Reference implementation of torch.nn.functional.smooth_l1_loss
"""
if size_average is not None or reduce is not None:
# TODO: Raise exception instead of converting value. This is only for
# primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
_check_reduction_value(reduction)

if beta == 0.0:
return l1_loss(
input, target, size_average=size_average, reduce=reduce, reduction=reduction
)
else:
loss = torch.abs(input - target)
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
return _apply_loss_reduction(loss, reduction)


# Forwarding alias: the functional variant doesn't support the out kwarg
# CompositeImplicitAutograd - don't register decomp
def log_softmax(
Expand Down
5 changes: 5 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18969,6 +18969,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
torch_opinfo_name="nn.functional.mse_loss",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.nn.functional.smooth_l1_loss",
torch_opinfo_name="nn.functional.smooth_l1_loss",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.nn.functional.hinge_embedding_loss",
torch_opinfo_name="nn.functional.hinge_embedding_loss",
Expand Down

0 comments on commit 9a728ab

Please sign in to comment.