Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7b2cf2c
[skip ci] compilation fails with error: expected unqualified-id befor…
sanchitintel Jun 25, 2023
7544939
Duplicate macros for nullptr dispatch
sanchitintel Jun 26, 2023
b856ebe
Disable more AVX512 kernels
sanchitintel Jun 26, 2023
697173f
Disable more AVX512 kernels
sanchitintel Jun 27, 2023
2aaf056
Merge branch 'main' of https://github.com/pytorch/pytorch into sanchi…
sanchitintel Jun 27, 2023
621804d
Enable AVX512 by default
sanchitintel Jun 27, 2023
4e691bd
Disable more AVX512 kernels
sanchitintel Jun 28, 2023
218386b
Reduce changes by using a new macro ALSO_REGISTER_AVX512_DISPATCH
sanchitintel Jul 20, 2023
c8c266d
Merge branch 'pytorch:main' into sanchitintel/reenable_avx512
sanchitintel Jul 20, 2023
27bab64
Delete activation_benchmark.py as it should be in another PR
sanchitintel Jul 20, 2023
d2ea43d
Disable AVX512 dispatch for fmod & remainder
sanchitintel Jul 21, 2023
ea34462
Some Distributions' kernels have not been vectorized with Vec
sanchitintel Aug 1, 2023
974e4bc
Merge branch 'pytorch:main' into sanchitintel/reenable_avx512
sanchitintel Aug 1, 2023
9dfab67
Remove AVX512 dispatch for some kernels
sanchitintel Aug 3, 2023
2e3704f
Fix lint & refactor to reduce no. of lines
sanchitintel Aug 3, 2023
15def40
Merge branch 'pytorch:main' into sanchitintel/reenable_avx512
sanchitintel Aug 9, 2023
29c1d4a
Being more conservative in enabling AVX512 dispatch
sanchitintel Aug 9, 2023
f20d940
Disable some AVX512 kernels
sanchitintel Aug 13, 2023
5e89355
Remove AVX512 dispatch for nan_to_num
sanchitintel Aug 14, 2023
ce922b6
[skip ci] Disable AVX512 dispatch for unvectorized unary ops' kernels
sanchitintel Aug 21, 2023
8d5d5ed
Disable AVX512 dispatch for unvectorized binary ops' kernels
sanchitintel Aug 21, 2023
a4ffdca
Merge branch 'pytorch:main' into sanchitintel/reenable_avx512
sanchitintel Aug 21, 2023
27f35d5
Revise comments
sanchitintel Aug 21, 2023
e9eed90
Fix typo in comment
sanchitintel Aug 21, 2023
d5002f4
Add AVX512 dispatch for FlashAttentionKernel
sanchitintel Aug 21, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions aten/src/ATen/native/DispatchStub.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ static CPUCapability compute_cpu_capability() {

#if !defined(__powerpc__) && !defined(__s390x__)
if (cpuinfo_initialize()) {
// AVX512 can be slower then AVX2, so lets keep it as opt-in
// see https://github.com/pytorch/pytorch/issues/80252
#if defined(HAVE_AVX512_CPU_DEFINITION) && false
#if defined(HAVE_AVX512_CPU_DEFINITION)
// GCC supports some AVX512 intrinsics such as _mm512_set_epi16 only in
// versions 9 & beyond. So, we want to ensure that only releases built with
// supported compilers on supported hardware return CPU Capability AVX512,
Expand Down
11 changes: 7 additions & 4 deletions aten/src/ATen/native/DispatchStub.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,12 +301,15 @@ struct RegisterPRIVATEUSE1Dispatch {
// NB: this macro must be used from a 'mm' file in order to dispatch a MPS kernel
#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
#elif defined(CPU_CAPABILITY)
// REGISTER_DISPATCH now dispatches an AVX512 kernel to nullptr but registers other dispatches.
// ALSO_REGISTER_AVX512_DISPATCH should be used for ensuring AVX512 dispatch, among others.
#ifdef CPU_CAPABILITY_AVX512
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, nullptr)
#else
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#define REGISTER_NO_AVX512_DISPATCH(name) \
REGISTER_AVX512_DISPATCH(name, nullptr)
#endif


#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
#endif
} // namespace at::native

C10_CLANG_DIAGNOSTIC_POP()
46 changes: 24 additions & 22 deletions aten/src/ATen/native/cpu/Activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1430,33 +1430,35 @@ void prelu_backward_kernel(TensorIterator& iter) {

} // namespace

REGISTER_DISPATCH(log_sigmoid_cpu_stub, &log_sigmoid_cpu_kernel);
REGISTER_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_cpu_kernel);
REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
REGISTER_DISPATCH(elu_stub, &elu_kernel);
REGISTER_DISPATCH(elu_backward_stub, &elu_backward_kernel);
REGISTER_DISPATCH(GeluKernel, &GeluKernelImpl);
REGISTER_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel);

REGISTER_DISPATCH(hardsigmoid_stub, &hardsigmoid_kernel);
REGISTER_DISPATCH(hardsigmoid_backward_stub, &hardsigmoid_backward_kernel);
REGISTER_DISPATCH(hardswish_stub, &hardswish_kernel);
REGISTER_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel);
REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel);
REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel);
REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel);
REGISTER_DISPATCH(threshold_stub, &threshold_kernel);
REGISTER_DISPATCH(leaky_relu_stub, &leaky_relu_kernel);
REGISTER_DISPATCH(leaky_relu_backward_stub, &leaky_relu_backward_kernel);
REGISTER_DISPATCH(softplus_stub, &softplus_kernel);
REGISTER_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
REGISTER_DISPATCH(glu_stub, &glu_kernel);
REGISTER_DISPATCH(glu_backward_stub, &glu_backward_kernel);
REGISTER_DISPATCH(glu_jvp_stub, &glu_jvp_kernel);
REGISTER_DISPATCH(silu_stub, &silu_kernel);
REGISTER_DISPATCH(silu_backward_stub, &silu_backward_kernel);
REGISTER_DISPATCH(mish_stub, &mish_kernel);
REGISTER_DISPATCH(mish_backward_stub, &mish_backward_kernel);
REGISTER_DISPATCH(prelu_stub, &prelu_kernel);
REGISTER_DISPATCH(prelu_backward_stub, &prelu_backward_kernel);
REGISTER_DISPATCH(hardtanh_backward_stub, &hardtanh_backward_kernel);
REGISTER_DISPATCH(hardshrink_stub, &hardshrink_kernel);
REGISTER_DISPATCH(softshrink_stub, &softshrink_kernel);
REGISTER_DISPATCH(shrink_backward_stub, &shrink_backward_kernel);

ALSO_REGISTER_AVX512_DISPATCH(log_sigmoid_cpu_stub, &log_sigmoid_cpu_kernel);
ALSO_REGISTER_AVX512_DISPATCH(log_sigmoid_backward_stub, &log_sigmoid_backward_cpu_kernel);
ALSO_REGISTER_AVX512_DISPATCH(glu_stub, &glu_kernel);
ALSO_REGISTER_AVX512_DISPATCH(glu_backward_stub, &glu_backward_kernel);
ALSO_REGISTER_AVX512_DISPATCH(glu_jvp_stub, &glu_jvp_kernel);
ALSO_REGISTER_AVX512_DISPATCH(elu_stub, &elu_kernel);
ALSO_REGISTER_AVX512_DISPATCH(elu_backward_stub, &elu_backward_kernel);
ALSO_REGISTER_AVX512_DISPATCH(GeluKernel, &GeluKernelImpl);
ALSO_REGISTER_AVX512_DISPATCH(GeluBackwardKernel, &GeluBackwardKernelImpl);
ALSO_REGISTER_AVX512_DISPATCH(hardswish_stub, &hardswish_kernel);
ALSO_REGISTER_AVX512_DISPATCH(hardswish_backward_stub, &hardswish_backward_kernel);
ALSO_REGISTER_AVX512_DISPATCH(softplus_stub, &softplus_kernel);
ALSO_REGISTER_AVX512_DISPATCH(softplus_backward_stub, &softplus_backward_kernel);
ALSO_REGISTER_AVX512_DISPATCH(silu_stub, &silu_kernel);
ALSO_REGISTER_AVX512_DISPATCH(silu_backward_stub, &silu_backward_kernel);
ALSO_REGISTER_AVX512_DISPATCH(mish_stub, &mish_kernel);
ALSO_REGISTER_AVX512_DISPATCH(mish_backward_stub, &mish_backward_kernel);

} // namespace at::native
40 changes: 21 additions & 19 deletions aten/src/ATen/native/cpu/BinaryOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,8 +1296,6 @@ REGISTER_DISPATCH(mul_stub, &mul_kernel);
REGISTER_DISPATCH(div_true_stub, &div_true_kernel);
REGISTER_DISPATCH(div_trunc_stub, &div_trunc_kernel);
REGISTER_DISPATCH(div_floor_stub, &div_floor_kernel);
REGISTER_DISPATCH(remainder_stub, &remainder_kernel);
REGISTER_DISPATCH(atan2_stub, &atan2_kernel);
REGISTER_DISPATCH(bitwise_and_stub, &bitwise_and_kernel);
REGISTER_DISPATCH(bitwise_or_stub, &bitwise_or_kernel);
REGISTER_DISPATCH(bitwise_xor_stub, &bitwise_xor_kernel);
Expand All @@ -1316,37 +1314,41 @@ REGISTER_DISPATCH(maximum_stub, &maximum_kernel);
REGISTER_DISPATCH(minimum_stub, &minimum_kernel);
REGISTER_DISPATCH(fmax_stub, &fmax_kernel);
REGISTER_DISPATCH(fmin_stub, &fmin_kernel);
REGISTER_DISPATCH(smooth_l1_stub, &smooth_l1_kernel);
REGISTER_DISPATCH(huber_stub, &huber_kernel);
REGISTER_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel);
REGISTER_DISPATCH(logit_backward_stub, &logit_backward_kernel);
REGISTER_DISPATCH(tanh_backward_stub, &tanh_backward_kernel);
REGISTER_DISPATCH(mse_stub, &mse_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);
REGISTER_DISPATCH(remainder_stub, &remainder_kernel);
REGISTER_DISPATCH(fmod_stub, &fmod_kernel);
REGISTER_DISPATCH(logaddexp_stub, &logaddexp_kernel);
REGISTER_DISPATCH(logaddexp2_stub, &logaddexp2_kernel);
REGISTER_DISPATCH(gcd_stub, &gcd_kernel);
REGISTER_DISPATCH(lcm_stub, &lcm_kernel);
REGISTER_DISPATCH(hypot_stub, &hypot_kernel);
REGISTER_DISPATCH(igamma_stub, &igamma_kernel);
REGISTER_DISPATCH(igammac_stub, &igammac_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(copysign_stub, &copysign_kernel);
REGISTER_DISPATCH(xlogy_stub, &xlogy_kernel);
REGISTER_DISPATCH(xlog1py_stub, &xlog1py_kernel);
REGISTER_DISPATCH(zeta_stub, &zeta_kernel);
REGISTER_DISPATCH(nextafter_stub, &nextafter_kernel);
REGISTER_DISPATCH(heaviside_stub, &heaviside_kernel);
REGISTER_DISPATCH(chebyshev_polynomial_t_stub, &chebyshev_polynomial_t_kernel);
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_kernel);
REGISTER_DISPATCH(chebyshev_polynomial_v_stub, &chebyshev_polynomial_v_kernel);
REGISTER_DISPATCH(chebyshev_polynomial_w_stub, &chebyshev_polynomial_w_kernel);
REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel);
REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_kernel);
REGISTER_DISPATCH(laguerre_polynomial_l_stub, &laguerre_polynomial_l_kernel);
REGISTER_DISPATCH(legendre_polynomial_p_stub, &legendre_polynomial_p_kernel);
REGISTER_DISPATCH(shifted_chebyshev_polynomial_t_stub, &shifted_chebyshev_polynomial_t_kernel);
REGISTER_DISPATCH(shifted_chebyshev_polynomial_u_stub, &shifted_chebyshev_polynomial_u_kernel);
REGISTER_DISPATCH(shifted_chebyshev_polynomial_v_stub, &shifted_chebyshev_polynomial_v_kernel);
REGISTER_DISPATCH(shifted_chebyshev_polynomial_w_stub, &shifted_chebyshev_polynomial_w_kernel);
// Might enable AVX512 dispatch after enabling explicit vectorization for them.
REGISTER_DISPATCH(chebyshev_polynomial_u_stub, &chebyshev_polynomial_u_kernel);
REGISTER_DISPATCH(hermite_polynomial_h_stub, &hermite_polynomial_h_kernel);
REGISTER_DISPATCH(hermite_polynomial_he_stub, &hermite_polynomial_he_kernel);

ALSO_REGISTER_AVX512_DISPATCH(atan2_stub, &atan2_kernel);
ALSO_REGISTER_AVX512_DISPATCH(smooth_l1_stub, &smooth_l1_kernel);
ALSO_REGISTER_AVX512_DISPATCH(huber_stub, &huber_kernel);
ALSO_REGISTER_AVX512_DISPATCH(sigmoid_backward_stub, &sigmoid_backward_kernel);
ALSO_REGISTER_AVX512_DISPATCH(logit_backward_stub, &logit_backward_kernel);
ALSO_REGISTER_AVX512_DISPATCH(tanh_backward_stub, &tanh_backward_kernel);
ALSO_REGISTER_AVX512_DISPATCH(mse_stub, &mse_kernel);
ALSO_REGISTER_AVX512_DISPATCH(logaddexp_stub, &logaddexp_kernel);
ALSO_REGISTER_AVX512_DISPATCH(logaddexp2_stub, &logaddexp2_kernel);
ALSO_REGISTER_AVX512_DISPATCH(hypot_stub, &hypot_kernel);
ALSO_REGISTER_AVX512_DISPATCH(igamma_stub, &igamma_kernel);
ALSO_REGISTER_AVX512_DISPATCH(igammac_stub, &igammac_kernel);

} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/ComplexKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@ void polar_kernel(TensorIterator& iter) {
} // anonymous namespace

REGISTER_DISPATCH(complex_stub, &complex_kernel);
REGISTER_DISPATCH(polar_stub, &polar_kernel);
ALSO_REGISTER_AVX512_DISPATCH(polar_stub, &polar_kernel);

} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cpu/DepthwiseConvKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,6 @@ Tensor _convolution_depthwise3x3_winograd(

} // namespace

REGISTER_DISPATCH(convolution_depthwise3x3_winograd_stub, &_convolution_depthwise3x3_winograd);
ALSO_REGISTER_AVX512_DISPATCH(convolution_depthwise3x3_winograd_stub, &_convolution_depthwise3x3_winograd);

} // namespace at::native
6 changes: 0 additions & 6 deletions aten/src/ATen/native/cpu/DistributionKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,7 @@ REGISTER_DISPATCH(cauchy_stub, &cauchy_kernel);
REGISTER_DISPATCH(exponential_stub, &exponential_kernel);
REGISTER_DISPATCH(geometric_stub, &geometric_kernel);
REGISTER_DISPATCH(log_normal_stub, &log_normal_kernel);
#ifdef CPU_CAPABILITY_AVX512
// normal_stub isn't being dispatched to AVX512 because it exposes
// flakiness in test_sgd of test/optim/test_optim.py
REGISTER_NO_AVX512_DISPATCH(normal_stub);
#else
REGISTER_DISPATCH(normal_stub, &normal_kernel);
#endif
REGISTER_DISPATCH(uniform_stub, &uniform_kernel);
REGISTER_DISPATCH(random_from_to_stub, &random_from_to_kernel);
REGISTER_DISPATCH(random_full_64_bits_range_stub, &random_full_64_bits_range_kernel);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/FlashAttentionKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ void flash_attention_backward_kernel_impl(

} // anonymous namespace

REGISTER_DISPATCH(flash_attention_kernel, &flash_attention_kernel_impl);
REGISTER_DISPATCH(flash_attention_backward_kernel, &flash_attention_backward_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(flash_attention_kernel, &flash_attention_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(flash_attention_backward_kernel, &flash_attention_backward_kernel_impl);

} // at::native
4 changes: 2 additions & 2 deletions aten/src/ATen/native/cpu/PowKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ static void pow_tensor_scalar_kernel(

} // anonymous namespace

REGISTER_DISPATCH(pow_tensor_tensor_stub, &CPU_CAPABILITY::pow_tensor_tensor_kernel);
REGISTER_DISPATCH(pow_tensor_scalar_stub, &CPU_CAPABILITY::pow_tensor_scalar_kernel);
ALSO_REGISTER_AVX512_DISPATCH(pow_tensor_tensor_stub, &CPU_CAPABILITY::pow_tensor_tensor_kernel);
ALSO_REGISTER_AVX512_DISPATCH(pow_tensor_scalar_stub, &CPU_CAPABILITY::pow_tensor_scalar_kernel);

} // namespace at::native
7 changes: 6 additions & 1 deletion aten/src/ATen/native/cpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,12 @@ three steps:

4. Write your actual kernel (e.g., `your_kernel`) in the
cpu directory, and register it to
the dispatch using `REGISTER_DISPATCH(fnNameImpl, &your_kernel)`.
the dispatch using `REGISTER_DISPATCH(fnNameImpl, &your_kernel)`, if
it does not perform as well with AVX512, as it does with AVX2.
Otherwise, if it performs well with AVX512, register it with `ALSO_REGISTER_AVX512_DISPATCH(fnNameImpl, &your_kernel)`.
Compute-intensive kernels tend to perform better with AVX512, than with AVX2.
Comparing AVX2 & AVX512 variants of a kernel can be done by registering a kernel with `ALSO_REGISTER_AVX512_DISPATCH(fnNameImpl, &your_kernel)`, building from source, and then benchmarking the kernel's performance by running a benchmarking script with the environment variables `ATEN_CPU_CAPABILITY=avx2` and `ATEN_CPU_CAPABILITY=avx512`, respectively.
tcmalloc/jemalloc can be preloaded for minimal run-to-run variation.

There are plenty of existing examples, look at them for more details.

Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/cpu/ReduceOpsKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ REGISTER_DISPATCH(min_values_stub, &min_values_kernel_impl);
REGISTER_DISPATCH(max_values_stub, &max_values_kernel_impl);
REGISTER_DISPATCH(argmax_stub, &argmax_kernel_impl);
REGISTER_DISPATCH(argmin_stub, &argmin_kernel_impl);

REGISTER_DISPATCH(cumprod_stub, &cumprod_cpu_kernel);
REGISTER_DISPATCH(cumsum_stub, &cumsum_cpu_kernel);
REGISTER_DISPATCH(logcumsumexp_stub, &logcumsumexp_cpu_kernel);
Expand Down
16 changes: 8 additions & 8 deletions aten/src/ATen/native/cpu/SoftMaxKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1279,19 +1279,19 @@ static void log_softmax_backward_kernel_impl(

} // anonymous namespace

REGISTER_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_impl);
REGISTER_DISPATCH(log_softmax_lastdim_kernel, &log_softmax_lastdim_kernel_impl);
REGISTER_DISPATCH(
ALSO_REGISTER_AVX512_DISPATCH(softmax_lastdim_kernel, &softmax_lastdim_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(log_softmax_lastdim_kernel, &log_softmax_lastdim_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(
softmax_backward_lastdim_kernel,
&softmax_backward_lastdim_kernel_impl);
REGISTER_DISPATCH(
ALSO_REGISTER_AVX512_DISPATCH(
log_softmax_backward_lastdim_kernel,
&log_softmax_backward_lastdim_kernel_impl);

REGISTER_DISPATCH(softmax_kernel, &softmax_kernel_impl);
REGISTER_DISPATCH(log_softmax_kernel, &log_softmax_kernel_impl);
REGISTER_DISPATCH(softmax_backward_kernel, &softmax_backward_kernel_impl);
REGISTER_DISPATCH(
ALSO_REGISTER_AVX512_DISPATCH(softmax_kernel, &softmax_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(log_softmax_kernel, &log_softmax_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(softmax_backward_kernel, &softmax_backward_kernel_impl);
ALSO_REGISTER_AVX512_DISPATCH(
log_softmax_backward_kernel,
&log_softmax_backward_kernel_impl);
} // namespace at::native
8 changes: 2 additions & 6 deletions aten/src/ATen/native/cpu/SumKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -629,14 +629,10 @@ void nansum_kernel_impl(TensorIterator &iter) {

} // namespace (anonymous)

REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);

// nansum on Float16 has poor accuracy with AVX2, and more so with AVX512.
// So until it's fixed, it won't be dispatched with AVX512. GH issue 59415.
#ifndef CPU_CAPABILITY_AVX512
// Besides, these kernels are slower with AVX512 than with AVX2.
REGISTER_DISPATCH(nansum_stub, &nansum_kernel_impl);
#else
REGISTER_NO_AVX512_DISPATCH(nansum_stub);
#endif
REGISTER_DISPATCH(sum_stub, &sum_kernel_impl);

} // namespace at::native
Loading