Skip to content

Commit

Permalink
Improve complex lerp performance
Browse files Browse the repository at this point in the history
The complex lerp kernel uses `std::abs(z) < 0.5` which involves
computing a sqrt. Instead compare the square against 0.25 has much
lower latency and so performs much better overall.

In a simple timeit benchmark I see more than 10x speedup on CPU for a 4096
element complex lerp, from 84 us to 6.7 us.

ghstack-source-id: 3fd059b6e41f541a6a48b26d2c87e67c01fe236d
Pull Request resolved: pytorch#84844
  • Loading branch information
peterbell10 committed Sep 23, 2022
1 parent e9d05f8 commit 6aa03e2
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 28 deletions.
27 changes: 27 additions & 0 deletions aten/src/ATen/native/Lerp.h
Original file line number Diff line number Diff line change
@@ -1,12 +1,39 @@
#pragma once

#include <ATen/native/DispatchStub.h>
#include <ATen/OpMathType.h>
#include <ATen/TensorIterator.h>
#include <c10/core/Scalar.h>

namespace at {
namespace native {

template <typename scalar_t>
C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(scalar_t weight) {
return std::abs(weight) < scalar_t(0.5);
}
template <typename scalar_t>
C10_HOST_DEVICE C10_ALWAYS_INLINE bool is_lerp_weight_small(c10::complex<scalar_t> weight) {
// Avoid the sqrt in abs(weight)
return (weight.real() * weight.real() + weight.imag() * weight.imag()) < scalar_t(0.25);
}

template <typename scalar_t, typename weight_t>
C10_HOST_DEVICE C10_ALWAYS_INLINE scalar_t lerp(scalar_t self_, scalar_t end_, weight_t weight_) {
using opmath_t = at::opmath_type<scalar_t>;
using opmath_weight_t = at::opmath_type<weight_t>;

opmath_t self = self_;
opmath_t end = end_;
opmath_weight_t weight = weight_;

// Conditional for better numeric. This has been discussed in
// https://github.com/pytorch/pytorch/pull/18871
return is_lerp_weight_small(weight)
? self + weight * (end - self)
: end - (end - self) * (opmath_t(1) - weight);
}

using lerp_fn_scalar = void (*)(
at::TensorIteratorBase& iter,
const Scalar& weight);
Expand Down
13 changes: 4 additions & 9 deletions aten/src/ATen/native/cpu/LerpKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,34 +3,29 @@
#include <ATen/Dispatch.h>
#include <ATen/TensorIterator.h>
#include <ATen/native/cpu/Loops.h>
#include <c10/util/irange.h>

namespace at {
namespace native {
namespace {

void lerp_scalar_kernel(at::TensorIteratorBase& iter, const Scalar& weight) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_scalar", [&] {
using value_t = typename c10::scalar_value_type<scalar_t>::type;
scalar_t weight_val = weight.to<scalar_t>();
auto weight_val = weight.to<scalar_t>();
at::native::cpu_kernel(
iter,
[weight_val](scalar_t self_val, scalar_t end_val) {
return (zabs<scalar_t, value_t>(weight_val) < 0.5)
? self_val + weight_val * (end_val - self_val)
: end_val - (end_val - self_val) * (scalar_t(1) - weight_val);
return lerp(self_val, end_val, weight_val);
});
});
}

void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(iter.common_dtype(), "lerp_kernel_tensor", [&] {
using value_t = typename c10::scalar_value_type<scalar_t>::type;
at::native::cpu_kernel(
iter,
[](scalar_t self_val, scalar_t end_val, scalar_t weight_val) {
return (zabs<scalar_t, value_t>(weight_val) < 0.5)
? self_val + weight_val * (end_val - self_val)
: end_val - (end_val - self_val) * (scalar_t(1) - weight_val);
return lerp(self_val, end_val, weight_val);
});
});
}
Expand Down
21 changes: 2 additions & 19 deletions aten/src/ATen/native/cuda/Lerp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,23 +14,13 @@ void lerp_tensor_kernel(at::TensorIteratorBase& iter) {
at::ScalarType::Half, at::ScalarType::BFloat16,
iter.common_dtype(), "lerp_cuda",
[&] {
using opmath_t = at::opmath_type<scalar_t>;
at::native::gpu_kernel(
iter,
[] GPU_LAMBDA(
scalar_t self_val,
scalar_t end_val,
scalar_t weight_val) -> scalar_t {
opmath_t self_val_f = self_val;
opmath_t end_val_f = end_val;
opmath_t weight_val_f = weight_val;
// Conditional for better numeric. This has been discussed in
// https://github.com/pytorch/pytorch/pull/18871
return (std::abs(weight_val_f) < 0.5)
? self_val_f + weight_val_f * (end_val_f - self_val_f)
: end_val_f -
(end_val_f - self_val_f) *
(opmath_t{1} - weight_val_f);
return lerp(self_val, end_val, weight_val);
});
});
}
Expand All @@ -44,14 +34,7 @@ void lerp_scalar_kernel(at::TensorIteratorBase& iter, const c10::Scalar& weight)
auto weight_val = weight.to<opmath_t>();
at::native::gpu_kernel(
iter, [=] GPU_LAMBDA(scalar_t self_val, scalar_t end_val) {
opmath_t self_val_f = self_val;
opmath_t end_val_f = end_val;
// Conditional for better numeric. This has been discussed in
// https://github.com/pytorch/pytorch/pull/18871
return (std::abs(weight_val) < 0.5)
? self_val_f + weight_val * (end_val_f - self_val_f)
: end_val_f -
(end_val_f - self_val_f) * (opmath_t{1} - weight_val);
return lerp(self_val, end_val, weight_val);
});
});
}
Expand Down

0 comments on commit 6aa03e2

Please sign in to comment.