diff --git a/aten/src/ATen/native/Lerp.h b/aten/src/ATen/native/Lerp.h index f24032f5e38d8..c1784ae16f319 100644 --- a/aten/src/ATen/native/Lerp.h +++ b/aten/src/ATen/native/Lerp.h @@ -1,12 +1,39 @@ #pragma once #include +#include #include #include namespace at { namespace native { +template +C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) { + return std::abs(weight) < scalar_t(0.5); +} +template +C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex weight) { + // Avoid the sqrt in abs(weight) + return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25); +} + +template +C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) { + using opmath_t = at::opmath_type; + using opmath_weight_t = at::opmath_type; + + opmath_t self = self_; + opmath_t end = end_; + opmath_weight_t weight = weight_; + + // Conditional for better numeric. This has been discussed in + // https://github.com/pytorch/pytorch/pull/18871 + return is_lerp_weight_small(weight) + ? self + weight * (end - self) + : end - (end - self) * (opmath_t(1) - weight); +} + using lerp_fn_scalar = void (*)( at::TensorIteratorBase& iter, const Scalar& weight); diff --git a/aten/src/ATen/native/cpu/LerpKernel.cpp b/aten/src/ATen/native/cpu/LerpKernel.cpp index 04e3a23721a74..4569e1450f05b 100644 --- a/aten/src/ATen/native/cpu/LerpKernel.cpp +++ b/aten/src/ATen/native/cpu/LerpKernel.cpp @@ -19,9 +19,7 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const Scalar& weight) { at::native::cpu_kernel_vec( iter, [weight_val](BFloat16 self_val, BFloat16 end_val) -> BFloat16 { - return (weight_val < 0.5) - ? float(self_val) + weight_val * (float(end_val) - float(self_val)) - : float(end_val) - (float(end_val) - float(self_val)) * (float(1) - weight_val); + return lerp(self_val, end_val, weight_val); }, [=](bVec self_vec, bVec end_vec) -> bVec { fVec self_vec0, self_vec1, end_vec0, end_vec1; @@ -39,15 +37,12 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const Scalar& weight) { }); } else { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_scalar", [&] { - using value_t = typename c10::scalar_value_type::type; - scalar_t weight_val = weight.to(); - at::native::cpu_kernel( - iter, - [weight_val](scalar_t self_val, scalar_t end_val) { - return (zabs(weight_val) < 0.5) - ? self_val + weight_val * (end_val - self_val) - : end_val - (end_val - self_val) * (scalar_t(1) - weight_val); - }); + auto weight_val = weight.to(); + at::native::cpu_kernel( + iter, + [weight_val](scalar_t self_val, scalar_t end_val) { + return lerp(self_val, end_val, weight_val); + }); }); } } @@ -61,9 +56,7 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) { at::native::cpu_kernel_vec( iter, [=](BFloat16 self_val, BFloat16 end_val, BFloat16 weight_val) -> BFloat16 { - return (weight_val < 0.5) - ? float(self_val) + weight_val * (float(end_val) - float(self_val)) - : float(end_val) - (float(end_val) - float(self_val)) * (float(1) - weight_val); + return lerp(self_val, end_val, weight_val); }, [=](bVec self_vec, bVec end_vec, bVec weight_vec) -> bVec { fVec self_vec0, self_vec1, end_vec0, end_vec1, weight_vec0, weight_vec1; @@ -82,15 +75,12 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) { }); } else { AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_tensor", [&] { - using value_t = typename c10::scalar_value_type::type; - at::native::cpu_kernel( - iter, - [](scalar_t self_val, scalar_t end_val, scalar_t weight_val) { - return (zabs(weight_val) < 0.5) - ? self_val + weight_val * (end_val - self_val) - : end_val - (end_val - self_val) * (scalar_t(1) - weight_val); - }); - }); + at::native::cpu_kernel( + iter, + [](scalar_t self_val, scalar_t end_val, scalar_t weight_val) { + return lerp(self_val, end_val, weight_val); + }); + }); } } diff --git a/aten/src/ATen/native/cuda/Lerp.cu b/aten/src/ATen/native/cuda/Lerp.cu index ac1f2ba379b53..c1adb5b6fc030 100644 --- a/aten/src/ATen/native/cuda/Lerp.cu +++ b/aten/src/ATen/native/cuda/Lerp.cu @@ -14,23 +14,13 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) { at::ScalarType::Half, at::ScalarType::BFloat16, iter.common_dtype(), "lerp_cuda", [&] { - using opmath_t = at::opmath_type; at::native::gpu_kernel( iter, [] GPU_LAMBDA( scalar_t self_val, scalar_t end_val, scalar_t weight_val) -> scalar_t { - opmath_t self_val_f = self_val; - opmath_t end_val_f = end_val; - opmath_t weight_val_f = weight_val; - // Conditional for better numeric. This has been discussed in - // https://github.com/pytorch/pytorch/pull/18871 - return (std::abs(weight_val_f) < 0.5) - ? self_val_f + weight_val_f * (end_val_f - self_val_f) - : end_val_f - - (end_val_f - self_val_f) * - (opmath_t{1} - weight_val_f); + return lerp(self_val, end_val, weight_val); }); }); } @@ -44,14 +34,7 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight) auto weight_val = weight.to(); at::native::gpu_kernel( iter, [=] GPU_LAMBDA(scalar_t self_val, scalar_t end_val) { - opmath_t self_val_f = self_val; - opmath_t end_val_f = end_val; - // Conditional for better numeric. This has been discussed in - // https://github.com/pytorch/pytorch/pull/18871 - return (std::abs(weight_val) < 0.5) - ? self_val_f + weight_val * (end_val_f - self_val_f) - : end_val_f - - (end_val_f - self_val_f) * (opmath_t{1} - weight_val); + return lerp(self_val, end_val, weight_val); }); }); }