Skip to content

Commit

Permalink
Update on "Add log_sigmoid_backward forward-AD"
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
qqaatw committed Apr 17, 2023
1 parent 4f3def1 commit c41e998
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 0 deletions.
1 change: 1 addition & 0 deletions test/functorch/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,7 @@ def get_vjp(cotangents, *primals):
xfail('renorm', ''), # NYI: forward AD for renorm
xfail('ormqr', ''), # NYI: forward AD for ormqr
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for soft_margin_loss_backward
xfail('nn.functional.ctc_loss', ''), # NYI: forward-AD for _ctc_loss
xfail('nn.functional.pdist', ''), # NYI: forward-AD with _pdist_forward
skip('nn.functional.scaled_dot_product_attention', device_type='cuda'),
Expand Down
1 change: 1 addition & 0 deletions tools/autograd/derivatives.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1998,6 +1998,7 @@
- name: log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)
self: log_sigmoid_backward(grad, self, buffer)
output: log_sigmoid_backward(self_t.conj(), self_p, buffer).conj()
output_differentiability: [True, False]

- name: _log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor
self: _log_softmax_backward_data(grad, result, dim, self.scalar_type())
Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12526,6 +12526,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_multilabel_soft_margin_loss,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
decorators=(
DecorateInfo(
toleranceOverride({torch.float32: tol(atol=1e-4, rtol=1e-4)}),
Expand Down Expand Up @@ -13132,6 +13133,7 @@ def reference_flatten(input, start_dim=0, end_dim=-1):
supports_autograd=True,
assert_autodiffed=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
supports_gradgrad=True,
# autodiff_nonfusible_nodes=["aten::log_sigmoid"],
decorators=[
Expand Down

0 comments on commit c41e998

Please sign in to comment.