diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 0b99b11430956..b9f8f5f6b14c6 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -321,6 +321,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { KERNEL_CUDA(__VA_ARGS__, lower_precision_fp) AT_FORALL_LOWER_PRECISION_FP(_KERNEL_CUDA_LOW_PRECISION_FP) +#undef _KERNEL_CUDA_LOW_PRECISION_FP KERNEL_CUDA(cudnn_convolution, lower_precision_fp) KERNEL_CUDA(cudnn_convolution_transpose, lower_precision_fp) @@ -328,12 +329,14 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { #define _KERNEL_CUDA_FP32(...) KERNEL_CUDA(__VA_ARGS__, fp32) AT_FORALL_FP32(_KERNEL_CUDA_FP32) +#undef _KERNEL_CUDA_FP32 // fp32_set_opt_dtype #define _KERNEL_CUDA_FP32_SET_OPT_DTYPE(...) \ KERNEL_CUDA(__VA_ARGS__, fp32_set_opt_dtype) AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_CUDA_FP32_SET_OPT_DTYPE) +#undef _KERNEL_CUDA_FP32_SET_OPT_DTYPE // commenting these out because they accept an explicit (not-optional) dtype, and we shouldn't try to flip that even // when autocasting. // KERNEL_CUDA(norm, ScalarOpt_dtype, fp32_set_opt_dtype) @@ -350,9 +353,9 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) { #define _KERNEL_CUDA_PROMOTE(...) KERNEL_CUDA(__VA_ARGS__, promote) AT_FORALL_PROMOTE(_KERNEL_CUDA_PROMOTE) +#undef _KERNEL_CUDA_PROMOTE - m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), - TORCH_FN((&at::autocast::binary_cross_entropy_banned))); + KERNEL_FN("binary_cross_entropy", &at::autocast::binary_cross_entropy_banned) } TORCH_LIBRARY_IMPL(_, AutocastCPU, m) { @@ -507,17 +510,20 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) { KERNEL_XPU(__VA_ARGS__, lower_precision_fp) AT_FORALL_LOWER_PRECISION_FP(_KERNEL_XPU_LOW_PRECISION_FP) +#undef _KERNEL_XPU_LOW_PRECISION_FP // fp32 #define _KERNEL_XPU_FP32(...) KERNEL_XPU(__VA_ARGS__, fp32) AT_FORALL_FP32(_KERNEL_XPU_FP32) +#undef _KERNEL_XPU_FP32 // fp32_set_opt_dtype #define _KERNEL_XPU_FP32_SET_OPT_DTYPE(...) \ KERNEL_XPU(__VA_ARGS__, fp32_set_opt_dtype) AT_FORALL_FP32_SET_OPT_DTYPE(_KERNEL_XPU_FP32_SET_OPT_DTYPE) +#undef _KERNEL_XPU_FP32_SET_OPT_DTYPE // fp32_append_dtype // The fp32_append_dtype wrapper overrides implicit promotion behavior. @@ -529,9 +535,9 @@ TORCH_LIBRARY_IMPL(aten, AutocastXPU, m) { #define _KERNEL_XPU_PROMOTE(...) KERNEL_XPU(__VA_ARGS__, promote) AT_FORALL_PROMOTE(_KERNEL_XPU_PROMOTE) +#undef _KERNEL_XPU_PROMOTE - m.impl(TORCH_SELECTIVE_NAME("aten::binary_cross_entropy"), - TORCH_FN((&at::autocast::binary_cross_entropy_banned))); + KERNEL_FN("binary_cross_entropy", &at::autocast::binary_cross_entropy_banned) } } // namespace diff --git a/aten/src/ATen/autocast_mode.h b/aten/src/ATen/autocast_mode.h index de072617c0f2c..f0e0f33b7f67a 100644 --- a/aten/src/ATen/autocast_mode.h +++ b/aten/src/ATen/autocast_mode.h @@ -623,6 +623,9 @@ copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. #define _KERNEL_OVERLOAD_NARG(...) \ C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1)) +#define KERNEL_FN(OP, Function) \ + m.impl(TORCH_SELECTIVE_NAME("aten::" OP), TORCH_FN(Function)); + // Common cases where registration signature matches redispatch signature // (that's why SIGNATURE is repeated in the WrapFunction instantiation) #define KERNEL1(DISPATCHKEY, OP, POLICY) \ diff --git a/c10/test/core/DispatchKeySet_test.cpp b/c10/test/core/DispatchKeySet_test.cpp index 7877cc76fbba3..529187fec93db 100644 --- a/c10/test/core/DispatchKeySet_test.cpp +++ b/c10/test/core/DispatchKeySet_test.cpp @@ -9,6 +9,20 @@ using namespace c10; +static bool isRealDispatchKey(DispatchKey k) { + if (k == DispatchKey::EndOfFunctionalityKeys || + k == DispatchKey::StartOfDenseBackends || + k == DispatchKey::StartOfQuantizedBackends || + k == DispatchKey::StartOfSparseBackends || + k == DispatchKey::StartOfSparseCsrBackends || + k == DispatchKey::StartOfNestedTensorBackends || + k == DispatchKey::StartOfAutogradFunctionalityBackends) { + return false; + } + + return true; +} + // This test exists not to be comprehensive, but to more clearly show // what the semantics of DispatchKeySet are. TEST(DispatchKeySet, ShowSemantics) { @@ -179,10 +193,7 @@ TEST(DispatchKeySet, SingletonPerBackendFunctionalityKeys) { i++) { auto tid = static_cast(i); // Skip these because they aren't real keys. - if (tid == DispatchKey::StartOfDenseBackends || - tid == DispatchKey::StartOfSparseBackends || - tid == DispatchKey::StartOfQuantizedBackends || - tid == DispatchKey::StartOfAutogradFunctionalityBackends) { + if (isRealDispatchKey(tid)) { continue; } DispatchKeySet sing(tid); @@ -221,20 +232,9 @@ TEST(DispatchKeySet, DoubletonPerBackend) { auto tid2 = static_cast(j); // Skip these because they aren't real keys. - if (tid1 == DispatchKey::StartOfDenseBackends || - tid1 == DispatchKey::StartOfSparseBackends || - tid1 == DispatchKey::StartOfSparseCsrBackends || - tid1 == DispatchKey::StartOfQuantizedBackends || - tid1 == DispatchKey::StartOfNestedTensorBackends || - tid1 == DispatchKey::StartOfAutogradFunctionalityBackends) - continue; - if (tid2 == DispatchKey::StartOfDenseBackends || - tid2 == DispatchKey::StartOfSparseBackends || - tid2 == DispatchKey::StartOfSparseCsrBackends || - tid2 == DispatchKey::StartOfQuantizedBackends || - tid2 == DispatchKey::StartOfNestedTensorBackends || - tid2 == DispatchKey::StartOfAutogradFunctionalityBackends) + if (!isRealDispatchKey(tid1) || !isRealDispatchKey(tid2)) { continue; + } auto backend1 = toBackendComponent(tid1); auto backend2 = toBackendComponent(tid2); @@ -421,14 +421,9 @@ TEST(DispatchKeySet, TestFunctionalityDispatchKeyToString) { auto k = static_cast(i); // These synthetic keys never actually get used and don't need // to be printed - if (k == DispatchKey::EndOfFunctionalityKeys || - k == DispatchKey::StartOfDenseBackends || - k == DispatchKey::StartOfQuantizedBackends || - k == DispatchKey::StartOfSparseBackends || - k == DispatchKey::StartOfSparseCsrBackends || - k == DispatchKey::StartOfNestedTensorBackends || - k == DispatchKey::StartOfAutogradFunctionalityBackends) + if (!isRealDispatchKey(k)) { continue; + } auto res = std::string(toString(k)); ASSERT_TRUE(res.find("Unknown") == std::string::npos) << i << " (before is " << toString(static_cast(i - 1)) diff --git a/test/test_ops.py b/test/test_ops.py index 44f503ae9b6ed..452fabcc310ac 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2594,6 +2594,8 @@ def test_fake(self, device, dtype, op): @ops(op_db, dtypes=OpDTypes.any_one) def test_fake_autocast(self, device, dtype, op): + # remove the index from the device, first + device = device.split(":")[0] if op.name in fake_autocast_device_skips[device]: self.skipTest("Skip failing test") context = (