From 43a5e0bf4b3f8444805168d9e07b76d0599fef98 Mon Sep 17 00:00:00 2001 From: tompng Date: Sat, 4 Oct 2025 18:53:34 +0900 Subject: [PATCH] Improve performance of x**y when y is a huge value When y.exponent is several thousand or more, x**y was slow because exponentiation by squaring requires several thousands of multiplications. Use exp and log in such case. Needed to calaculate (1+1/n).power(n, prec) --- lib/bigdecimal.rb | 31 ++++++++++++++++-------------- test/bigdecimal/test_bigdecimal.rb | 5 +++++ 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/lib/bigdecimal.rb b/lib/bigdecimal.rb index d0a318d8..cf2dfaae 100644 --- a/lib/bigdecimal.rb +++ b/lib/bigdecimal.rb @@ -165,22 +165,25 @@ def power(y, prec = nil) return BigDecimal(1).div(inv, prec) end - int_part = y.fix.to_i prec2 = prec + BigDecimal.double_fig - pow_prec = prec2 + (int_part > 0 ? y.exponent : 0) - ans = BigDecimal(1) - n = 1 - xn = x - while true - ans = ans.mult(xn, pow_prec) if int_part.allbits?(n) - n <<= 1 - break if n > int_part - xn = xn.mult(xn, pow_prec) - end - unless frac_part.zero? - ans = ans.mult(BigMath.exp(BigMath.log(x, prec2).mult(frac_part, prec2), prec2), prec2) + + if frac_part.zero? && y.exponent < Math.log(prec) * 5 + 20 + # Use exponentiation by squaring if y is an integer and not too large + pow_prec = prec2 + y.exponent + n = 1 + xn = x + ans = BigDecimal(1) + int_part = y.fix.to_i + while true + ans = ans.mult(xn, pow_prec) if int_part.allbits?(n) + n <<= 1 + break if n > int_part + xn = xn.mult(xn, pow_prec) + end + ans.mult(1, prec) + else + BigMath.exp(BigMath.log(x, prec2).mult(y, prec2), prec) end - ans.mult(1, prec) end # Returns the square root of the value. diff --git a/test/bigdecimal/test_bigdecimal.rb b/test/bigdecimal/test_bigdecimal.rb index c794ed0d..26cc5f02 100644 --- a/test/bigdecimal/test_bigdecimal.rb +++ b/test/bigdecimal/test_bigdecimal.rb @@ -1991,6 +1991,11 @@ def test_power_with_rational assert_in_epsilon(z2, x2 ** y, 1e-99) end + def test_power_with_huge_value + n = BigDecimal('7e+10000') + assert_equal(BigMath.exp(1, 100), (1 + BigDecimal(1).div(n, 120)).power(n, 100)) + end + def test_power_precision x = BigDecimal("1.41421356237309504880168872420969807856967187537695") y = BigDecimal("3.14159265358979323846264338327950288419716939937511")