Skip to content

Commit

Permalink
Add log_sigmoid_backward forward-AD (#99288)
Browse files Browse the repository at this point in the history
Fixes #95057
Pull Request resolved: #99288
Approved by: https://github.com/kshitij12345, https://github.com/albanD
  • Loading branch information
qqaatw authored and pytorchmergebot committed Apr 17, 2023
1 parent dede0bb commit e549ad0
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
6 changes: 1 addition & 5 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,13 +1366,11 @@ def get_vjp(cotangents, *primals):
xfail('grid_sampler_2d', ''), # NYI: forward AD for grid_sampler_2d
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer
xfail('NumpyCubeNotComposableAutogradFunction'), # not composable
xfail('renorm', ''), # NYI: forward AD for renorm
xfail('ormqr', ''), # NYI: forward AD for ormqr
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
xfail('nn.functional.multilabel_soft_margin_loss', ''), # NYI: log_sigmoid_backward
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for soft_margin_loss_backward
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
skip('nn.functional.scaled_dot_product_attention', device_type='cuda'),
Expand Down Expand Up @@ -1516,14 +1514,12 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
# running_mean or running_var, which will be updated in place,
# were not batched.
xfail('nn.functional.instance_norm'),
xfail('nn.functional.logsigmoid'), # Forward AD not implemented and no decomposition
# NYI: Tensor.clone(memory_format) inside vmap is only supported with
# memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
xfail('nn.functional.max_unpool2d'),
xfail('nn.functional.max_unpool2d', 'grad'),
xfail('nn.functional.multi_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.multilabel_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.multilabel_soft_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.pdist'), # Forward AD not implemented and no decomposition
xfail('nn.functional.rrelu'), # vmap: we do not yet support aten::rrelu_with_noise.
xfail('nn.functional.soft_margin_loss'), # Forward AD not implemented and no decomposition
Expand Down
2 changes: 2 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,7 @@
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
self: log_sigmoid_backward(grad, self, buffer)
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj()
output_differentiability: [True, False]

- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
Expand Down Expand Up @@ -2326,6 +2327,7 @@
- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
grad_output: log_sigmoid_backward(grad, self, buffer)
self: log_sigmoid_double_backward(grad * grad_output, self)
result: log_sigmoid_backward(grad_output_t, self_p, buffer) + log_sigmoid_double_backward(self_t * grad_output_p, self_p)

- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true)
Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12526,6 +12526,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
Expand Down Expand Up @@ -13132,6 +13133,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_autograd=True,
assert_autodiffed=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_gradgrad=True,
# autodiff_nonfusible_nodes=["aten::log_sigmoid"],
decorators=[
Expand Down

0 comments on commit e549ad0

Please sign in to comment.