Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add op.input_func to OpInfo-based testing #50837

Closed
IvanYashchuk opened this issue Jan 20, 2021 · 2 comments
Closed

Add op.input_func to OpInfo-based testing #50837

IvanYashchuk opened this issue Jan 20, 2021 · 2 comments
Assignees
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: testing Issues related to the torch.testing module (not tests) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@IvanYashchuk
Copy link
Collaborator

IvanYashchuk commented Jan 20, 2021

Some functions with a mathematically constrained type of input require special preprocessing so that finite-difference derivatives match the analytical ones implemented in the backward operation. Some examples are Hermitian (or symmetric) linear algebra functions torch.linalg.cholesky, torch.linalg.eigh, or function taking a triangular matrix input as torch.cholesky_inverse.
This requirement can be met for example with a custom OpInfo class that overrides get_op method. But then test_variant_consistency_jit tests would fail with mismatch error in computed gradients because the jit scripted function ignores the get_op method.
OpInfo already has output_func, so we should add input_func as well.

It should be enough to change these lines https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_jit.py#L127-L131
to something like

def runAndSaveRNG(self, func, inputs, kwargs=None, input_func=None):
    kwargs = kwargs if kwargs else {}
    with freeze_rng_state():
        if input_func is not None:
            inputs = [input_func(*inputs, **kwargs)]
        results = func(*inputs, **kwargs)
    return results

to make test_variant_consistency_jit tests pass with custom op.input_func. All other occurrences of calling .get_op() should also use op.input_func.

cc @ezyang @albanD @zou3519 @gqchen @pearu @nikitaved @soulitzer @mruberry

@IvanYashchuk IvanYashchuk added the module: testing Issues related to the torch.testing module (not tests) label Jan 20, 2021
@IvanYashchuk IvanYashchuk self-assigned this Jan 20, 2021
@ngimel ngimel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jan 21, 2021
@mruberry mruberry added the module: autograd Related to torch.autograd, and the autograd engine in general label Jan 25, 2021
@mruberry
Copy link
Collaborator

Linked to the test suite tracking issue and cc @albanD as an fyi. This seems like the right idea to me.

@albanD
Copy link
Collaborator

albanD commented Jan 25, 2021

SGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: autograd Related to torch.autograd, and the autograd engine in general module: testing Issues related to the torch.testing module (not tests) triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants