Add op.input_func to OpInfo-based testing #50837
Labels
module: autograd
Related to torch.autograd, and the autograd engine in general
module: testing
Issues related to the torch.testing module (not tests)
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Some functions with a mathematically constrained type of input require special preprocessing so that finite-difference derivatives match the analytical ones implemented in the backward operation. Some examples are Hermitian (or symmetric) linear algebra functions
torch.linalg.cholesky
,torch.linalg.eigh
, or function taking a triangular matrix input astorch.cholesky_inverse
.This requirement can be met for example with a custom OpInfo class that overrides
get_op
method. But thentest_variant_consistency_jit
tests would fail with mismatch error in computed gradients because the jit scripted function ignores theget_op
method.OpInfo
already hasoutput_func
, so we should addinput_func
as well.It should be enough to change these lines https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_jit.py#L127-L131
to something like
to make
test_variant_consistency_jit
tests pass with customop.input_func
. All other occurrences of calling.get_op()
should also useop.input_func
.cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @mruberry
The text was updated successfully, but these errors were encountered: