Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[inductor] Added smooth_l1_loss refs #102077

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 0 additions & 2 deletions test/expect/HasDecompTest.test_has_decomposition.expect
Original file line number Diff line number Diff line change
Expand Up @@ -1161,8 +1161,6 @@ aten::slow_conv_transpose2d
aten::slow_conv_transpose2d.out
aten::slow_conv_transpose3d
aten::slow_conv_transpose3d.out
aten::smooth_l1_loss
aten::smooth_l1_loss.out
aten::softmax.int_out
aten::sort
aten::sort.stable
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2830,7 +2830,6 @@ def forward(self, x):
xfail('nn.functional.pixel_shuffle', ''), # aten.pixel_shuffle.default - couldn't find symbolic meta fun...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta...
xfail('nn.functional.rrelu', ''), # aten.rrelu_with_noise.default - couldn't find symbolic meta function...
xfail('nn.functional.smooth_l1_loss', ''), # could not find kernel
xfail('normal', 'number_mean'), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
xfail('pinverse', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomp...
Expand Down
1 change: 1 addition & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1823,6 +1823,7 @@ class TestRefsOpsInfo(TestCase):
'_refs.movedim',
'_refs.narrow',
'_refs.nn.functional.l1_loss',
'_refs.nn.functional.smooth_l1_loss',
'_refs.nn.functional.log_softmax',
'_refs.nn.functional.poisson_nll_loss',
'_refs.nn.functional.softmax',
Expand Down
1 change: 0 additions & 1 deletion test/test_proxy_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,6 @@ def f(a, b, c, d, e):
xfail('nn.functional.pad', 'replicate'), # aten.replication_pad1d.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.pdist', ''), # Could not run 'aten::_pdist_forward' with arguments from the 'Meta' backend...
xfail('nn.functional.pixel_unshuffle', ''), # aten.pixel_unshuffle.default - couldn't find symbolic meta function/deco...
xfail('nn.functional.smooth_l1_loss', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('normal', 'number_mean'), # aten.normal.float_Tensor - couldn't find symbolic meta function/decomposition
xfail('ormqr', ''), # aten.ormqr.default - couldn't find symbolic meta function/decomposition
xfail('pinverse', ''), # aten.linalg_pinv.atol_rtol_tensor - couldn't find symbolic meta function/decomposition
Expand Down
1 change: 1 addition & 0 deletions torch/_decomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ def core_aten_decompositions() -> Dict[OpOverload, Callable]:
aten.silu_backward,
aten.sinc,
aten.slice_backward,
aten.smooth_l1_loss,
aten.smooth_l1_loss_backward,
aten.soft_margin_loss,
aten.soft_margin_loss_backward,
Expand Down
13 changes: 13 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,19 @@ def mse_loss_backward(
return norm * (input - target) * grad_output


@register_decomposition(aten.smooth_l1_loss)
@pw_cast_for_opmath
def smooth_l1_loss(
self: Tensor,
target: Tensor,
reduction: int = Reduction.MEAN.value,
beta: float = 1.0,
):
loss = (self - target).abs()
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
return apply_loss_reduction(loss, reduction)


@register_decomposition(aten.smooth_l1_loss_backward.default)
@pw_cast_for_opmath
def smooth_l1_loss_backward(
Expand Down
1 change: 0 additions & 1 deletion torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,6 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten.resize_as)
make_fallback(aten.resize_as_)
make_fallback(aten.searchsorted)
make_fallback(aten.smooth_l1_loss)
make_fallback(aten.special_airy_ai)
make_fallback(aten.special_bessel_j0, warn=False)
make_fallback(aten.special_bessel_j1, warn=False)
Expand Down
33 changes: 33 additions & 0 deletions torch/_refs/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"hinge_embedding_loss",
"huber_loss",
"l1_loss",
"smooth_l1_loss",
"log_softmax",
"margin_ranking_loss",
"mish",
Expand Down Expand Up @@ -542,6 +543,38 @@ def l1_loss(
return _apply_loss_reduction(loss, reduction)


@elementwise_type_promotion_wrapper(
type_promoting_args=("input", "target"),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT,
)
def smooth_l1_loss(
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
input: TensorLikeType,
target: TensorLikeType,
size_average: Optional[bool] = None,
reduce: Optional[bool] = None,
reduction: str = "mean",
beta: float = 1.0,
) -> TensorLikeType:
"""
Reference implementation of torch.nn.functional.smooth_l1_loss
"""
if size_average is not None or reduce is not None:
# TODO: Raise exception instead of converting value. This is only for
# primTorch since it can drop support for deprecated arguments.
# msg = "size_average and reduce args are deprecated, please use reduction argument."
reduction = _get_string_reduction_arg(size_average=size_average, reduce=reduce)
_check_reduction_value(reduction)

if beta == 0.0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need this conditional?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Python nn functional API does the same:

pytorch/torch/nn/functional.py

Lines 3242 to 3245 in 76af221

if beta == 0.0:
return torch._C._nn.l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))
else:
return torch._C._nn.smooth_l1_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction), beta)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we have enough logic as to optimise this example out in practice, as it would mean that we need to prove that (input-target).abs() is not negative. The conditional is alright for now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wdym? .abs() is non-negative. Functional API does this due to some numeric discrepancies in backward, this doesn't apply here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, my point simply comes from a perf perspective, wheere we would be computing both branches of the where and just using one, but probably LLVM should be able to catch the < 0 after .abs() and optimise it out.

That being said, I still think that keeping this closer to core is better, as we could think of eventually registering this operation and simply differentiating through it to get its backward. This beta==0 specialisation would make sure that this works in that case, as it does in master.

return torch.nn.functional.l1_loss(
input, target, size_average=size_average, reduce=reduce, reduction=reduction
)
else:
loss = torch.abs(input - target)
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
return _apply_loss_reduction(loss, reduction)


# Forwarding alias: the functional variant doesn't support the out kwarg
# CompositeImplicitAutograd - don't register decomp
def log_softmax(
Expand Down
5 changes: 5 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18975,6 +18975,11 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
torch_opinfo_name="nn.functional.mse_loss",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.nn.functional.smooth_l1_loss",
torch_opinfo_name="nn.functional.smooth_l1_loss",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.nn.functional.hinge_embedding_loss",
torch_opinfo_name="nn.functional.hinge_embedding_loss",
Expand Down