From 47b9f83a83db288c652e43567c7b0f74d87a29be Mon Sep 17 00:00:00 2001 From: Raymond Hettinger Date: Sat, 7 Jan 2023 12:46:35 -0600 Subject: [PATCH] GH-100485: Add math.sumprod() (GH-100677) --- Doc/library/itertools.rst | 11 +- Doc/library/math.rst | 16 + Lib/test/test_math.py | 166 +++++++++ ...-01-01-21-54-46.gh-issue-100485.geNrHS.rst | 1 + Modules/clinic/mathmodule.c.h | 39 ++- Modules/mathmodule.c | 325 ++++++++++++++++++ 6 files changed, 548 insertions(+), 10 deletions(-) create mode 100644 Misc/NEWS.d/next/Library/2023-01-01-21-54-46.gh-issue-100485.geNrHS.rst diff --git a/Doc/library/itertools.rst b/Doc/library/itertools.rst index 4ad679dfccdbb4..e7d2e13626fa52 100644 --- a/Doc/library/itertools.rst +++ b/Doc/library/itertools.rst @@ -33,7 +33,7 @@ by combining :func:`map` and :func:`count` to form ``map(f, count())``. These tools and their built-in counterparts also work well with the high-speed functions in the :mod:`operator` module. For example, the multiplication operator can be mapped across two vectors to form an efficient dot-product: -``sum(map(operator.mul, vector1, vector2))``. +``sum(starmap(operator.mul, zip(vec1, vec2, strict=True)))``. **Infinite iterators:** @@ -838,10 +838,6 @@ which incur interpreter overhead. "Returns the sequence elements n times" return chain.from_iterable(repeat(tuple(iterable), n)) - def dotproduct(vec1, vec2): - "Compute a sum of products." - return sum(starmap(operator.mul, zip(vec1, vec2, strict=True))) - def convolve(signal, kernel): # See: https://betterexplained.com/articles/intuitive-convolution/ # convolve(data, [0.25, 0.25, 0.25, 0.25]) --> Moving average (blur) @@ -852,7 +848,7 @@ which incur interpreter overhead. window = collections.deque([0], maxlen=n) * n for x in chain(signal, repeat(0, n-1)): window.append(x) - yield dotproduct(kernel, window) + yield math.sumprod(kernel, window) def polynomial_from_roots(roots): """Compute a polynomial's coefficients from its roots. @@ -1211,9 +1207,6 @@ which incur interpreter overhead. >>> list(ncycles('abc', 3)) ['a', 'b', 'c', 'a', 'b', 'c', 'a', 'b', 'c'] - >>> dotproduct([1,2,3], [4,5,6]) - 32 - >>> data = [20, 40, 24, 32, 20, 28, 16] >>> list(convolve(data, [0.25, 0.25, 0.25, 0.25])) [5.0, 15.0, 21.0, 29.0, 29.0, 26.0, 24.0, 16.0, 11.0, 4.0] diff --git a/Doc/library/math.rst b/Doc/library/math.rst index aeebcaf6ab0864..0e888c4d4e423f 100644 --- a/Doc/library/math.rst +++ b/Doc/library/math.rst @@ -291,6 +291,22 @@ Number-theoretic and representation functions .. versionadded:: 3.7 +.. function:: sumprod(p, q) + + Return the sum of products of values from two iterables *p* and *q*. + + Raises :exc:`ValueError` if the inputs do not have the same length. + + Roughly equivalent to:: + + sum(itertools.starmap(operator.mul, zip(p, q, strict=true))) + + For float and mixed int/float inputs, the intermediate products + and sums are computed with extended precision. + + .. versionadded:: 3.12 + + .. function:: trunc(x) Return *x* with the fractional part diff --git a/Lib/test/test_math.py b/Lib/test/test_math.py index bf0d0a56e6ac8b..65fe169671eae2 100644 --- a/Lib/test/test_math.py +++ b/Lib/test/test_math.py @@ -4,6 +4,7 @@ from test.support import verbose, requires_IEEE_754 from test import support import unittest +import fractions import itertools import decimal import math @@ -1202,6 +1203,171 @@ def testLog10(self): self.assertEqual(math.log(INF), INF) self.assertTrue(math.isnan(math.log10(NAN))) + def testSumProd(self): + sumprod = math.sumprod + Decimal = decimal.Decimal + Fraction = fractions.Fraction + + # Core functionality + self.assertEqual(sumprod(iter([10, 20, 30]), (1, 2, 3)), 140) + self.assertEqual(sumprod([1.5, 2.5], [3.5, 4.5]), 16.5) + self.assertEqual(sumprod([], []), 0) + + # Type preservation and coercion + for v in [ + (10, 20, 30), + (1.5, -2.5), + (Fraction(3, 5), Fraction(4, 5)), + (Decimal(3.5), Decimal(4.5)), + (2.5, 10), # float/int + (2.5, Fraction(3, 5)), # float/fraction + (25, Fraction(3, 5)), # int/fraction + (25, Decimal(4.5)), # int/decimal + ]: + for p, q in [(v, v), (v, v[::-1])]: + with self.subTest(p=p, q=q): + expected = sum(p_i * q_i for p_i, q_i in zip(p, q, strict=True)) + actual = sumprod(p, q) + self.assertEqual(expected, actual) + self.assertEqual(type(expected), type(actual)) + + # Bad arguments + self.assertRaises(TypeError, sumprod) # No args + self.assertRaises(TypeError, sumprod, []) # One arg + self.assertRaises(TypeError, sumprod, [], [], []) # Three args + self.assertRaises(TypeError, sumprod, None, [10]) # Non-iterable + self.assertRaises(TypeError, sumprod, [10], None) # Non-iterable + + # Uneven lengths + self.assertRaises(ValueError, sumprod, [10, 20], [30]) + self.assertRaises(ValueError, sumprod, [10], [20, 30]) + + # Error in iterator + def raise_after(n): + for i in range(n): + yield i + raise RuntimeError + with self.assertRaises(RuntimeError): + sumprod(range(10), raise_after(5)) + with self.assertRaises(RuntimeError): + sumprod(raise_after(5), range(10)) + + # Error in multiplication + class BadMultiply: + def __mul__(self, other): + raise RuntimeError + def __rmul__(self, other): + raise RuntimeError + with self.assertRaises(RuntimeError): + sumprod([10, BadMultiply(), 30], [1, 2, 3]) + with self.assertRaises(RuntimeError): + sumprod([1, 2, 3], [10, BadMultiply(), 30]) + + # Error in addition + with self.assertRaises(TypeError): + sumprod(['abc', 3], [5, 10]) + with self.assertRaises(TypeError): + sumprod([5, 10], ['abc', 3]) + + # Special values should give the same as the pure python recipe + self.assertEqual(sumprod([10.1, math.inf], [20.2, 30.3]), math.inf) + self.assertEqual(sumprod([10.1, math.inf], [math.inf, 30.3]), math.inf) + self.assertEqual(sumprod([10.1, math.inf], [math.inf, math.inf]), math.inf) + self.assertEqual(sumprod([10.1, -math.inf], [20.2, 30.3]), -math.inf) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [-math.inf, math.inf]))) + self.assertTrue(math.isnan(sumprod([10.1, math.nan], [20.2, 30.3]))) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [math.nan, 30.3]))) + self.assertTrue(math.isnan(sumprod([10.1, math.inf], [20.3, math.nan]))) + + # Error cases that arose during development + args = ((-5, -5, 10), (1.5, 4611686018427387904, 2305843009213693952)) + self.assertEqual(sumprod(*args), 0.0) + + + @requires_IEEE_754 + @unittest.skipIf(HAVE_DOUBLE_ROUNDING, + "sumprod() accuracy not guaranteed on machines with double rounding") + @support.cpython_only # Other implementations may choose a different algorithm + def test_sumprod_accuracy(self): + sumprod = math.sumprod + self.assertEqual(sumprod([0.1] * 10, [1]*10), 1.0) + self.assertEqual(sumprod([0.1] * 20, [True, False] * 10), 1.0) + self.assertEqual(sumprod([1.0, 10E100, 1.0, -10E100], [1.0]*4), 2.0) + + def test_sumprod_stress(self): + sumprod = math.sumprod + product = itertools.product + Decimal = decimal.Decimal + Fraction = fractions.Fraction + + class Int(int): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Int({int(self)})' + + class Flt(float): + def __add__(self, other): + return Int(int(self) + int(other)) + def __mul__(self, other): + return Int(int(self) * int(other)) + __radd__ = __add__ + __rmul__ = __mul__ + def __repr__(self): + return f'Flt({int(self)})' + + def baseline_sumprod(p, q): + """This defines the target behavior including expections and special values. + However, it is subject to rounding errors, so float inputs should be exactly + representable with only a few bits. + """ + total = 0 + for p_i, q_i in zip(p, q, strict=True): + total += p_i * q_i + return total + + def run(func, *args): + "Make comparing functions easier. Returns error status, type, and result." + try: + result = func(*args) + except (AssertionError, NameError): + raise + except Exception as e: + return type(e), None, 'None' + return None, type(result), repr(result) + + pools = [ + (-5, 10, -2**20, 2**31, 2**40, 2**61, 2**62, 2**80, 1.5, Int(7)), + (5.25, -3.5, 4.75, 11.25, 400.5, 0.046875, 0.25, -1.0, -0.078125), + (-19.0*2**500, 11*2**1000, -3*2**1500, 17*2*333, + 5.25, -3.25, -3.0*2**(-333), 3, 2**513), + (3.75, 2.5, -1.5, float('inf'), -float('inf'), float('NaN'), 14, + 9, 3+4j, Flt(13), 0.0), + (13.25, -4.25, Decimal('10.5'), Decimal('-2.25'), Fraction(13, 8), + Fraction(-11, 16), 4.75 + 0.125j, 97, -41, Int(3)), + (Decimal('6.125'), Decimal('12.375'), Decimal('-2.75'), Decimal(0), + Decimal('Inf'), -Decimal('Inf'), Decimal('NaN'), 12, 13.5), + (-2.0 ** -1000, 11*2**1000, 3, 7, -37*2**32, -2*2**-537, -2*2**-538, + 2*2**-513), + (-7 * 2.0 ** -510, 5 * 2.0 ** -520, 17, -19.0, -6.25), + (11.25, -3.75, -0.625, 23.375, True, False, 7, Int(5)), + ] + + for pool in pools: + for size in range(4): + for args1 in product(pool, repeat=size): + for args2 in product(pool, repeat=size): + args = (args1, args2) + self.assertEqual( + run(baseline_sumprod, *args), + run(sumprod, *args), + args, + ) + def testModf(self): self.assertRaises(TypeError, math.modf) diff --git a/Misc/NEWS.d/next/Library/2023-01-01-21-54-46.gh-issue-100485.geNrHS.rst b/Misc/NEWS.d/next/Library/2023-01-01-21-54-46.gh-issue-100485.geNrHS.rst new file mode 100644 index 00000000000000..9f6e593113bb54 --- /dev/null +++ b/Misc/NEWS.d/next/Library/2023-01-01-21-54-46.gh-issue-100485.geNrHS.rst @@ -0,0 +1 @@ +Add math.sumprod() to compute the sum of products. diff --git a/Modules/clinic/mathmodule.c.h b/Modules/clinic/mathmodule.c.h index 9fac1037e52528..1f9725883b9820 100644 --- a/Modules/clinic/mathmodule.c.h +++ b/Modules/clinic/mathmodule.c.h @@ -333,6 +333,43 @@ math_dist(PyObject *module, PyObject *const *args, Py_ssize_t nargs) return return_value; } +PyDoc_STRVAR(math_sumprod__doc__, +"sumprod($module, p, q, /)\n" +"--\n" +"\n" +"Return the sum of products of values from two iterables p and q.\n" +"\n" +"Roughly equivalent to:\n" +"\n" +" sum(itertools.starmap(operator.mul, zip(p, q, strict=True)))\n" +"\n" +"For float and mixed int/float inputs, the intermediate products\n" +"and sums are computed with extended precision."); + +#define MATH_SUMPROD_METHODDEF \ + {"sumprod", _PyCFunction_CAST(math_sumprod), METH_FASTCALL, math_sumprod__doc__}, + +static PyObject * +math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q); + +static PyObject * +math_sumprod(PyObject *module, PyObject *const *args, Py_ssize_t nargs) +{ + PyObject *return_value = NULL; + PyObject *p; + PyObject *q; + + if (!_PyArg_CheckPositional("sumprod", nargs, 2, 2)) { + goto exit; + } + p = args[0]; + q = args[1]; + return_value = math_sumprod_impl(module, p, q); + +exit: + return return_value; +} + PyDoc_STRVAR(math_pow__doc__, "pow($module, x, y, /)\n" "--\n" @@ -917,4 +954,4 @@ math_ulp(PyObject *module, PyObject *arg) exit: return return_value; } -/*[clinic end generated code: output=c2c2f42452d63734 input=a9049054013a1b77]*/ +/*[clinic end generated code: output=899211ec70e4506c input=a9049054013a1b77]*/ diff --git a/Modules/mathmodule.c b/Modules/mathmodule.c index 49c0293d4f5ce3..0bcb336efbb03a 100644 --- a/Modules/mathmodule.c +++ b/Modules/mathmodule.c @@ -68,6 +68,7 @@ raised for division by zero and mod by zero. #include /* For _Py_log1p with workarounds for buggy handling of zeros. */ #include "_math.h" +#include #include "clinic/mathmodule.c.h" @@ -2819,6 +2820,329 @@ For example, the hypotenuse of a 3/4/5 right triangle is:\n\ 5.0\n\ "); +/** sumprod() ***************************************************************/ + +/* Forward declaration */ +static inline int _check_long_mult_overflow(long a, long b); + +static inline bool +long_add_would_overflow(long a, long b) +{ + return (a > 0) ? (b > LONG_MAX - a) : (b < LONG_MIN - a); +} + +/* +Double length extended precision floating point arithmetic +based on ideas from three sources: + + Improved Kahan–Babuška algorithm by Arnold Neumaier + https://www.mat.univie.ac.at/~neum/scan/01.pdf + + A Floating-Point Technique for Extending the Available Precision + by T. J. Dekker + https://csclub.uwaterloo.ca/~pbarfuss/dekker1971.pdf + + Ultimately Fast Accurate Summation by Siegfried M. Rump + https://www.tuhh.de/ti3/paper/rump/Ru08b.pdf + +The double length routines allow for quite a bit of instruction +level parallelism. On a 3.22 Ghz Apple M1 Max, the incremental +cost of increasing the input vector size by one is 6.25 nsec. + +dl_zero() returns an extended precision zero +dl_split() exactly splits a double into two half precision components. +dl_add() performs compensated summation to keep a running total. +dl_mul() implements lossless multiplication of doubles. +dl_fma() implements an extended precision fused-multiply-add. +dl_to_d() converts from extended precision to double precision. + +*/ + +typedef struct{ double hi; double lo; } DoubleLength; + +static inline DoubleLength +dl_zero() +{ + return (DoubleLength) {0.0, 0.0}; +} +static inline DoubleLength +dl_add(DoubleLength total, double x) +{ + double s = total.hi + x; + double c = total.lo; + if (fabs(total.hi) >= fabs(x)) { + c += (total.hi - s) + x; + } else { + c += (x - s) + total.hi; + } + return (DoubleLength) {s, c}; +} + +static inline DoubleLength +dl_split(double x) { + double t = x * 134217729.0; /* Veltkamp constant = float(0x8000001) */ + double hi = t - (t - x); + double lo = x - hi; + return (DoubleLength) {hi, lo}; +} + +static inline DoubleLength +dl_mul(double x, double y) +{ + /* Dekker mul12(). Section (5.12) */ + DoubleLength xx = dl_split(x); + DoubleLength yy = dl_split(y); + double p = xx.hi * yy.hi; + double q = xx.hi * yy.lo + xx.lo * yy.hi; + double z = p + q; + double zz = p - z + q + xx.lo * yy.lo; + return (DoubleLength) {z, zz}; +} + +static inline DoubleLength +dl_fma(DoubleLength total, double p, double q) +{ + DoubleLength product = dl_mul(p, q); + total = dl_add(total, product.hi); + return dl_add(total, product.lo); +} + +static inline double +dl_to_d(DoubleLength total) +{ + return total.hi + total.lo; +} + +/*[clinic input] +math.sumprod + + p: object + q: object + / + +Return the sum of products of values from two iterables p and q. + +Roughly equivalent to: + + sum(itertools.starmap(operator.mul, zip(p, q, strict=True))) + +For float and mixed int/float inputs, the intermediate products +and sums are computed with extended precision. +[clinic start generated code]*/ + +static PyObject * +math_sumprod_impl(PyObject *module, PyObject *p, PyObject *q) +/*[clinic end generated code: output=6722dbfe60664554 input=82be54fe26f87e30]*/ +{ + PyObject *p_i = NULL, *q_i = NULL, *term_i = NULL, *new_total = NULL; + PyObject *p_it, *q_it, *total; + iternextfunc p_next, q_next; + bool p_stopped = false, q_stopped = false; + bool int_path_enabled = true, int_total_in_use = false; + bool flt_path_enabled = true, flt_total_in_use = false; + long int_total = 0; + DoubleLength flt_total = dl_zero(); + + p_it = PyObject_GetIter(p); + if (p_it == NULL) { + return NULL; + } + q_it = PyObject_GetIter(q); + if (q_it == NULL) { + Py_DECREF(p_it); + return NULL; + } + total = PyLong_FromLong(0); + if (total == NULL) { + Py_DECREF(p_it); + Py_DECREF(q_it); + return NULL; + } + p_next = *Py_TYPE(p_it)->tp_iternext; + q_next = *Py_TYPE(q_it)->tp_iternext; + while (1) { + bool finished; + + assert (p_i == NULL); + assert (q_i == NULL); + assert (term_i == NULL); + assert (new_total == NULL); + + assert (p_it != NULL); + assert (q_it != NULL); + assert (total != NULL); + + p_i = p_next(p_it); + if (p_i == NULL) { + if (PyErr_Occurred()) { + if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { + goto err_exit; + } + PyErr_Clear(); + } + p_stopped = true; + } + q_i = q_next(q_it); + if (q_i == NULL) { + if (PyErr_Occurred()) { + if (!PyErr_ExceptionMatches(PyExc_StopIteration)) { + goto err_exit; + } + PyErr_Clear(); + } + q_stopped = true; + } + if (p_stopped != q_stopped) { + PyErr_Format(PyExc_ValueError, "Inputs are not the same length"); + goto err_exit; + } + finished = p_stopped & q_stopped; + + if (int_path_enabled) { + + if (!finished && PyLong_CheckExact(p_i) & PyLong_CheckExact(q_i)) { + int overflow; + long int_p, int_q, int_prod; + + int_p = PyLong_AsLongAndOverflow(p_i, &overflow); + if (overflow) { + goto finalize_int_path; + } + int_q = PyLong_AsLongAndOverflow(q_i, &overflow); + if (overflow) { + goto finalize_int_path; + } + if (_check_long_mult_overflow(int_p, int_q)) { + goto finalize_int_path; + } + int_prod = int_p * int_q; + if (long_add_would_overflow(int_total, int_prod)) { + goto finalize_int_path; + } + int_total += int_prod; + int_total_in_use = true; + Py_CLEAR(p_i); + Py_CLEAR(q_i); + continue; + } + + finalize_int_path: + // # We're finished, overflowed, or have a non-int + int_path_enabled = false; + if (int_total_in_use) { + term_i = PyLong_FromLong(int_total); + if (term_i == NULL) { + goto err_exit; + } + new_total = PyNumber_Add(total, term_i); + if (new_total == NULL) { + goto err_exit; + } + Py_SETREF(total, new_total); + new_total = NULL; + Py_CLEAR(term_i); + int_total = 0; // An ounce of prevention, ... + int_total_in_use = false; + } + } + + if (flt_path_enabled) { + + if (!finished) { + double flt_p, flt_q; + bool p_type_float = PyFloat_CheckExact(p_i); + bool q_type_float = PyFloat_CheckExact(q_i); + if (p_type_float && q_type_float) { + flt_p = PyFloat_AS_DOUBLE(p_i); + flt_q = PyFloat_AS_DOUBLE(q_i); + } else if (p_type_float && (PyLong_CheckExact(q_i) || PyBool_Check(q_i))) { + /* We care about float/int pairs and int/float pairs because + they arise naturally in several use cases such as price + times quantity, measurements with integer weights, or + data selected by a vector of bools. */ + flt_p = PyFloat_AS_DOUBLE(p_i); + flt_q = PyLong_AsDouble(q_i); + if (flt_q == -1.0 && PyErr_Occurred()) { + PyErr_Clear(); + goto finalize_flt_path; + } + } else if (q_type_float && (PyLong_CheckExact(p_i) || PyBool_Check(q_i))) { + flt_q = PyFloat_AS_DOUBLE(q_i); + flt_p = PyLong_AsDouble(p_i); + if (flt_p == -1.0 && PyErr_Occurred()) { + PyErr_Clear(); + goto finalize_flt_path; + } + } else { + goto finalize_flt_path; + } + DoubleLength new_flt_total = dl_fma(flt_total, flt_p, flt_q); + if (isfinite(new_flt_total.hi)) { + flt_total = new_flt_total; + flt_total_in_use = true; + Py_CLEAR(p_i); + Py_CLEAR(q_i); + continue; + } + } + + finalize_flt_path: + // We're finished, overflowed, have a non-float, or got a non-finite value + flt_path_enabled = false; + if (flt_total_in_use) { + term_i = PyFloat_FromDouble(dl_to_d(flt_total)); + if (term_i == NULL) { + goto err_exit; + } + new_total = PyNumber_Add(total, term_i); + if (new_total == NULL) { + goto err_exit; + } + Py_SETREF(total, new_total); + new_total = NULL; + Py_CLEAR(term_i); + flt_total = dl_zero(); + flt_total_in_use = false; + } + } + + assert(!int_total_in_use); + assert(!flt_total_in_use); + if (finished) { + goto normal_exit; + } + term_i = PyNumber_Multiply(p_i, q_i); + if (term_i == NULL) { + goto err_exit; + } + new_total = PyNumber_Add(total, term_i); + if (new_total == NULL) { + goto err_exit; + } + Py_SETREF(total, new_total); + new_total = NULL; + Py_CLEAR(p_i); + Py_CLEAR(q_i); + Py_CLEAR(term_i); + } + + normal_exit: + Py_DECREF(p_it); + Py_DECREF(q_it); + return total; + + err_exit: + Py_DECREF(p_it); + Py_DECREF(q_it); + Py_DECREF(total); + Py_XDECREF(p_i); + Py_XDECREF(q_i); + Py_XDECREF(term_i); + Py_XDECREF(new_total); + return NULL; +} + + /* pow can't use math_2, but needs its own wrapper: the problem is that an infinite result can arise either as a result of overflow (in which case OverflowError should be raised) or as a result of @@ -3933,6 +4257,7 @@ static PyMethodDef math_methods[] = { {"sqrt", math_sqrt, METH_O, math_sqrt_doc}, {"tan", math_tan, METH_O, math_tan_doc}, {"tanh", math_tanh, METH_O, math_tanh_doc}, + MATH_SUMPROD_METHODDEF MATH_TRUNC_METHODDEF MATH_PROD_METHODDEF MATH_PERM_METHODDEF