-
Notifications
You must be signed in to change notification settings - Fork 25.7k
Ensure that call before redispatch work well for PythonTLSSnapshot #73045
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -562,6 +562,46 @@ def test_autograd_in_attr(self): | |||
| self.assertIsNone(t.grad) | ||||
| self.assertIsNotNone(t.elem.grad) | ||||
|
|
||||
| def test_multiple_ops_subclass(self): | ||||
| # This is a Direct Subclass, don't do that! | ||||
|
||||
| class MySubclass(torch.Tensor): | ||||
| @staticmethod | ||||
| def __new__(cls, elem): | ||||
| r = torch.Tensor._make_subclass(cls, elem) | ||||
| return r | ||||
|
|
||||
| __torch_function__ = torch._C._disabled_torch_function_impl | ||||
|
|
||||
| @classmethod | ||||
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||||
| with no_dispatch(): | ||||
| return func(*args, **kwargs) | ||||
|
|
||||
| x = MySubclass(torch.rand(2, 2, dtype=torch.complex64)) | ||||
| y = x.conj() | ||||
| # Details of the bug that this tests for: | ||||
| # Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU} | ||||
| # There are a few calls to the dispatcher that are going to happen here: | ||||
| # - call_exp: User calling exp on y | ||||
| # - PythonTLSSnapshot: records the TLS on entry and redispatch | ||||
| # - AutogradCPU: no input requires grad, so does nothing and redispatch | ||||
| # - Conjugate: no special implementation for exp: use the fallback that | ||||
| # first clone the Tensor (to materialize the conj) then redispatch | ||||
| # - call_clone: conjugate fallback calling clone on y | ||||
|
||||
| auto resolved_tensor = at::clone(tensor); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(I think that's because the conjugate fallback ends up calling at::clone(), which enters the dispatcher through Dispatcher::call(). PythonTLSSnapshot is set up to always run first every time you call Dispatcher::call())
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the PythonTLSSnapshot key in the local exclude key set at this point? So I would expect the at::clone call to skip past PythonTLSSnapshot, AutogradCPU, and Conjugate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, I see, you don't use ExcludeDispatchKeyGuards to do the redispatch...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redispatch here means actually using redispatch vs clone.
The ExcludeDispatchKeyGuards is not used anymore except by autograd.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ugh does this mean on the at::clone call, AutogradCPU doesn't get hit? AutogradCPU is in the local exclude set (because VariableType::blah both uses an exclude key guard AND at::redispatch)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is correct! My bad!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the detailed comment and explanations, I understand what's going on now. I agree that this test is testing the above code now.
On an orthogonal note, I'm a bit concerned that AutogradCPU doesn't get hit, but maybe that's not a problem because that's how things work even without the Python keys
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's what happens when your key is "below" autograd: autogad is already handled above so you don't need to care about it here :)
Uh oh!
There was an error while loading. Please reload this page.