Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add log_sigmoid_backward forward-AD #99288

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 1 addition & 5 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1368,13 +1368,11 @@ def get_vjp(cotangents, *primals):
xfail('grid_sampler_2d', ''), # NYI: forward AD for grid_sampler_2d
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
xfail('nn.functional.logsigmoid', ''), # not differentiable w.r.t. buffer
xfail('NumpyCubeNotComposableAutogradFunction'), # not composable
xfail('renorm', ''), # NYI: forward AD for renorm
xfail('ormqr', ''), # NYI: forward AD for ormqr
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
xfail('nn.functional.multilabel_soft_margin_loss', ''), # NYI: log_sigmoid_backward
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for log_sigmoid_backward
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for soft_margin_loss_backward
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
skip('nn.functional.scaled_dot_product_attention', device_type='cuda'),
Expand Down Expand Up @@ -1518,14 +1516,12 @@ def reference(primals, cotangents, primals_tangents, cotangents_tangents):
# running_mean or running_var, which will be updated in place,
# were not batched.
xfail('nn.functional.instance_norm'),
xfail('nn.functional.logsigmoid'), # Forward AD not implemented and no decomposition
# NYI: Tensor.clone(memory_format) inside vmap is only supported with
# memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
xfail('nn.functional.max_unpool2d'),
xfail('nn.functional.max_unpool2d', 'grad'),
xfail('nn.functional.multi_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.multilabel_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.multilabel_soft_margin_loss'), # Forward AD not implemented and no decomposition
xfail('nn.functional.pdist'), # Forward AD not implemented and no decomposition
xfail('nn.functional.rrelu'), # vmap: we do not yet support aten::rrelu_with_noise.
xfail('nn.functional.soft_margin_loss'), # Forward AD not implemented and no decomposition
Expand Down
2 changes: 2 additions & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,7 @@
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
self: log_sigmoid_backward(grad, self, buffer)
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj()
output_differentiability: [True, False]

- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
Expand Down Expand Up @@ -2326,6 +2327,7 @@
- name: log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor
grad_output: log_sigmoid_backward(grad, self, buffer)
self: log_sigmoid_double_backward(grad * grad_output, self)
result: log_sigmoid_backward(grad_output_t, self_p, buffer) + log_sigmoid_double_backward(self_t * grad_output_p, self_p)

- name: _log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor
grad_output: grad.to(output.dtype()) - (grad.to(output.dtype()) * output.exp()).sum(dim, true)
Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12526,6 +12526,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
Expand Down Expand Up @@ -13132,6 +13133,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_autograd=True,
assert_autodiffed=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_gradgrad=True,
# autodiff_nonfusible_nodes=["aten::log_sigmoid"],
decorators=[
Expand Down