Skip to content

Commit

Permalink
Add vmap support for torch.nn.functional.smooth_l1_loss (#98357)
Browse files Browse the repository at this point in the history
Partially fixes #97246 and #97558.

Pull Request resolved: #98357
Approved by: https://github.com/kshitij12345
  • Loading branch information
yhl48 authored and pytorchmergebot committed Apr 14, 2023
1 parent 1e78a2e commit 298cc5c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
30 changes: 26 additions & 4 deletions aten/src/ATen/functorch/BatchRulesLoss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@ at::Tensor flatten_logical(const Tensor& tensor, optional<int64_t> bdim) {
}
}

std::tuple<at::Tensor,optional<int64_t>>
mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction) {
// Useful for many loss functions
template <typename Func>
static std::tuple<at::Tensor,optional<int64_t>>
loss_batch_rule_helper(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction,
Func loss_fn) {
auto self_ = flatten_logical(self, self_bdim);
auto target_ = flatten_logical(target, target_bdim);
auto result = at::mse_loss(self_, target_, Reduction::None);
auto result = loss_fn(self_, target_, Reduction::None);
if (result.dim() == 1) {
return std::make_tuple(result, 0);
} else if (reduction == Reduction::None) {
Expand All @@ -46,6 +49,24 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
TORCH_INTERNAL_ASSERT(false);
};

std::tuple<at::Tensor,optional<int64_t>>
mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
return at::mse_loss(self, target, reduction);
});
};

std::tuple<at::Tensor,optional<int64_t>>
smooth_l1_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction, double beta) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [beta](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
return at::smooth_l1_loss(self, target, reduction, beta);
});
};

static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
if (reduction == at::Reduction::Mean) {
return unreduced.mean();
Expand Down Expand Up @@ -283,6 +304,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("nll_loss2d_backward", nll_loss_backward_decomposition);
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
// mse_loss_backwards uses a decomposition for its batch rule
VMAP_SUPPORT(smooth_l1_loss, smooth_l1_loss_batch_rule);
m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
}
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3695,7 +3695,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('nn.functional.triplet_margin_loss', ''),
xfail('nn.functional.pdist', ''),
xfail('scatter_reduce', 'sum'),
xfail('nn.functional.smooth_l1_loss', ''),
xfail('scatter_reduce', 'amax'),
xfail('nn.functional.max_unpool1d', 'grad'),
xfail('nn.functional.multi_margin_loss', ''),
Expand Down

0 comments on commit 298cc5c

Please sign in to comment.