Skip to content

Commit

Permalink
a few more comments on dispatch key computation methods
Browse files Browse the repository at this point in the history
ghstack-source-id: e31d96a3e80cfa0b9dca64bc0e158ec2d16f4f08
Pull Request resolved: #46128
  • Loading branch information
bhosmer committed Oct 10, 2020
1 parent e33d455 commit 5807657
Showing 1 changed file with 34 additions and 15 deletions.
49 changes: 34 additions & 15 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Expand Up @@ -175,7 +175,6 @@ c10::optional<const AnnotatedKernel*> OperatorEntry::getKernelForDispatchKey(Dis
}

std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
// [Note] DispatchTable computation
// dispatchTable contains entries for runtime dispatch keys.
// For any dispatch key, it'll pick a kernel using the following order:
Expand Down Expand Up @@ -215,7 +214,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA .
// TODO: we can remove (2.4) and (4) after TypeDefault registrations are moved from catchAll to Math
// so that Math can populate to Autograd backend keys before fallback kernels.

// 1. Operator registration
if (auto direct_registration = getKernelForDispatchKey(dispatch_key)) {
return {*direct_registration.value(), "kernel"};
Expand All @@ -232,6 +231,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// non backend keys (e.g AutogradXXX, Batched etc) due to (2.1).
bool has_backend_kernel =
hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key).add(DispatchKey::DefaultBackend));

// 2.2. Use Math kernel if available. For autograd keys, we only use kernel from Math
// when there's no direct registration to its corresponding backend key or DefaultBackend.
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
Expand All @@ -252,38 +252,47 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
if (auto autograd_registration = getKernelForDispatchKey(DispatchKey::Autograd)) {
return {*autograd_registration.value(), "autograd kernel"};
}
}

// 2.4. For autograd backend keys, we use kernel from catchAll if there's no direct
// registration to the backend key or DefaultBackend. Once CatchAll is moved to Math, this should
// fit 2.1 and we can remove 2.3 entirely.
if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)
&& !has_backend_kernel && !catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
// 2.4. For autograd dispatch keys, we use kernel from catchAll if there's no direct
// registration to the backend key or DefaultBackend. Once CatchAll is moved to Math, this should
// fit 2.1 and we can remove 2.4 entirely.
if (!has_backend_kernel && !catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
}
}

// 3. Backend fallback
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
}

// 4. Catch all
} else if (!catchAllKernel_.empty()) {
if (!catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
}

// 5. Default to error
} else {
return {missingKernel_, "missing"};
}
return {missingKernel_, "missing"};
}

// synchronizes the dispatch table entry for a given dispatch key
// with the current state of kernel registrations in the dispatcher.
// note that this is not a complete update, due to relationships between
// dispatch keys (e.g. runtime keys and their associated autograd keys).
// This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
}

// synchronizes the dispatch table entries for a given dispatch key *and its
// associated keys* with the current state of kernel registrations in the
// dispatcher.
// After a kernel has been registered to a dispatch key, a call to this
// function will synchronize the dispatcher state. See e.g. registerKernel()
void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
// Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
// See Note [Undefined in dispatchTable_]
Expand All @@ -300,6 +309,16 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
updateDispatchTableEntry_(dispatcher, autograd_key);
}

// does a complete update of the dispatch table, synchronizing all
// runtime dispatch keys with the current state of kernel registrations
// in the dispatcher.
// Note that we use updateDispatchTable_() to perform our per-key updating,
// even though that function is equipped to handle out-of-order updates and
// alias key updates, neither of which we send it. This is deliberate - the
// current design is more tractable with all updates funneled through a single
// per-key update mechanism, than with multiple variations that assume different
// invariants.
//
void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
// Note [Undefined in dispatchTable_]
// (1) it gives people place to specify functionality that should run when there are no dispatch keys,
Expand Down

0 comments on commit 5807657

Please sign in to comment.