Skip to content

Commit

Permalink
Cleaned up code and renamed arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
yhl48 committed Apr 14, 2023
1 parent b6a93fd commit 03570dd
Showing 1 changed file with 4 additions and 86 deletions.
90 changes: 4 additions & 86 deletions aten/src/ATen/functorch/BatchRulesLoss.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,41 +53,20 @@ std::tuple<at::Tensor,optional<int64_t>>
mse_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [](const at::Tensor& a, const at::Tensor& b, int64_t red) {
return at::mse_loss(a, b, red);
reduction, [](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
return at::mse_loss(self, target, reduction);
});
};

std::tuple<at::Tensor,optional<int64_t>>
smooth_l1_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
optional<int64_t> target_bdim, int64_t reduction, double beta) {
return loss_batch_rule_helper(self, self_bdim, target, target_bdim,
reduction, [beta](const at::Tensor& a, const at::Tensor& b, int64_t red) {
return at::smooth_l1_loss(a, b, red, beta);
reduction, [beta](const at::Tensor& self, const at::Tensor& target, int64_t reduction) {
return at::smooth_l1_loss(self, target, reduction, beta);
});
};

// std::tuple<at::Tensor,optional<int64_t>>
// smooth_l1_loss_batch_rule(const at::Tensor& self, optional<int64_t> self_bdim, const at::Tensor& target,
// optional<int64_t> target_bdim, int64_t reduction, double beta) {
// auto self_ = flatten_logical(self, self_bdim);
// auto target_ = flatten_logical(target, target_bdim);
// auto result = at::smooth_l1_loss(self_, target_, Reduction::None, beta);
// if (result.dim() == 1) {
// return std::make_tuple(result, 0);
// } else if (reduction == Reduction::None) {
// DimVector end_shape;
// const auto batched_elem = self_bdim.has_value() ?
// moveBatchDimToFront(self, self_bdim) : moveBatchDimToFront(target, target_bdim);
// return std::make_tuple(result.reshape(batched_elem.sizes()), 0);
// } else if (reduction == Reduction::Sum) {
// return std::make_tuple(result.sum(-1), 0);
// } else if (reduction == Reduction::Mean) {
// return std::make_tuple(result.mean(-1), 0);
// }
// TORCH_INTERNAL_ASSERT(false);
// };

static Tensor apply_loss_reduction(const at::Tensor& unreduced, int64_t reduction) {
if (reduction == at::Reduction::Mean) {
return unreduced.mean();
Expand Down Expand Up @@ -318,74 +297,13 @@ at::Tensor nll_loss_backward_decomposition(
return grad_input * grad_output_;
}

// at::Tensor smooth_l1_loss_backward_decomposition(
// const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target,
// int64_t reduction, double beta) {

// auto diff = self - target;
// auto abs_diff = diff.abs();

// auto grad = at::where(abs_diff < beta, diff / beta, diff.sign());

// if (reduction == Reduction::Mean) {
// grad = grad / (grad.numel() / grad_output.numel());
// } else if (reduction == Reduction::Sum) {
// grad = grad * grad_output.view_as(grad);
// } else {
// grad = grad;
// }

// return grad;
// }

// at::Tensor smooth_l1_loss_backward_decomposition(
// const at::Tensor& grad_output, const at::Tensor& self, const at::Tensor& target,
// int64_t reduction, double beta) {

// auto diff = self - target;
// auto abs_diff = diff.abs();

// auto grad = at::where(abs_diff < beta, diff / beta, diff.sign());

// if (self.dim() > 0) {
// auto grad_output_shape = self.sizes().vec();
// for (size_t i = 0; i < grad_output_shape.size(); i++) {
// if (i != 0) {
// grad_output_shape[i] = 1;
// }
// }

// auto grad_output_broadcasted = grad_output.view(grad_output_shape);
// grad_output_broadcasted = grad_output_broadcasted.expand_as(self);

// if (reduction == Reduction::Mean) {
// grad = grad / (grad.numel() / grad_output.numel());
// } else if (reduction == Reduction::Sum) {
// grad = grad * grad_output_broadcasted;
// } else {
// grad = grad;
// }
// } else {
// if (reduction == Reduction::Mean) {
// grad = grad / grad.numel();
// } else if (reduction == Reduction::Sum) {
// grad = grad * grad_output.item().toDouble();
// } else {
// grad = grad;
// }
// }

// return grad;
// }

TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
m.impl("nll_loss_forward", nll_loss_forward_decomposition);
m.impl("nll_loss2d_forward", nll_loss_forward_decomposition);
m.impl("nll_loss_backward", nll_loss_backward_decomposition);
m.impl("nll_loss2d_backward", nll_loss_backward_decomposition);
VMAP_SUPPORT(mse_loss, mse_loss_batch_rule);
// mse_loss_backwards uses a decomposition for its batch rule
// m.impl("smooth_l1_loss_backward", smooth_l1_loss_backward_decomposition);
VMAP_SUPPORT(smooth_l1_loss, smooth_l1_loss_batch_rule);
m.impl("binary_cross_entropy", binary_cross_entropy_plumbing);
m.impl("binary_cross_entropy_backward", binary_cross_entropy_backward_plumbing);
Expand Down

0 comments on commit 03570dd

Please sign in to comment.