Skip to content

Commit

Permalink
bpo-37295: Optimize math.comb() and math.perm()
Browse files Browse the repository at this point in the history
Use divide-and-conquer algorithm for getting benefit of Karatsuba
multiplication of large numbers.

Do calculations for comb() in C unsigned long long instead of Python
integers if possible.
  • Loading branch information
serhiy-storchaka committed Oct 20, 2021
1 parent 70945d5 commit 9fd696a
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 95 deletions.
5 changes: 5 additions & 0 deletions Doc/whatsnew/3.11.rst
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,11 @@ Optimizations
* Pure ASCII strings are now normalized in constant time by :func:`unicodedata.normalize`.
(Contributed by Dong-hee Na in :issue:`44987`.)

* :mod:`math` functions :func:`~math.comb` and :func:`~math.perm` are now up
to 10 times or more faster for large arguments (the speed up is larger for
larger *k*).
(Contributed by Serhiy Storchaka in :issue:`37295`.)


CPython bytecode changes
========================
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Optimize :func:`math.comb` and :func:`math.perm`.
269 changes: 174 additions & 95 deletions Modules/mathmodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -3211,6 +3211,127 @@ math_prod_impl(PyObject *module, PyObject *iterable, PyObject *start)
}


/* Number of permutations and combinations.
* P(n, k) = n! / (n-k)!
* C(n, k) = P(n, k) / k!
*/

/* Calculate C(n, k) for n in the 63-bit range. */
static PyObject *
comb_small(unsigned long long n, unsigned long long k)
{
static const unsigned long long fast_comb_limits[] = {
#if SIZEOF_LONG_LONG >= 8
0, ULLONG_MAX, 4294967296ULL, 3329022, 102570, 13467, 3612, 1449, // 0-7
746, 453, 308, 227, 178, 147, 125, 110, // 8-15
99, 90, 84, 79, 75, 72, 69, 68, // 16-23
66, 65, 64, 63, 63, 62, 62, 62, // 24-31
#elif SIZEOF_LONG_LONG >= 4
0, ULLONG_MAX, 65536, 2049, 402, 161, 92, 63, // 0-7
49, 42, 37, 34, 33, 31, 31, 30, // 8-15
#endif
};

if (k == 0) {
return PyLong_FromLong(1);
}

/* For small enough n and k the result fits in the 64-bit range and can
* be calculated without allocating intermediate PyLong objects. */
if (k < Py_ARRAY_LENGTH(fast_comb_limits)
&& n <= fast_comb_limits[k])
{
unsigned long long result = n;
for (unsigned long long i = 1; i < k;) {
result *= --n;
result /= ++i;
}
return PyLong_FromUnsignedLongLong(result);
}

/* For larger n use recursive formula. */
/* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
unsigned long long j = k / 2;
PyObject *a, *b;
a = comb_small(n, j);
if (a == NULL) {
return NULL;
}
b = comb_small(n - j, k - j);
if (b == NULL) {
goto error;
}
Py_SETREF(a, PyNumber_Multiply(a, b));
Py_DECREF(b);
if (a != NULL) {
b = comb_small(k, j);
if (b == NULL) {
goto error;
}
Py_SETREF(a, PyNumber_FloorDivide(a, b));
Py_DECREF(b);
}
return a;

error:
Py_DECREF(a);
return NULL;
}

/* Calculate P(n, k) or C(n, k) using recursive formulas.
* It is more efficient than sequential multiplication thanks to
* Karatsuba multiplication.
*/
static PyObject *
perm_comb(PyObject *n, unsigned long long k, int iscomb)
{
if (k == 0) {
return PyLong_FromLong(1);
}
if (k == 1) {
Py_INCREF(n);
return n;
}

/* P(n, k) = P(n, j) * P(n-j, k-j) */
/* C(n, k) = C(n, j) * C(n-j, k-j) // C(k, j) */
unsigned long long j = k / 2;
PyObject *a, *b;
a = perm_comb(n, j, iscomb);
if (a == NULL) {
return NULL;
}
PyObject *t = PyLong_FromUnsignedLongLong(j);
if (t == NULL) {
goto error;
}
n = PyNumber_Subtract(n, t);
Py_DECREF(t);
if (n == NULL) {
goto error;
}
b = perm_comb(n, k - j, iscomb);
Py_DECREF(n);
if (b == NULL) {
goto error;
}
Py_SETREF(a, PyNumber_Multiply(a, b));
Py_DECREF(b);
if (iscomb && a != NULL) {
b = comb_small(k, j);
if (b == NULL) {
goto error;
}
Py_SETREF(a, PyNumber_FloorDivide(a, b));
Py_DECREF(b);
}
return a;

error:
Py_DECREF(a);
return NULL;
}

/*[clinic input]
math.perm
Expand All @@ -3234,9 +3355,9 @@ static PyObject *
math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=e021a25469653e23 input=5311c5a00f359b53]*/
{
PyObject *result = NULL, *factor = NULL;
PyObject *result = NULL;
int overflow, cmp;
long long i, factors;
long long ki;

if (k == Py_None) {
return math_factorial(module, n);
Expand All @@ -3250,6 +3371,7 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
Py_DECREF(n);
return NULL;
}
assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));

if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
Expand All @@ -3271,57 +3393,29 @@ math_perm_impl(PyObject *module, PyObject *n, PyObject *k)
goto error;
}

factors = PyLong_AsLongLongAndOverflow(k, &overflow);
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (overflow > 0) {
PyErr_Format(PyExc_OverflowError,
"k must not exceed %lld",
LLONG_MAX);
goto error;
}
else if (factors == -1) {
/* k is nonnegative, so a return value of -1 can only indicate error */
goto error;
}

if (factors == 0) {
result = PyLong_FromLong(1);
goto done;
}
assert(ki >= 0);

result = n;
Py_INCREF(result);
if (factors == 1) {
goto done;
}

factor = Py_NewRef(n);
PyObject *one = _PyLong_GetOne(); // borrowed ref
for (i = 1; i < factors; ++i) {
Py_SETREF(factor, PyNumber_Subtract(factor, one));
if (factor == NULL) {
goto error;
}
Py_SETREF(result, PyNumber_Multiply(result, factor));
if (result == NULL) {
goto error;
}
}
Py_DECREF(factor);
result = perm_comb(n, (unsigned long long)ki, 0);

done:
Py_DECREF(n);
Py_DECREF(k);
return result;

error:
Py_XDECREF(factor);
Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;
}


/*[clinic input]
math.comb
Expand All @@ -3347,9 +3441,9 @@ static PyObject *
math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
/*[clinic end generated code: output=bd2cec8d854f3493 input=9a05315af2518709]*/
{
PyObject *result = NULL, *factor = NULL, *temp;
PyObject *result = NULL, *temp;
int overflow, cmp;
long long i, factors;
long long ki, ni;

n = PyNumber_Index(n);
if (n == NULL) {
Expand All @@ -3360,6 +3454,7 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
Py_DECREF(n);
return NULL;
}
assert(PyLong_CheckExact(n) && PyLong_CheckExact(k));

if (Py_SIZE(n) < 0) {
PyErr_SetString(PyExc_ValueError,
Expand All @@ -3372,82 +3467,66 @@ math_comb_impl(PyObject *module, PyObject *n, PyObject *k)
goto error;
}

/* k = min(k, n - k) */
temp = PyNumber_Subtract(n, k);
if (temp == NULL) {
goto error;
}
if (Py_SIZE(temp) < 0) {
Py_DECREF(temp);
result = PyLong_FromLong(0);
goto done;
}
cmp = PyObject_RichCompareBool(temp, k, Py_LT);
if (cmp > 0) {
Py_SETREF(k, temp);
ni = PyLong_AsLongLongAndOverflow(n, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (!overflow) {
assert(ni >= 0);
ki = PyLong_AsLongLongAndOverflow(k, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (overflow || ki > ni) {
result = PyLong_FromLong(0);
goto done;
}
assert(ki >= 0);
ki = Py_MIN(ki, ni - ki);
if (ki > 1) {
result = comb_small((unsigned long long)ni,
(unsigned long long)ki);
goto done;
}
/* For k == 1 just return the original n in perm_comb(). */
}
else {
Py_DECREF(temp);
if (cmp < 0) {
/* k = min(k, n - k) */
temp = PyNumber_Subtract(n, k);
if (temp == NULL) {
goto error;
}
}

factors = PyLong_AsLongLongAndOverflow(k, &overflow);
if (overflow > 0) {
PyErr_Format(PyExc_OverflowError,
"min(n - k, k) must not exceed %lld",
LLONG_MAX);
goto error;
}
if (factors == -1) {
/* k is nonnegative, so a return value of -1 can only indicate error */
goto error;
}

if (factors == 0) {
result = PyLong_FromLong(1);
goto done;
}

result = n;
Py_INCREF(result);
if (factors == 1) {
goto done;
}

factor = Py_NewRef(n);
PyObject *one = _PyLong_GetOne(); // borrowed ref
for (i = 1; i < factors; ++i) {
Py_SETREF(factor, PyNumber_Subtract(factor, one));
if (factor == NULL) {
goto error;
if (Py_SIZE(temp) < 0) {
Py_DECREF(temp);
result = PyLong_FromLong(0);
goto done;
}
Py_SETREF(result, PyNumber_Multiply(result, factor));
if (result == NULL) {
goto error;
cmp = PyObject_RichCompareBool(temp, k, Py_LT);
if (cmp > 0) {
Py_SETREF(k, temp);
}

temp = PyLong_FromUnsignedLongLong((unsigned long long)i + 1);
if (temp == NULL) {
goto error;
else {
Py_DECREF(temp);
if (cmp < 0) {
goto error;
}
}
Py_SETREF(result, PyNumber_FloorDivide(result, temp));
Py_DECREF(temp);
if (result == NULL) {

ki = PyLong_AsLongLongAndOverflow(k, &overflow);
assert(overflow >= 0 && !PyErr_Occurred());
if (overflow) {
PyErr_Format(PyExc_OverflowError,
"min(n - k, k) must not exceed %lld",
LLONG_MAX);
goto error;
}
assert(ki >= 0);
}
Py_DECREF(factor);

result = perm_comb(n, (unsigned long long)ki, 1);

done:
Py_DECREF(n);
Py_DECREF(k);
return result;

error:
Py_XDECREF(factor);
Py_XDECREF(result);
Py_DECREF(n);
Py_DECREF(k);
return NULL;
Expand Down

0 comments on commit 9fd696a

Please sign in to comment.