Skip to content

Commit

Permalink
Added log1p for complex in c10 (#89214)
Browse files Browse the repository at this point in the history
One PR towards #89205.
The content is mostly from PR #38465, but slightly changed the expression to make it faster.

Here are some benchmarking code:
```c++
#include <complex>
#include <iostream>
#include <chrono>

// main.cc

template<typename T> inline std::complex<T> log1p_v0(const std::complex<T> &z) {
    // this PR
    T x = z.real();
    T y = z.imag();
    T theta = std::atan2(y, x + T(1));
    T r = x * (x + T(2)) + y * y;
    return {T(0.5) * std::log1p(r), theta};
}

template<typename T> inline std::complex<T> log1p_v1(const std::complex<T> &z) {
    // PR #38465
    T x = z.real();
    T y = z.imag();
    std::complex<T> p1 = z + T(1);
    T r = std::abs(p1);
    T a = std::arg(p1);
    T rm1 = (x * x + y * y + x * T(2)) / (r + 1);
    return {std::log1p(rm1), a};
}

template<typename T>
inline std::complex<T> log1p_v2(const std::complex<T> &z) {
    // naive, but numerically inaccurate
    return std::log(T(1) + z);
}

int main() {
    int n = 1000000;
    std::complex<float> res(0.0, 0.0);
    std::complex<float> input(0.5, 2.0);
    auto start = std::chrono::system_clock::now();
    for (int i = 0; i < n; i++) {
        res += log1p_v0(input);
    }
    auto end = std::chrono::system_clock::now();
    auto elapsed = end - start;
    std::cout << "time for v0: " << elapsed.count() << '\n';

    start = std::chrono::system_clock::now();
    for (int i = 0; i < n; i++) {
        res += log1p_v1(input);
    }
    end = std::chrono::system_clock::now();
    elapsed = end - start;
    std::cout << "time for v1: " << elapsed.count() << '\n';

    start = std::chrono::system_clock::now();
    for (int i = 0; i < n; i++) {
        res += log1p_v2(input);
    }
    end = std::chrono::system_clock::now();
    elapsed = end - start;
    std::cout << "time for v2: " << elapsed.count() << '\n';
    std::cout << res << '\n';
}
```

Compiling the script with command `g++ main.cc` produces the following results:
```
time for v0: 237812271
time for v1: 414524941
time for v2: 360585994
```

Pull Request resolved: #89214
Approved by: https://github.com/lezcano
  • Loading branch information
mfkasim1 authored and pytorchmergebot committed Nov 24, 2022
1 parent 4f5c4c0 commit 1588ea0
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 0 deletions.
128 changes: 128 additions & 0 deletions c10/test/util/complex_math_test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,134 @@ C10_DEFINE_TEST(TestLog2, Rev) {
}
}

C10_DEFINE_TEST(TestLog1p, Normal) {
// log1p(x) = log(1 + x)
{
c10::complex<float> x(0.1, 1.2);
c10::complex<float> l1 = std::log1p(x);
c10::complex<float> l2 = std::log(1.0f + x);
C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
}
{
c10::complex<double> x(0.1, 1.2);
c10::complex<double> l1 = std::log1p(x);
c10::complex<double> l2 = std::log(1.0 + x);
C10_ASSERT_NEAR(l1.real(), l2.real(), tol);
C10_ASSERT_NEAR(l1.imag(), l2.imag(), tol);
}
}

C10_DEFINE_TEST(TestLog1p, Small) {
// log(1 + x) ~ x for |x| << 1
{
c10::complex<float> x(1e-9, 2e-9);
c10::complex<float> l = std::log1p(x);
C10_ASSERT_NEAR(l.real() / x.real(), 1, tol);
C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol);
}
{
c10::complex<double> x(1e-100, 2e-100);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real() / x.real(), 1, tol);
C10_ASSERT_NEAR(l.imag() / x.imag(), 1, tol);
}
}

C10_DEFINE_TEST(TestLog1p, Extreme) {
// log(1 + x) ~ x for |x| << 1 and in the brink of overflow / underflow
{
c10::complex<float> x(-1, 1e-30);
c10::complex<float> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), -69.07755278982137, tol);
C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol);
}
{
c10::complex<float> x(-1, 1e30);
c10::complex<float> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol);
C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol);
}
{
c10::complex<float> x(1e30, 1);
c10::complex<float> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 69.07755278982137, tol);
C10_ASSERT_NEAR(l.imag(), 1e-30, tol);
}
{
c10::complex<float> x(1e-30, 1);
c10::complex<float> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol);
C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol);
}
{
c10::complex<float> x(1e30, 1e30);
c10::complex<float> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 69.42412638010134, tol);
C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol);
}
{
c10::complex<float> x(1e-38, 1e-38);
c10::complex<float> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 1e-38, tol);
C10_ASSERT_NEAR(l.imag(), 1e-38, tol);
}
{
c10::complex<float> x(1e-38, 2e-30);
c10::complex<float> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 1e-30, tol);
C10_ASSERT_NEAR(l.imag(), 2e-30, tol);
}
{
c10::complex<double> x(-1, 1e-250);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), -575.6462732485114, tol);
C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol);
}
{
c10::complex<double> x(-1, 1e250);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol);
C10_ASSERT_NEAR(l.imag(), 1.5707963267948966, tol);
}
{
c10::complex<double> x(1e250, 1);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 575.6462732485114, tol);
C10_ASSERT_NEAR(l.imag(), 1e-250, tol);
}
{
c10::complex<double> x(1e-250, 1);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 0.34657359027997264, tol);
C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol);
}
{
c10::complex<double> x(1e250, 1e250);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 575.9928468387914, tol);
C10_ASSERT_NEAR(l.imag(), 0.7853981633974483, tol);
}
{
c10::complex<double> x(1e-250, 1e-250);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 1e-250, tol);
C10_ASSERT_NEAR(l.imag(), 1e-250, tol);
}
{
c10::complex<double> x(1e-250, 2e-250);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 1e-250, tol);
C10_ASSERT_NEAR(l.imag(), 2e-250, tol);
}
{
c10::complex<double> x(2e-308, 1.5e-250);
c10::complex<double> l = std::log1p(x);
C10_ASSERT_NEAR(l.real(), 2e-308, tol);
C10_ASSERT_NEAR(l.imag(), 1.5e-308, tol);
}
}

// Power functions

C10_DEFINE_TEST(TestPowSqrt, Equal) {
Expand Down
31 changes: 31 additions & 0 deletions c10/util/complex_math.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,35 @@ C10_HOST_DEVICE inline c10::complex<T> atanh(const c10::complex<T>& x) {
#endif
}

template <typename T>
C10_HOST_DEVICE inline c10::complex<T> log1p(const c10::complex<T>& z) {
// log1p(z) = log(1 + z)
// Let's define 1 + z = r * e ^ (i * a), then we have
// log(r * e ^ (i * a)) = log(r) + i * a
// With z = x + iy, the term r can be written as
// r = ((1 + x) ^ 2 + y ^ 2) ^ 0.5
// = (1 + x ^ 2 + 2 * x + y ^ 2) ^ 0.5
// So, log(r) is
// log(r) = 0.5 * log(1 + x ^ 2 + 2 * x + y ^ 2)
// = 0.5 * log1p(x * (x + 2) + y ^ 2)
// we need to use the expression only on certain condition to avoid overflow
// and underflow from `(x * (x + 2) + y ^ 2)`
T x = z.real();
T y = z.imag();
T zabs = std::abs(z);
T theta = std::atan2(y, x + T(1));
if (zabs < 0.5) {
T r = x * (T(2) + x) + y * y;
if (r == 0) { // handle underflow
return {x, theta};
}
return {T(0.5) * std::log1p(r), theta};
} else {
T z0 = std::hypot(x + 1, y);
return {std::log(z0), theta};
}
}

} // namespace c10_complex_math

using c10_complex_math::acos;
Expand All @@ -304,6 +333,7 @@ using c10_complex_math::cosh;
using c10_complex_math::exp;
using c10_complex_math::log;
using c10_complex_math::log10;
using c10_complex_math::log1p;
using c10_complex_math::log2;
using c10_complex_math::pow;
using c10_complex_math::sin;
Expand All @@ -325,6 +355,7 @@ using c10_complex_math::cosh;
using c10_complex_math::exp;
using c10_complex_math::log;
using c10_complex_math::log10;
using c10_complex_math::log1p;
using c10_complex_math::log2;
using c10_complex_math::pow;
using c10_complex_math::sin;
Expand Down

0 comments on commit 1588ea0

Please sign in to comment.