Skip to content

Commit

Permalink
Update on "Implement ravel"
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan committed Oct 10, 2020
2 parents d8f202b + 4c87d33 commit dcf8140
Show file tree
Hide file tree
Showing 57 changed files with 1,421 additions and 417 deletions.
99 changes: 69 additions & 30 deletions aten/src/ATen/core/dispatch/OperatorEntry.cpp
Expand Up @@ -156,7 +156,7 @@ const KernelFunction& OperatorEntry::computeDispatchTableEntry(const c10::Dispat
return computeDispatchTableEntryWithDebug(dispatcher, dispatch_key).first.kernel;
}

bool OperatorEntry::hasKernelForDispatchKeySet(DispatchKeySet ks) const {
bool OperatorEntry::hasKernelForAnyDispatchKey(DispatchKeySet ks) const {
TORCH_INTERNAL_ASSERT(kernels_.find(DispatchKey::Undefined) == kernels_.end());
for (auto& kv : kernels_) {
if (ks.has(kv.first)) return true;
Expand All @@ -175,95 +175,124 @@ c10::optional<const AnnotatedKernel*> OperatorEntry::getKernelForDispatchKey(Dis
}

std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTableEntryWithDebug(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) const {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
// [Note] DispatchTable computation
// dispatchTable contains entries for runtime dispatch keys.
// For any dispatch key, it'll pick a kernel using the following order:
// (1) Use kernel if it's directly registered to this key
// (2) Handle runtime keys that have kernels available from alias keys
// (2.1) Use kernel from DispatchKey::Math if available.
// (2.1) Use kernel from DispatchKey::DefaultBackend if available.
// This is used to register a kernel that works for all backend in inference. But it requires
// separate registration for Autograd keys to support training.
// (2.2) Use kernel from DispatchKey::Math if available.
// For autograd keys, we only use kernel from Math when there's no direct registration
// to its corresponding backend key.
// to its corresponding backend key or DefaultBackend. See Note [DefaultBackend and Math].
// For AutogradOther, we eagerly return ambiguousAutogradOtherKernel_ if there's registration to any of
// its backends and ask backend extender to request a decicated Autograd key for the backend.
// See Note [Ambiguity in AutogradOther kernel] for more details.
// (2.2) Use kernel from DispatchKey::Autograd if available
// (2.3) Special logic to handle catchAll for Autograd keys
// A DefaultBackend kernel prevents Math kernel being used for Autograd keys, but it doesn't
// cause confusion for AutogradOther. It's pretty straightforward to use Autograd (if available)
// in this case.
// (2.3) Use kernel from DispatchKey::Autograd if available
// (2.4) Special logic to handle catchAll for Autograd keys
// For autograd backend keys, we use kernel from alias Math key (catchAll will be moved to Math)
// if there's no direct registration to the backend key.
// Tensor factory functions used to have no registration to Autograd key but only to catchAll.
// In the past we directly call into backends(filled with catchAll) after BackendSelect.
// Now that we first call Autograd backend keys after BackendSelect, we should fill those
// with catchAll as well.
// The implementation of (2.1) & (2.3) relies on the invariant that for a given backend,
// The implementation of (2.2) & (2.4) relies on the invariant that for a given backend,
// `computeDispatchTableEntryWithDebug()` will be called for that backend's autograd key after the
// backend key. See Note [Refresh Runtime Autograd entries in dispatchTable_]
// (3) Use fallthrough kernel that are registered as fallback.
// (4) Use catchAll kernel if available
// Alias Key Precedence:
// Math > Autograd
// DefaultBackend > Math > Autograd
// Note [DefaultBackend and Math]
// When there're registrations to both DefaultBackend & Math & Autograd, from (2.2) we know DefaultBackend
// and Autograd kernels will be picked up and Math is overriden.
// This is fine and in practice DefaultBackend and Math shouldn't co-exist for an op.
// TODO: Update alias key precedence after we add new alias keys AutogradDispatchCPUOrCUDA .
// TODO: we can remove (2.3) and (4) after TypeDefault registrations are moved from catchAll to Math
// TODO: we can remove (2.4) and (4) after TypeDefault registrations are moved from catchAll to Math
// so that Math can populate to Autograd backend keys before fallback kernels.

// 1. Operator registration
if (auto direct_registration = getKernelForDispatchKey(dispatch_key)) {
return {*direct_registration.value(), "kernel"};
}

bool is_autograd_key_with_backend_kernel =
hasKernelForDispatchKeySet(getBackendKeySetFromAutograd(dispatch_key));
// 2.1. Use Math kernel if available. For autograd keys, we only use kernel from Math
// when there's no direct registration to its corresponding backend key.
// 2.1 Use DefaultBackend kernel if available.
if (isIncludedInAlias(dispatch_key, DispatchKey::DefaultBackend)) {
if (auto default_backend_registration = getKernelForDispatchKey(DispatchKey::DefaultBackend)) {
return {*default_backend_registration.value(), "default backend kernel"};
}
}

// Note when there's direct registration to DefaultBackend, this code path will only be hit by
// non backend keys (e.g AutogradXXX, Batched etc) due to (2.1).
bool has_backend_kernel =
hasKernelForAnyDispatchKey(getBackendKeySetFromAutograd(dispatch_key).add(DispatchKey::DefaultBackend));

// 2.2. Use Math kernel if available. For autograd keys, we only use kernel from Math
// when there's no direct registration to its corresponding backend key or DefaultBackend.
// For AutogradOther, we return ambiguousAutogradOtherKernel_ if there's registration
// to any of its backends.
if (isIncludedInAlias(dispatch_key, DispatchKey::Math)) {
if (auto math_registration = getKernelForDispatchKey(DispatchKey::Math)) {
if (dispatch_key == DispatchKey::AutogradOther && is_autograd_key_with_backend_kernel) {
if (dispatch_key == DispatchKey::AutogradOther
&& hasKernelForAnyDispatchKey(c10::autogradother_backends)) {
return {ambiguousAutogradOtherKernel_, "ambiguous autogradother"};
} else if (!is_autograd_key_with_backend_kernel) {
} else if (!has_backend_kernel) {
return {*math_registration.value(), "math kernel"};
}
}
}

// 2.2. For autograd backend keys, use kernel from DispatchKey::Autograd if available
// 2.3. For autograd backend keys, use kernel from DispatchKey::Autograd if available
if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)) {
if (auto autograd_registration = getKernelForDispatchKey(DispatchKey::Autograd)) {
return {*autograd_registration.value(), "autograd kernel"};
}
}

// 2.3. For autograd backend keys, we use kernel from catchAll if there's no direct
// registration to the backend key. Once CatchAll is moved to Math, this should
// fit 2.1 and we can remove 2.3 entirely.
if (isIncludedInAlias(dispatch_key, DispatchKey::Autograd)
&& !is_autograd_key_with_backend_kernel && !catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
// 2.4. For autograd dispatch keys, we use kernel from catchAll if there's no direct
// registration to the backend key or DefaultBackend. Once CatchAll is moved to Math, this should
// fit 2.1 and we can remove 2.4 entirely.
if (!has_backend_kernel && !catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
}
}

// 3. Backend fallback
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
if (dispatcher.backendFallbackKernels_[dispatch_ix].kernel.isValid()) {
return {dispatcher.backendFallbackKernels_[dispatch_ix], "backend fallback"};
}

// 4. Catch all
} else if (!catchAllKernel_.empty()) {
if (!catchAllKernel_.empty()) {
TORCH_INTERNAL_ASSERT(catchAllKernel_.front().kernel.isValid());
return {catchAllKernel_.front(), "catch all"};
}

// 5. Default to error
} else {
return {missingKernel_, "missing"};
}
return {missingKernel_, "missing"};
}

// synchronizes the dispatch table entry for a given dispatch key
// with the current state of kernel registrations in the dispatcher.
// note that this is not a complete update, due to relationships between
// dispatch keys (e.g. runtime keys and their associated autograd keys).
// This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
}

// synchronizes the dispatch table entries for a given dispatch key *and its
// associated keys* with the current state of kernel registrations in the
// dispatcher.
// After a kernel has been registered to a dispatch key, a call to this
// function will synchronize the dispatcher state. See e.g. registerKernel()
void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
// Handle Undefined separately since it isn't a runtime key but we have an entry in dispatchTable_.
// See Note [Undefined in dispatchTable_]
Expand All @@ -280,6 +309,16 @@ void OperatorEntry::updateDispatchTable_(const c10::Dispatcher& dispatcher, Disp
updateDispatchTableEntry_(dispatcher, autograd_key);
}

// does a complete update of the dispatch table, synchronizing all
// runtime dispatch keys with the current state of kernel registrations
// in the dispatcher.
// Note that we use updateDispatchTable_() to perform our per-key updating,
// even though that function is equipped to handle out-of-order updates and
// alias key updates, neither of which we send it. This is deliberate - the
// current design is more tractable with all updates funneled through a single
// per-key update mechanism, than with multiple variations that assume different
// invariants.
//
void OperatorEntry::updateDispatchTableFull_(const c10::Dispatcher& dispatcher) {
// Note [Undefined in dispatchTable_]
// (1) it gives people place to specify functionality that should run when there are no dispatch keys,
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/core/dispatch/OperatorEntry.h
Expand Up @@ -254,7 +254,7 @@ class CAFFE2_API OperatorEntry final {
void updateDispatchTableFull_(const c10::Dispatcher& dispatcher);

// Returns true if kernel_ has entry for any key in ks.
bool hasKernelForDispatchKeySet(DispatchKeySet ks) const;
bool hasKernelForAnyDispatchKey(DispatchKeySet ks) const;
// Retrieves a pointer to AnnotatedKernel at kernels_.at(dispatch_key).front().
c10::optional<const AnnotatedKernel*> getKernelForDispatchKey(DispatchKey dispatch_key) const;
};
Expand Down
146 changes: 146 additions & 0 deletions aten/src/ATen/core/op_registration/op_registration_test.cpp
Expand Up @@ -1802,6 +1802,152 @@ TEST(NewOperatorRegistrationTest, BackendOverridesMathKernel) {
}
}

TEST(NewOperatorRegistrationTest, dispatchWithDefaultBackendKernel) {
bool called = false;
auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn", torch::dispatch(c10::DispatchKey::DefaultBackend, [&](const Tensor& x) { called = true; return x; }));

auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());

{
ASSERT_FALSE(called);
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
ASSERT_TRUE(called);
}

{
called = false;
// AutogradCPU is fallthrough, calls CPU kernel
callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
ASSERT_TRUE(called);
}

{
called = false;
callOp(*op, dummyTensor(c10::DispatchKey::XLA));
ASSERT_TRUE(called);
}

{
called = false;
// AutogradXLA is fallthrough, calls XLA kernel
callOp(*op, dummyTensor(c10::DispatchKey::XLA, /*requires_grad=*/true));
ASSERT_TRUE(called);
}

{
called = false;
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
ASSERT_TRUE(called);
}

{
called = false;
// AutogradCPU is fallthrough, calls CPU kernel
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
ASSERT_TRUE(called);
}
}

TEST(NewOperatorRegistrationTest, dispatchWithDefaultBackendAndMathKernel) {
bool backend_called = false;
bool math_called = false;
auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn", torch::dispatch(c10::DispatchKey::DefaultBackend, [&](const Tensor& x) { backend_called = true; return x; }));
m.impl("fn", c10::DispatchKey::Math, [&](const Tensor& x) { math_called = true; return x; });

auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());

{
backend_called = math_called = false;
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
ASSERT_TRUE(backend_called);
ASSERT_FALSE(math_called);
}

{
backend_called = math_called = false;
// AutogradCPU is fallthrough, calls CPU kernel
callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
ASSERT_FALSE(math_called);
ASSERT_TRUE(backend_called);
}

{
backend_called = math_called = false;
callOp(*op, dummyTensor(c10::DispatchKey::XLA));
ASSERT_TRUE(backend_called);
ASSERT_FALSE(math_called);
}

{
backend_called = math_called = false;
// AutogradXLA is fallthrough, calls XLA kernel
callOp(*op, dummyTensor(c10::DispatchKey::XLA, /*requires_grad=*/true));
ASSERT_FALSE(math_called);
ASSERT_TRUE(backend_called);
}

{
backend_called = math_called = false;
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU));
ASSERT_TRUE(backend_called);
ASSERT_FALSE(math_called);
}

{
backend_called = math_called = false;
// AutogradOther is fallthrough, calls SparseCPU kernel
callOp(*op, dummyTensor(c10::DispatchKey::SparseCPU, /*requires_grad=*/true));
ASSERT_FALSE(math_called);
ASSERT_TRUE(backend_called);
}
}

TEST(NewOperatorRegistrationTest, BackendOverridesDefaultBackendKernel) {
bool default_called = false;
bool backend_called = false;
auto m = MAKE_TORCH_LIBRARY(test);
m.def("fn", torch::dispatch(c10::DispatchKey::DefaultBackend, [&](const Tensor& x) { default_called = true; return x; }));
m.impl("fn", c10::DispatchKey::CPU, [&](const Tensor& x) { backend_called = true; return x; });

auto op = Dispatcher::singleton().findSchema({"test::fn", ""});
ASSERT_TRUE(op.has_value());

{
default_called = backend_called = false;
callOp(*op, dummyTensor(c10::DispatchKey::CPU));
ASSERT_TRUE(backend_called);
ASSERT_FALSE(default_called);
}

{
default_called = backend_called = false;
// AutogradCPU is fallthrough, calls CPU kernel
callOp(*op, dummyTensor(c10::DispatchKey::CPU, /*requires_grad=*/true));
ASSERT_TRUE(backend_called);
ASSERT_FALSE(default_called);
}

{
default_called = backend_called = false;
callOp(*op, dummyTensor(c10::DispatchKey::CUDA));
ASSERT_TRUE(default_called);
ASSERT_FALSE(backend_called);
}

{
default_called = backend_called = false;
// AutogradCUDA is fallthrough, calls CUDA kernel
callOp(*op, dummyTensor(c10::DispatchKey::CUDA, /*requires_grad=*/true));
ASSERT_TRUE(default_called);
ASSERT_FALSE(backend_called);
}
}


TEST(NewOperatorRegistrationTest, dispatch) {
bool cpu_called = false;
bool cuda_called = false;
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/native_functions.yaml
Expand Up @@ -284,9 +284,13 @@
- func: sgn(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
DefaultBackend: sgn

- func: sgn_(Tensor(a!) self) -> Tensor(a!)
variants: method
dispatch:
DefaultBackend: sgn_

- func: sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/quantized/cpu/qhardsigmoid.cpp
Expand Up @@ -43,9 +43,10 @@ Tensor qnnpack_hardsigmoid(Tensor input) {
"failed to create QNNPACK Hardsigmoid operator");
Tensor qy = at::_empty_affine_quantized(
input_contig.sizes(),
input_contig.options(),
at::device(kCPU).dtype(input_contig.dtype()),
o_scale,
o_zero_point);
o_zero_point,
input_contig.suggest_memory_format());

const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_hardsigmoid_nc_q8(
hardsigmoid_op,
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/quantized/cpu/qsigmoid.cpp
Expand Up @@ -48,9 +48,10 @@ Tensor qnnpack_sigmoid(
"failed to create QNNPACK sigmoid operator");
qy = at::_empty_affine_quantized(
input_contig.sizes(),
input.options(),
at::device(kCPU).dtype(input_contig.dtype()),
output_scale,
output_zero_point);
output_zero_point,
input_contig.suggest_memory_format());

const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_sigmoid_nc_q8(
sigmoid_op,
Expand Down
5 changes: 3 additions & 2 deletions aten/src/ATen/native/quantized/cpu/qtanh.cpp
Expand Up @@ -50,9 +50,10 @@ Tensor qnnpack_tanh(Tensor input) {
"failed to create QNNPACK TanH operator");
qy = at::_empty_affine_quantized(
input_contig.sizes(),
input.options(),
at::device(kCPU).dtype(input_contig.dtype()),
output_scale,
output_zero_point);
output_zero_point,
input_contig.suggest_memory_format());

const pytorch_qnnp_status setupStatus = pytorch_qnnp_setup_tanh_nc_q8(
tanh_op,
Expand Down

0 comments on commit dcf8140

Please sign in to comment.