Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions aten/src/ATen/core/PythonFallbackKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,27 @@
#include <ATen/core/dispatch/Dispatcher.h>
#include <ATen/core/PythonModeTLS.h>

#include <stack>

namespace {

// TLS saving the state of the include/exclude sets on entry to the dispatcher
// This is set in the pythonTLSSnapshot fallback and used by the Python fallback.
thread_local c10::optional<c10::impl::LocalDispatchKeySet> tls_on_entry;
thread_local std::stack<c10::impl::LocalDispatchKeySet> tls_on_entry;

struct C10_API StashTLSStateGuard {
public:
StashTLSStateGuard(const c10::impl::LocalDispatchKeySet& key_set) {
tls_on_entry.push(key_set);
}
~StashTLSStateGuard() {
tls_on_entry.pop();
}
};

void pythonFallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.value());
TORCH_INTERNAL_ASSERT(tls_on_entry.size() > 0);
c10::impl::ForceDispatchKeyGuard guard(tls_on_entry.top());

// If Python Mode is active, use its PyInterpreter for dispatch
const auto& maybe_python_mode_state = at::impl::PythonModeTLS::get_state();
Expand Down Expand Up @@ -54,11 +66,9 @@ void pythonTLSSnapshotFallback(const c10::OperatorHandle& op, c10::DispatchKeySe
// A CompositeImplicitAutograd function may have been called just before this and so the tls here were never cleared
// This is also why we don't need an RAII to ensure the tls is reset when exceptions happen

tls_on_entry = c10::impl::tls_local_dispatch_key_set();
StashTLSStateGuard guard(c10::impl::tls_local_dispatch_key_set());

op.redispatchBoxed(dispatch_keys & c10::DispatchKeySet(c10::DispatchKeySet::FULL_AFTER, c10::DispatchKey::PythonTLSSnapshot), stack);

tls_on_entry = c10::nullopt;
}


Expand Down
40 changes: 40 additions & 0 deletions test/test_python_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,46 @@ def test_autograd_in_attr(self):
self.assertIsNone(t.grad)
self.assertIsNotNone(t.elem.grad)

def test_multiple_ops_subclass(self):
# This is a Direct Subclass, don't do that!
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a "Direct Subclass" and who shouldn't do that?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a direct subclass in the sense that the Tensor that we send to the backend and the subclass are the same.
It is ok to do it, but not composable and tricky in general. Since a lot of people use these tests as example, I prefered to warn here that this is not a good example to copy.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

^ you should probably put that into the comment in a code, as a user I would be more even confused haha

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

related question: does the bug with the non-stack-based version only show up when you do a "Direct Subclass" (and not show up when you do a wrapper subclass)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The direct subclass is the simplest repro I could come up with.
@Chillee managed to hit the same problem with some of the functorch tests. In that case, it was while going down a clamp_min_ op, an extra clone was done before executing the function itself. So I guess functionalization was involved? But I don't know how to enable that easily in a small repro. Hence the use of direct subclass and conjugate fallback.

class MySubclass(torch.Tensor):
@staticmethod
def __new__(cls, elem):
r = torch.Tensor._make_subclass(cls, elem)
return r

__torch_function__ = torch._C._disabled_torch_function_impl

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
with no_dispatch():
return func(*args, **kwargs)

x = MySubclass(torch.rand(2, 2, dtype=torch.complex64))
y = x.conj()
# Details of the bug that this tests for:
# Here, y dispatch keys are: {PythonTLSSnapshot, AutogradCPU, Conjugate, Python, CPU}
# There are a few calls to the dispatcher that are going to happen here:
# - call_exp: User calling exp on y
# - PythonTLSSnapshot: records the TLS on entry and redispatch
# - AutogradCPU: no input requires grad, so does nothing and redispatch
# - Conjugate: no special implementation for exp: use the fallback that
# first clone the Tensor (to materialize the conj) then redispatch
# - call_clone: conjugate fallback calling clone on y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potentially a stupid question, but, why does call_clone go to PythonTLSSnapshot? I would have expected the redispatch from Conjugate to go into the Python dispatch key?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a redispatch, this is a call. Namely:

auto resolved_tensor = at::clone(tensor);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I think that's because the conjugate fallback ends up calling at::clone(), which enters the dispatcher through Dispatcher::call(). PythonTLSSnapshot is set up to always run first every time you call Dispatcher::call())

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the PythonTLSSnapshot key in the local exclude key set at this point? So I would expect the at::clone call to skip past PythonTLSSnapshot, AutogradCPU, and Conjugate

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I see, you don't use ExcludeDispatchKeyGuards to do the redispatch...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redispatch here means actually using redispatch vs clone.
The ExcludeDispatchKeyGuards is not used anymore except by autograd.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh does this mean on the at::clone call, AutogradCPU doesn't get hit? AutogradCPU is in the local exclude set (because VariableType::blah both uses an exclude key guard AND at::redispatch)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is correct! My bad!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the detailed comment and explanations, I understand what's going on now. I agree that this test is testing the above code now.

On an orthogonal note, I'm a bit concerned that AutogradCPU doesn't get hit, but maybe that's not a problem because that's how things work even without the Python keys

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's what happens when your key is "below" autograd: autogad is already handled above so you don't need to care about it here :)

# - PythonTLSSnapshot: records the TLS on entry and redispatch
# - (AutogradCPU: skipped as autograd added itself to the exclude set above)
# - Conjugate: special implementation for clone: just skip this key
# - Python: Reset the TLS based on the snapshot above and call the user implementation (this
# actually calls into the dispatcher again but since we disable both our keys
# before, not detailed here)
# - exit Python: restore the TLS and exit
# - exit Conjugate: nothing was inplace so just exit
# - exit PythonTLSSnapshot: done with this call, reset the saved TLS to empty
# - Python: Reset the TLS again based on the snapshot. <- this used to fail
# - More steps....
y.exp()



if __name__ == '__main__':
run_tests()