Skip to content

Commit

Permalink
[dispatcher] avoid autograd fixup step on non-backend keys
Browse files Browse the repository at this point in the history
ghstack-source-id: e5ba3f793a9f885b0df82c7d8259eb2dc089a269
Pull Request resolved: #46135
  • Loading branch information
bhosmer committed Oct 13, 2020
1 parent dac6807 commit 571e8d5
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 11 deletions.
25 changes: 14 additions & 11 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA .
// TODO: we can remove (2.4) and (4) after TypeDefault registrations are moved from catchAll to Math
// so that Math can populate to Autograd backend keys before fallback kernels.

// 1. Operator registration
if (auto direct_registration = getKernelForDispatchKey(dispatch_key)) {
return {*direct_registration.value(), "kernel"};
Expand All @@ -231,7 +231,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// non backend keys (e.g AutogradXXX, Batched etc) due to (2.1).
bool has_backend_kernel =
hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key).add(DispatchKey::DefaultBackend));

// 2.2. Use Math kernel if available. For autograd keys, we only use kernel from Math
// when there's no direct registration to its corresponding backend key or DefaultBackend.
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
Expand Down Expand Up @@ -280,7 +280,8 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// synchronizes the dispatch table entry for a given dispatch key
// with the current state of kernel registrations in the dispatcher.
// note that this is not a complete update, due to relationships between
// dispatch keys (e.g. runtime keys and their associated autograd keys).
// dispatch keys (e.g. runtime keys and their associated autograd keys,
// or alias keys and their associated keysets).
// This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
Expand All @@ -289,9 +290,9 @@ void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher,
}

// synchronizes the dispatch table entries for a given dispatch key *and its
// associated keys* with the current state of kernel registrations in the
// dispatcher.
// After a kernel has been registered to a dispatch key, a call to this
// associated keys* with the current state of kernel registrations in the
// dispatcher.
// After a kernel has been registered to a dispatch key, a call to this
// function will synchronize the dispatcher state. See e.g. registerKernel()
void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
// Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
Expand All @@ -305,16 +306,18 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
}
// Note [Refresh Runtime Autograd entries in dispatchTable_]
// Registering to backend key might affect computed entry at its Autograd backend key due to (2.1) & (2.3).
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key);
updateDispatchTableEntry_(dispatcher, autograd_key);
if (c10::isBackendDispatchKey(dispatch_key)) {
DispatchKey autograd_key = getAutogradKeyFromBackend(dispatch_key);
updateDispatchTableEntry_(dispatcher, autograd_key);
}
}

// does a complete update of the dispatch table, synchronizing all
// does a complete update of the dispatch table, synchronizing all
// runtime dispatch keys with the current state of kernel registrations
// in the dispatcher.
// Note that we use updateDispatchTable_() to perform our per-key updating,
// even though that function is equipped to handle out-of-order updates and
// alias key updates, neither of which we send it. This is deliberate - the
// even though that function is equipped to handle out-of-order updates and
// alias key updates, neither of which we send it. This is deliberate - the
// current design is more tractable with all updates funneled through a single
// per-key update mechanism, than with multiple variations that assume different
// invariants.
Expand Down
7 changes: 7 additions & 0 deletions c10/core/DispatchKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ std::ostream& operator<<(std::ostream& str, DispatchKey rhs) {
return str << toString(rhs);
}

// for a given backend key, return the associated autograd key.
// for non-backend keys, return AutogradOther as a default.
// Note: it's convenient and fast to return a default here rather than (say)
// returning an optional<DispatchKey>, or throwing. But it makes callers
// responsible for either a) enforcing the invariant that only backend keys
// be passed as arguments, or b) interpreting our return value carefully.
//
DispatchKey getAutogradKeyFromBackend(DispatchKey t) {
switch (t) {
case DispatchKey::CPU:
Expand Down
6 changes: 6 additions & 0 deletions c10/core/DispatchKeySet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends | Disp
DispatchKey::PrivateUse3,
});

bool isBackendDispatchKey(DispatchKey t) {
return t != DispatchKey::Undefined && backend_dispatch_keyset.has(t);
}

// math_dispatch_keyset contains all keys in backend_dispatch_keyset and autograd_dispatch_keyset
// Alias key DispatchKey::Math maps to math_dispatch_keyset.
constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | autograd_dispatch_keyset;
Expand All @@ -31,6 +35,8 @@ DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) {
}
}

// for a given autograd key, return the (guaranteed nonempty) set of associated backend keys.
// for a non-autograd key, return the empty keyset.
DispatchKeySet getBackendKeySetFromAutograd(DispatchKey t) {
switch (t) {
case DispatchKey::AutogradCPU:
Expand Down
3 changes: 3 additions & 0 deletions c10/core/DispatchKeySet.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,9 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet({
DispatchKey::SparseHIP,
});

// true if t is a backend dispatch key
C10_API bool isBackendDispatchKey(DispatchKey t);

// Resolve alias dispatch key to DispatchKeySet if applicable
C10_API DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t);

Expand Down

0 comments on commit 571e8d5

Please sign in to comment.