Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dispatcher] avoid autograd fixup step on non-backend keys #46135

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 14 additions & 11 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA .
// TODO: we can remove (2.4) and (4) after TypeDefault registrations are moved from catchAll to Math
// so that Math can populate to Autograd backend keys before fallback kernels.

// 1. Operator registration
if (auto direct_registration = getKernelForDispatchKey(dispatch_key)) {
return {*direct_registration.value(), "kernel"};
Expand All @@ -231,7 +231,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// non backend keys (e.g AutogradXXX, Batched etc) due to (2.1).
bool has_backend_kernel =
hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key).add(DispatchKey::DefaultBackend));

// 2.2. Use Math kernel if available. For autograd keys, we only use kernel from Math
// when there's no direct registration to its corresponding backend key or DefaultBackend.
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
Expand Down Expand Up @@ -280,7 +280,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// synchronizes the dispatch table entry for a given dispatch key
// with the current state of kernel registrations in the dispatcher.
// note that this is not a complete update, due to relationships between
// dispatch keys (e.g. runtime keys and their associated autograd keys).
// dispatch keys (e.g. runtime keys and their associated autograd keys,
// or alias keys and their associated keysets).
// This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
Expand All @@ -289,9 +290,9 @@ void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher,
}

// synchronizes the dispatch table entries for a given dispatch key *and its
// associated keys* with the current state of kernel registrations in the
// dispatcher.
// After a kernel has been registered to a dispatch key, a call to this
// associated keys* with the current state of kernel registrations in the
// dispatcher.
// After a kernel has been registered to a dispatch key, a call to this
// function will synchronize the dispatcher state. See e.g. registerKernel()
void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
// Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
Expand All @@ -305,16 +306,18 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
}
// Note [Refresh Runtime Autograd entries in dispatchTable_]
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key);
updateDispatchTableEntry_(dispatcher, autograd_key);
if (c10::isBackendDispatchKey(dispatch_key)) {
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key);
updateDispatchTableEntry_(dispatcher, autograd_key);
}
}

// does a complete update of the dispatch table, synchronizing all
// does a complete update of the dispatch table, synchronizing all
// runtime dispatch keys with the current state of kernel registrations
// in the dispatcher.
// Note that we use updateDispatchTable_() to perform our per-key updating,
// even though that function is equipped to handle out-of-order updates and
// alias key updates, neither of which we send it. This is deliberate - the
// even though that function is equipped to handle out-of-order updates and
// alias key updates, neither of which we send it. This is deliberate - the
// current design is more tractable with all updates funneled through a single
// per-key update mechanism, than with multiple variations that assume different
// invariants.
Expand Down
7 changes: 7 additions & 0 deletions c10/core/DispatchKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ std::ostream& operator<<(std::ostream& str, DispatchKey rhs) {
return str << toString(rhs);
}

// for a given backend key, return the associated autograd key.
// for non-backend keys, return AutogradOther as a default.
// Note: it's convenient and fast to return a default here rather than (say)
// returning an optional<DispatchKey>, or throwing. But it makes callers
// responsible for either a) enforcing the invariant that only backend keys
// be passed as arguments, or b) interpreting our return value carefully.
//
DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
switch (t) {
case DispatchKey::CPU:
Expand Down
6 changes: 6 additions & 0 deletions c10/core/DispatchKeySet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | Disp
DispatchKey::PrivateUse3,
});

bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t);
}

// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset
// Alias key DispatchKey::Math maps to math_dispatch_keyset.
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset;
Expand All @@ -31,6 +35,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
}
}

// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys.
// for a non-autograd key, return the empty keyset.
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
switch (t) {
case DispatchKey::AutogradCPU:
Expand Down
3 changes: 3 additions & 0 deletions c10/core/DispatchKeySet.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet({
DispatchKey::SparseHIP,
});

// true if t is a backend dispatch key
C10_API bool isBackendDispatchKey(DispatchKey t);

// Resolve alias dispatch key to DispatchKeySet if applicable
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);

Expand Down