Skip to content

Commit

Permalink
GH-100485: Add math.sumprod() (GH-100677)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhettinger committed Jan 7, 2023
1 parent deaf090 commit 47b9f83
Show file tree
Hide file tree
Showing 6 changed files with 548 additions and 10 deletions.
11 changes: 2 additions & 9 deletions Doc/library/itertools.rst
Expand Up @@ -33,7 +33,7 @@ by combining :func:`map` and :func:`count` to form ``map(f, count())``.
These tools and their built-in counterparts also work well with the high-speed
functions in the :mod:`operator` module. For example, the multiplication
operator can be mapped across two vectors to form an efficient dot-product:
``sum(map(operator.mul, vector1, vector2))``.
``sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))``.


**Infinite iterators:**
Expand Down Expand Up @@ -838,10 +838,6 @@ which incur interpreter overhead.
"Returns the sequence elements n times"
return chain.from_iterable(repeat(tuple(iterable), n))

def dotproduct(vec1, vec2):
"Compute a sum of products."
return sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))

def convolve(signal, kernel):
# See: https://betterexplained.com/articles/intuitive-convolution/
# convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur)
Expand All @@ -852,7 +848,7 @@ which incur interpreter overhead.
window = collections.deque([0], maxlen=n) * n
for x in chain(signal, repeat(0, n-1)):
window.append(x)
yield dotproduct(kernel, window)
yield math.sumprod(kernel, window)

def polynomial_from_roots(roots):
"""Compute a polynomial's coefficients from its roots.
Expand Down Expand Up @@ -1211,9 +1207,6 @@ which incur interpreter overhead.
>>> list(ncycles('abc', 3))
['a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c']

>>> dotproduct([1,2,3], [4,5,6])
32

>>> data = [20, 40, 24, 32, 20, 28, 16]
>>> list(convolve(data, [0.25, 0.25, 0.25, 0.25]))
[5.0, 15.0, 21.0, 29.0, 29.0, 26.0, 24.0, 16.0, 11.0, 4.0]
Expand Down
16 changes: 16 additions & 0 deletions Doc/library/math.rst
Expand Up @@ -291,6 +291,22 @@ Number-theoretic and representation functions
.. versionadded:: 3.7


.. function:: sumprod(p, q)

Return the sum of products of values from two iterables *p* and *q*.

Raises :exc:`ValueError` if the inputs do not have the same length.

Roughly equivalent to::

sum(itertools.starmap(operator.mul, zip(p, q, strict=true)))

For float and mixed int/float inputs, the intermediate products
and sums are computed with extended precision.

.. versionadded:: 3.12


.. function:: trunc(x)

Return *x* with the fractional part
Expand Down
166 changes: 166 additions & 0 deletions Lib/test/test_math.py
Expand Up @@ -4,6 +4,7 @@
from test.support import verbose, requires_IEEE_754
from test import support
import unittest
import fractions
import itertools
import decimal
import math
Expand Down Expand Up @@ -1202,6 +1203,171 @@ def testLog10(self):
self.assertEqual(math.log(INF), INF)
self.assertTrue(math.isnan(math.log10(NAN)))

def testSumProd(self):
sumprod = math.sumprod
Decimal = decimal.Decimal
Fraction = fractions.Fraction

# Core functionality
self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140)
self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5)
self.assertEqual(sumprod([], []), 0)

# Type preservation and coercion
for v in [
(10, 20, 30),
(1.5, -2.5),
(Fraction(3, 5), Fraction(4, 5)),
(Decimal(3.5), Decimal(4.5)),
(2.5, 10), # float/int
(2.5, Fraction(3, 5)), # float/fraction
(25, Fraction(3, 5)), # int/fraction
(25, Decimal(4.5)), # int/decimal
]:
for p, q in [(v, v), (v, v[::-1])]:
with self.subTest(p=p, q=q):
expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True))
actual = sumprod(p, q)
self.assertEqual(expected, actual)
self.assertEqual(type(expected), type(actual))

# Bad arguments
self.assertRaises(TypeError, sumprod) # No args
self.assertRaises(TypeError, sumprod, []) # One arg
self.assertRaises(TypeError, sumprod, [], [], []) # Three args
self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable
self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable

# Uneven lengths
self.assertRaises(ValueError, sumprod, [10, 20], [30])
self.assertRaises(ValueError, sumprod, [10], [20, 30])

# Error in iterator
def raise_after(n):
for i in range(n):
yield i
raise RuntimeError
with self.assertRaises(RuntimeError):
sumprod(range(10), raise_after(5))
with self.assertRaises(RuntimeError):
sumprod(raise_after(5), range(10))

# Error in multiplication
class BadMultiply:
def __mul__(self, other):
raise RuntimeError
def __rmul__(self, other):
raise RuntimeError
with self.assertRaises(RuntimeError):
sumprod([10, BadMultiply(), 30], [1, 2, 3])
with self.assertRaises(RuntimeError):
sumprod([1, 2, 3], [10, BadMultiply(), 30])

# Error in addition
with self.assertRaises(TypeError):
sumprod(['abc', 3], [5, 10])
with self.assertRaises(TypeError):
sumprod([5, 10], ['abc', 3])

# Special values should give the same as the pure python recipe
self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf)
self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf)
self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf)
self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf)
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf])))
self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3])))
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3])))
self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan])))

# Error cases that arose during development
args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952))
self.assertEqual(sumprod(*args), 0.0)


@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"sumprod() accuracy not guaranteed on machines with double rounding")
@support.cpython_only # Other implementations may choose a different algorithm
def test_sumprod_accuracy(self):
sumprod = math.sumprod
self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0)
self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0)
self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0)

def test_sumprod_stress(self):
sumprod = math.sumprod
product = itertools.product
Decimal = decimal.Decimal
Fraction = fractions.Fraction

class Int(int):
def __add__(self, other):
return Int(int(self) + int(other))
def __mul__(self, other):
return Int(int(self) * int(other))
__radd__ = __add__
__rmul__ = __mul__
def __repr__(self):
return f'Int({int(self)})'

class Flt(float):
def __add__(self, other):
return Int(int(self) + int(other))
def __mul__(self, other):
return Int(int(self) * int(other))
__radd__ = __add__
__rmul__ = __mul__
def __repr__(self):
return f'Flt({int(self)})'

def baseline_sumprod(p, q):
"""This defines the target behavior including expections and special values.
However, it is subject to rounding errors, so float inputs should be exactly
representable with only a few bits.
"""
total = 0
for p_i, q_i in zip(p, q, strict=True):
total += p_i * q_i
return total

def run(func, *args):
"Make comparing functions easier. Returns error status, type, and result."
try:
result = func(*args)
except (AssertionError, NameError):
raise
except Exception as e:
return type(e), None, 'None'
return None, type(result), repr(result)

pools = [
(-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)),
(5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125),
(-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333,
5.25, -3.25, -3.0*2**(-333), 3, 2**513),
(3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14,
9, 3+4j, Flt(13), 0.0),
(13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8),
Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)),
(Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0),
Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5),
(-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538,
2*2**-513),
(-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25),
(11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)),
]

for pool in pools:
for size in range(4):
for args1 in product(pool, repeat=size):
for args2 in product(pool, repeat=size):
args = (args1, args2)
self.assertEqual(
run(baseline_sumprod, *args),
run(sumprod, *args),
args,
)

def testModf(self):
self.assertRaises(TypeError, math.modf)

Expand Down
@@ -0,0 +1 @@
Add math.sumprod() to compute the sum of products.
39 changes: 38 additions & 1 deletion Modules/clinic/mathmodule.c.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 47b9f83

Please sign in to comment.