From 1588ea0dbf16f37ce14cfc8764666985c16ccbf9 Mon Sep 17 00:00:00 2001 From: mfkasim1 Date: Thu, 24 Nov 2022 11:11:51 +0000 Subject: [PATCH] Added log1p for complex in c10 (#89214) One PR towards #89205. The content is mostly from PR #38465, but slightly changed the expression to make it faster. Here are some benchmarking code: ```c++ #include #include #include // main.cc template inline std::complex log1p_v0(const std::complex &z) { // this PR T x = z.real(); T y = z.imag(); T theta = std::atan2(y, x + T(1)); T r = x * (x + T(2)) + y * y; return {T(0.5) * std::log1p(r), theta}; } template inline std::complex log1p_v1(const std::complex &z) { // PR #38465 T x = z.real(); T y = z.imag(); std::complex p1 = z + T(1); T r = std::abs(p1); T a = std::arg(p1); T rm1 = (x * x + y * y + x * T(2)) / (r + 1); return {std::log1p(rm1), a}; } template inline std::complex log1p_v2(const std::complex &z) { // naive, but numerically inaccurate return std::log(T(1) + z); } int main() { int n = 1000000; std::complex res(0.0, 0.0); std::complex input(0.5, 2.0); auto start = std::chrono::system_clock::now(); for (int i = 0; i < n; i++) { res += log1p_v0(input); } auto end = std::chrono::system_clock::now(); auto elapsed = end - start; std::cout << "time for v0: " << elapsed.count() << '\n'; start = std::chrono::system_clock::now(); for (int i = 0; i < n; i++) { res += log1p_v1(input); } end = std::chrono::system_clock::now(); elapsed = end - start; std::cout << "time for v1: " << elapsed.count() << '\n'; start = std::chrono::system_clock::now(); for (int i = 0; i < n; i++) { res += log1p_v2(input); } end = std::chrono::system_clock::now(); elapsed = end - start; std::cout << "time for v2: " << elapsed.count() << '\n'; std::cout << res << '\n'; } ``` Compiling the script with command `g++ main.cc` produces the following results: ``` time for v0: 237812271 time for v1: 414524941 time for v2: 360585994 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/89214 Approved by: https://github.com/lezcano --- c10/test/util/complex_math_test_common.h | 128 +++++++++++++++++++++++ c10/util/complex_math.h | 31 ++++++ 2 files changed, 159 insertions(+) diff --git a/c10/test/util/complex_math_test_common.h b/c10/test/util/complex_math_test_common.h index 15addf687856f..ce1be7b38d84d 100644 --- a/c10/test/util/complex_math_test_common.h +++ b/c10/test/util/complex_math_test_common.h @@ -166,6 +166,134 @@ C10_DEFINE_TEST(TestLog2, Rev) { } } +C10_DEFINE_TEST(TestLog1p, Normal) { + // log1p(x) = log(1 + x) + { + c10::complex x(0.1, 1.2); + c10::complex l1 = std::log1p(x); + c10::complex l2 = std::log(1.0f + x); + C10_ASSERT_NEAR(l1.real(), l2.real(), tol); + C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol); + } + { + c10::complex x(0.1, 1.2); + c10::complex l1 = std::log1p(x); + c10::complex l2 = std::log(1.0 + x); + C10_ASSERT_NEAR(l1.real(), l2.real(), tol); + C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol); + } +} + +C10_DEFINE_TEST(TestLog1p, Small) { + // log(1 + x) ~ x for |x| << 1 + { + c10::complex x(1e-9, 2e-9); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real() / x.real(), 1, tol); + C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol); + } + { + c10::complex x(1e-100, 2e-100); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real() / x.real(), 1, tol); + C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol); + } +} + +C10_DEFINE_TEST(TestLog1p, Extreme) { + // log(1 + x) ~ x for |x| << 1 and in the brink of overflow / underflow + { + c10::complex x(-1, 1e-30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), -69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(-1, 1e30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(1e30, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol); + C10_ASSERT_NEAR(l.imag(), 1e-30, tol); + } + { + c10::complex x(1e-30, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e30, 1e30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 69.42412638010134, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e-38, 1e-38); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-38, tol); + C10_ASSERT_NEAR(l.imag(), 1e-38, tol); + } + { + c10::complex x(1e-38, 2e-30); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-30, tol); + C10_ASSERT_NEAR(l.imag(), 2e-30, tol); + } + { + c10::complex x(-1, 1e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), -575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(-1, 1e250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol); + } + { + c10::complex x(1e250, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol); + C10_ASSERT_NEAR(l.imag(), 1e-250, tol); + } + { + c10::complex x(1e-250, 1); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e250, 1e250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 575.9928468387914, tol); + C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol); + } + { + c10::complex x(1e-250, 1e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-250, tol); + C10_ASSERT_NEAR(l.imag(), 1e-250, tol); + } + { + c10::complex x(1e-250, 2e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 1e-250, tol); + C10_ASSERT_NEAR(l.imag(), 2e-250, tol); + } + { + c10::complex x(2e-308, 1.5e-250); + c10::complex l = std::log1p(x); + C10_ASSERT_NEAR(l.real(), 2e-308, tol); + C10_ASSERT_NEAR(l.imag(), 1.5e-308, tol); + } +} + // Power functions C10_DEFINE_TEST(TestPowSqrt, Equal) { diff --git a/c10/util/complex_math.h b/c10/util/complex_math.h index ecfd0442b751b..8709fe4a0eb55 100644 --- a/c10/util/complex_math.h +++ b/c10/util/complex_math.h @@ -291,6 +291,35 @@ C10_HOST_DEVICE inline c10::complex atanh(const c10::complex& x) { #endif } +template +C10_HOST_DEVICE inline c10::complex log1p(const c10::complex& z) { + // log1p(z) = log(1 + z) + // Let's define 1 + z = r * e ^ (i * a), then we have + // log(r * e ^ (i * a)) = log(r) + i * a + // With z = x + iy, the term r can be written as + // r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5 + // = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5 + // So, log(r) is + // log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2) + // = 0.5 * log1p(x * (x + 2) + y ^ 2) + // we need to use the expression only on certain condition to avoid overflow + // and underflow from `(x * (x + 2) + y ^ 2)` + T x = z.real(); + T y = z.imag(); + T zabs = std::abs(z); + T theta = std::atan2(y, x + T(1)); + if (zabs < 0.5) { + T r = x * (T(2) + x) + y * y; + if (r == 0) { // handle underflow + return {x, theta}; + } + return {T(0.5) * std::log1p(r), theta}; + } else { + T z0 = std::hypot(x + 1, y); + return {std::log(z0), theta}; + } +} + } // namespace c10_complex_math using c10_complex_math::acos; @@ -304,6 +333,7 @@ using c10_complex_math::cosh; using c10_complex_math::exp; using c10_complex_math::log; using c10_complex_math::log10; +using c10_complex_math::log1p; using c10_complex_math::log2; using c10_complex_math::pow; using c10_complex_math::sin; @@ -325,6 +355,7 @@ using c10_complex_math::cosh; using c10_complex_math::exp; using c10_complex_math::log; using c10_complex_math::log10; +using c10_complex_math::log1p; using c10_complex_math::log2; using c10_complex_math::pow; using c10_complex_math::sin;