Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

a few more comments on dispatch key computation methods #46128

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 34 additions & 15 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Expand Up @@ -175,7 +175,6 @@ c10::optional<const AnnotatedKernel*> OperatorEntry::getKernelForDispatchKey(Dis
}

std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
// [Note] DispatchTable computation
// dispatchTable contains entries for runtime dispatch keys.
// For any dispatch key, it'll pick a kernel using the following order:
Expand Down Expand Up @@ -215,7 +214,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA .
// TODO: we can remove (2.4) and (4) after TypeDefault registrations are moved from catchAll to Math
// so that Math can populate to Autograd backend keys before fallback kernels.

// 1. Operator registration
if (auto direct_registration = getKernelForDispatchKey(dispatch_key)) {
return {*direct_registration.value(), "kernel"};
Expand All @@ -232,6 +231,7 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// non backend keys (e.g AutogradXXX, Batched etc) due to (2.1).
bool has_backend_kernel =
hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key).add(DispatchKey::DefaultBackend));

// 2.2. Use Math kernel if available. For autograd keys, we only use kernel from Math
// when there's no direct registration to its corresponding backend key or DefaultBackend.
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
Expand All @@ -252,38 +252,47 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
if (auto autograd_registration = getKernelForDispatchKey(DispatchKey::Autograd)) {
return {*autograd_registration.value(), "autograd kernel"};
}
}

// 2.4. For autograd backend keys, we use kernel from catchAll if there's no direct
// registration to the backend key or DefaultBackend. Once CatchAll is moved to Math, this should
// fit 2.1 and we can remove 2.3 entirely.
if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)
&& !has_backend_kernel && !catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
// 2.4. For autograd dispatch keys, we use kernel from catchAll if there's no direct
// registration to the backend key or DefaultBackend. Once CatchAll is moved to Math, this should
// fit 2.1 and we can remove 2.4 entirely.
if (!has_backend_kernel && !catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
}
}

// 3. Backend fallback
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
}

// 4. Catch all
} else if (!catchAllKernel_.empty()) {
if (!catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
}

// 5. Default to error
} else {
return {missingKernel_, "missing"};
}
return {missingKernel_, "missing"};
}

// synchronizes the dispatch table entry for a given dispatch key
// with the current state of kernel registrations in the dispatcher.
// note that this is not a complete update, due to relationships between
// dispatch keys (e.g. runtime keys and their associated autograd keys).
// This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
}

// synchronizes the dispatch table entries for a given dispatch key *and its
// associated keys* with the current state of kernel registrations in the
// dispatcher.
// After a kernel has been registered to a dispatch key, a call to this
// function will synchronize the dispatcher state. See e.g. registerKernel()
void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
// Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
// See Note [Undefined in dispatchTable_]
Expand All @@ -300,6 +309,16 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
updateDispatchTableEntry_(dispatcher, autograd_key);
}

// does a complete update of the dispatch table, synchronizing all
// runtime dispatch keys with the current state of kernel registrations
// in the dispatcher.
// Note that we use updateDispatchTable_() to perform our per-key updating,
// even though that function is equipped to handle out-of-order updates and
// alias key updates, neither of which we send it. This is deliberate - the
// current design is more tractable with all updates funneled through a single
// per-key update mechanism, than with multiple variations that assume different
// invariants.
//
void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
// Note [Undefined in dispatchTable_]
// (1) it gives people place to specify functionality that should run when there are no dispatch keys,
Expand Down