Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add vmap support for torch.nn.functional.smooth_l1_loss #98357

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 25 additions & 4 deletions aten/src/ATen/functorch/BatchRulesLoss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ at::Tensor flatten_logical(const Tensor& tensor, optional<int64_t> bdim) {
}
}

std::tuple<at::Tensor,optional<int64_t>>
mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction) {
// Useful for many loss functions
static std::tuple<at::Tensor,optional<int64_t>>
loss_batch_rule_helper(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction,
std::function<at::Tensor(const at::Tensor&, const at::Tensor&, int64_t)> loss_fn) {
yhl48 marked this conversation as resolved.
Show resolved Hide resolved
auto self_ = flatten_logical(self, self_bdim);
auto target_ = flatten_logical(target, target_bdim);
auto result = at::mse_loss(self_, target_, Reduction::None);
auto result = loss_fn(self_, target_, Reduction::None);
if (result.dim() == 1) {
return std::make_tuple(result, 0);
} else if (reduction == Reduction::None) {
Expand All @@ -46,6 +48,24 @@ mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const a
TORCH_INTERNAL_ASSERT(false);
};

std::tuple<at::Tensor,optional<int64_t>>
mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [](const at::Tensor& a, const at::Tensor& b, int64_t red) {
yhl48 marked this conversation as resolved.
Show resolved Hide resolved
return at::mse_loss(a, b, red);
});
};

std::tuple<at::Tensor,optional<int64_t>>
smooth_l1_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
kshitij12345 marked this conversation as resolved.
Show resolved Hide resolved
optional<int64_t> target_bdim, int64_t reduction, double beta) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [beta](const at::Tensor& a, const at::Tensor& b, int64_t red) {
return at::smooth_l1_loss(a, b, red, beta);
});
};

static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
if (reduction == at::Reduction::Mean) {
return unreduced.mean();
Expand Down Expand Up @@ -283,6 +303,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("nll_loss2d_backward", nll_loss_backward_decomposition);
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
// mse_loss_backwards uses a decomposition for its batch rule
VMAP_SUPPORT(smooth_l1_loss, smooth_l1_loss_batch_rule);
m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
}
Expand Down
1 change: 0 additions & 1 deletion test/functorch/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3684,7 +3684,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
xfail('nn.functional.triplet_margin_loss', ''),
xfail('nn.functional.pdist', ''),
xfail('scatter_reduce', 'sum'),
xfail('nn.functional.smooth_l1_loss', ''),
xfail('scatter_reduce', 'amax'),
xfail('nn.functional.max_unpool1d', 'grad'),
xfail('nn.functional.multi_margin_loss', ''),
Expand Down