New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Possible optimizations for math.comb() #81476
Comments
The implementation of math.comb() is nice in that it limits the size of intermediate values as much as possible and it avoids unnecessary steps when min(k,n-k) is small with respect to k. There are some ways it can be made better or faster:
For example comb(100, 55) is currently computed as: Instead, it could be computed as: This gives a nice speed-up in exchange for a little memory in an array of constants (for n=75, we would need an array of length 75//2 after exploiting symmetry). Almost all cases would should show some benefit and in favorable cases like comb(76, 20) the speed-up would be nearly 75x.
$ python3.8 -m timeit -r 11 -s 'from math import comb, factorial as fact' -s 'n=100_000' -s 'k = n//2' 'comb(n, k)'
1 loop, best of 11: 1.52 sec per loop
$ python3.8 -m timeit -r 11 -s 'from math import comb, factorial as fact' -s 'n=100_000' -s 'k = n//2' 'fact(n) // (fact(k) * fact(n-k))'
1 loop, best of 11: 518 msec per loop
|
Optimizations 2 and 3 look something like this: bc75 = [comb(75, r) for r in range(75//2+1)]
bc150 = [comb(150, r) for r in range(150//2+1)]
bc225 = [comb(225, r) for r in range(225//2+1)]
def comb(n, k):
if n < 0 or k < 0: raise ValueError
if k > n: return 0
k = min(k, n-k)
if k > n // 3 and n < 100_000:
return factorial(n) // (factorial(r) * factorial(n-r))
if 75 <= n <= 75 + k:
c, num, den, times = bc75[75-(n-k)], 75+1, 75-(n-k)+1, n-75
elif 150 <= n <= 150 + k:
c, num, den, times = bc150[150-(n-k)], 150+1, 150-(n-k)+1, n-150
elif 225 <= n <= 225 + k:
c, num, den, times = bc225[225-(n-k)], 225+1, 225-(n-k)+1, n-225
else:
c, num, den, times = 1, n-k+1, 1, k
for i in range(times):
c = c * num // den
num += 1
den += 1
return c |
(1), (4) and (5) sound good to me. For (1), it might make sense to ignore the 32-bit vs. 64-bit distinction and use (2) feels like too much extra complication to me, but that would become clearer with an implementation. For (3), I somewhat agree that the factorial method should be avoided. For (4), I don't see how/when the GIL could be released: doesn't the algorithm involve lots of memory allocations/deallocations and reference count adjustments? Can the suggested performance improvements go into 3.8, or should they wait for 3.9? It's not clear to me whether a performance improvement after feature freeze is okay or not. |
Performance improvements is what a beta build exists for in the first place. |
FWIW, here's a rough implementation of (3). It is limited to a specific range to avoid excess memory usage. For comb(50_000, 25_000), it gives a three-fold speedup: if (k_long > 20 && k_long > n_long / 3 && n_long <= 100000) {
PyObject *den = math_factorial(module, k); /* den = k! */
temp = PyNumber_Subtract(n, k);
Py_SETREF(temp, math_factorial(module, temp));
Py_SETREF(den, PyNumber_Multiply(den, temp)); /* den *= (n - k)! */
Py_DECREF(temp);
Py_SETREF(result, math_factorial(module, n)); /* result = n! */
Py_SETREF(result, PyNumber_FloorDivide(result, den)); /* result //= (n-k)! */
Py_DECREF(den);
return result;
}
Historically, we've used the beta phase for optimizations, tweaking APIs, and improving docs. However, I'm in no rush and this can easily wait for 3.9. My only concern is that the other math functions, except for factorial(), have bounded running times and memory usage, so performance is more of a concern for this function which could end-up being an unexpected bottleneck or being a vector for a DOS attack. That said, we haven't had any negative reports regarding factorial(), so this may be a low priority. |
In real life, I expect 99.999%+ of calls will be made with small arguments, so (1) is worth it. I like Mark's suggestion to use uint64_t so the acceptable range doesn't depend on platform. At least in the world I live in, 32-bit boxes are all but extinct anyway. I honestly wouldn't bother with more than that. It's fun to optimize giant-integer algorithms with an ever-ballooning number of different clever approaches, but Python is an odd place for that. People looking for blazing fast giant-int facilities generally want lots & lots of them, so are better steered toward, e.g., gmp. That's its reason for existing. For example, their implementation of binomial coefficients uses special division algorithms exploiting that the quotient is exact (no remainder): https://gmplib.org/manual/Exact-Division.html#Exact-Division There's just no end to potential speedups. But in Python, I expect a vast majority of users will be happy if they get the right answer for the number of possible poker hands ;-) Oh ya - some smart kid will file a bug report about the inability to interrupt the calculation of a billion-bit result, so (4) is inevitable. Me, I'd wait for them to complain, and encourage _them_ to learn something useful by writing a patch to fix it :-) |
+1! |
Here is more optimized PR inspired by PR 29020. It would be too long to explain how PR 29020 can be improved, so I write a new PR. Basically it implements Raymond's idea #1, but supports n>62 for smaller k. How to calculate limits: import math
n = m = 2**64
k = 1
while True:
nmax = int(math.ceil((m * math.factorial(k-1)) ** (1/k) + (k-1)/2)) + 100
n = min(n, nmax)
while math.comb(n, k) * k >= m:
n -= 1
if n < 2*k: break
print(k, n)
k += 1 |
Microbenchmarks: $ ./python -m pyperf timeit -s 'from math import comb' '[comb(n, k) for n in range(63) for k in range(n+1)]'
Mean +- std dev: 1.57 ms +- 0.07 ms -> 209 us +- 11 us: 7.53x faster
$ ./python -m pyperf timeit -s 'from math import comb' 'comb(62, 31)'
Mean +- std dev: 2.95 us +- 0.14 us -> 296 ns +- 11 ns: 9.99x faster
$ ./python -m pyperf timeit -s 'from math import comb' 'comb(110, 15)'
Mean +- std dev: 1.33 us +- 0.06 us -> 95.8 ns +- 3.1 ns: 13.86x faster
$ ./python -m pyperf timeit -s 'from math import comb' 'comb(1449, 7)'
Mean +- std dev: 689 ns +- 33 ns -> 59.0 ns +- 3.2 ns: 11.69x faster
$ ./python -m pyperf timeit -s 'from math import comb' 'comb(3329022, 3)'
Mean +- std dev: 308 ns +- 19 ns -> 57.2 ns +- 4.2 ns: 5.39x faster Now I want to try to optimize for larger arguments. Perhaps using recursive formula C(n, k) = C(n, j)*C(n-j, k-j)//C(k, j) where j=k//2 could help. |
Divide-and-conquer approach works pretty well for larger n. For results slightly out of the 64-bit range: $ ./python -m pyperf timeit -s 'from math import comb' 'comb(63, 31)'
Mean +- std dev: 2.80 us +- 0.14 us -> 388 ns +- 19 ns: 7.22x faster
$ ./python -m pyperf timeit -s 'from math import comb' 'comb(111, 15)'
Mean +- std dev: 1.24 us +- 0.06 us -> 215 ns +- 18 ns: 5.76x faster
$ ./python -m pyperf timeit -s 'from math import comb' 'comb(1450, 7)'
Mean +- std dev: 654 ns +- 45 ns -> 178 ns +- 13 ns: 3.67x faster
$ ./python -m pyperf timeit -s 'from math import comb' 'comb(3329023, 3)'
Mean +- std dev: 276 ns +- 15 ns -> 175 ns +- 11 ns: 1.58x faster For very large n: $ ./python -m pyperf timeit 'from math import comb' 'comb(2**100, 2**10)'
Mean +- std dev: 26.2 ms +- 1.7 ms -> 3.21 ms +- 0.20 ms: 8.16x faster
$ ./python -m pyperf timeit 'from math import comb' 'comb(2**1000, 2**10)'
Mean +- std dev: 704 ms +- 15 ms -> 103 ms +- 5 ms: 6.85x faster And it is faster than using factorial: $ ./python -m pyperf timeit -s 'from math import comb' 'comb(100_000, 50_000)'
Mean +- std dev: 1.61 sec +- 0.02 sec -> 177 ms +- 9 ms: 9.12x faster
$ ./python -m pyperf timeit -s 'from math import factorial as fact' 'fact(100_000) // (fact(50_000)*fact(50_000))'
Mean +- std dev: 507 ms +- 20 ms math.perm() can benefit from reusing the same code: $ ./python -m pyperf timeit -s 'from math import perm' 'perm(63, 31)'
Mean +- std dev: 1.35 us +- 0.07 us -> 1.18 us +- 0.06 us: 1.15x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(111, 15)'
Mean +- std dev: 601 ns +- 35 ns -> 563 ns +- 28 ns: 1.07x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(2**100, 2**10)'
Mean +- std dev: 5.96 ms +- 0.29 ms -> 2.32 ms +- 0.12 ms: 2.57x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(2**1000, 2**10)'
Mean +- std dev: 486 ms +- 14 ms -> 95.7 ms +- 4.2 ms: 5.08x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(100_000, 50_000)'
Mean +- std dev: 639 ms +- 23 ms -> 66.6 ms +- 3.2 ms: 9.60x faster Even in worst cases it is almost as fast as factorial: $ ./python -m pyperf timeit -s 'from math import perm' 'perm(100_000, 100_000)'
Mean +- std dev: 2.55 sec +- 0.02 sec -> 187 ms +- 8 ms: 13.66x faster
$ ./python -m pyperf timeit -s 'from math import factorial' 'factorial(100_000)'
Mean +- std dev: 142 ms +- 7 ms |
And with optimization of math.perm() for small arguments: $ ./python -m pyperf timeit -s 'from math import perm' 'perm(30, 14)'
Mean +- std dev: 524 ns +- 43 ns -> 66.7 ns +- 4.6 ns: 7.85x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(31, 14)'
Mean +- std dev: 522 ns +- 26 ns -> 127 ns +- 6 ns: 4.09x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(568, 7)'
Mean +- std dev: 318 ns +- 19 ns -> 62.9 ns +- 3.7 ns: 5.05x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(569, 7)'
Mean +- std dev: 311 ns +- 14 ns -> 114 ns +- 7 ns: 2.73x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(63, 31)'
Mean +- std dev: 1.36 us +- 0.08 us -> 263 ns +- 14 ns: 5.17x faster
$ ./python -m pyperf timeit -s 'from math import perm' 'perm(111, 15)'
Mean +- std dev: 595 ns +- 27 ns -> 126 ns +- 7 ns: 4.71x faster |
These speedups all to be significant and worth doing. |
I wrote a Python solution ("mycomb") that computes comb(100_000, 50_000) faster, maybe of interest: 1510.4 ms math.comb(n, k) The idea:
comb(13, 6) = ------------------------- = 13 * 1 * 11 * 1 * 3 * 4 It lists the numerator factors, then divides the denominator factors out of them (using primes), then just multiplies. Preparing the factors for the final multiplication took most of the time, about 23.1 ms. That part only needs numbers <= n, so it could be done with C ints and be much faster. If it's ten times faster, then mycomb in C would take 23.1/10 + (27.5-23.1) = 6.71 ms. See the comb_with_primes.py file. |
And for Raymond's case 4), about running very long and not responding to SIGINT, with n=1_000_000 and k=500_000: 150.91 seconds math.comb(n, k) And for n=10_000_000 and k=5_000_000: ~4 hours *estimation* for math.comb(n, k) |
Turns out for n=100_000, k=50_000, about 87% of my factors are 1, so they don't even need to be turned into Python ints for multiplication, improving the multiplication part to 3.05 ms. And a C++ version to produce the factors took 0.85 ms. Updated estimation: 1510.4 ms math.comb(n, k) |
For the small cases (say n < 50), we can get faster code by using a small (3Kb) table of factorial logarithms: double lf[50] = [log2(factorial(n)) for n in range(50)]; Then comb() and perm() function can be computed quickly and in constant time using the C99 math functions: result = PyLong_FromDouble(round(exp2(lf[n] - (lf[r] + lf[n-r])))); |
I ran some timings for comb(k, 67) on my macOS / Intel MacBook Pro, using timeit to time calls to a function that looked like this: def f(comb):
for k in range(68):
for _ in range(256):
comb(k, 67)
comb(k, 67)
... # 64 repetitions of comb(k, 67) in all Based on 200 timings of this script with each of the popcount approach and the uint8_t-table-of-trailing-zero-counts approach (interleaved), the popcount approach won, but just barely, at around 1.3% faster. The result was statistically significant (SciPy gave me a result of Ttest_indResult(statistic=19.929941828072433, pvalue=8.570975609117687e-62)). Interestingly, the default build on macOS/Intel is _not_ using the dedicated POPCNT instruction that arrived with the Nehalem architecture, presumably because it wants to produce builds that will still be useable on pre-Nehalem machines. It uses Clang's __builtin_popcount, but that gets translated to the same SIMD-within-a-register approach that we have already in pycore_bitutils.h. If I recompile with -msse4.2, then the POPCNT instruction *is* used, and I get an even more marginal improvement: a 1.7% speedup over the lookup-table-based version. |
[Mark]
I'm assuming you meant to write comb(67, k) instead, since the comb(k, 67) given is 0 at all tested k values except for k=67, and almost never executes any of the code in question. It's surprising to me that even the long-winded popcount code was faster! The other way needs to read up 3 1-byte values from a trailing zero table, but the long-winded popcount emulation needs to read up 4 4-byte mask constants (or are they embedded in the instruction stream?), in addition to doing many more bit-fiddling operations (4 shifts, 4 "&" masks, 3 add/subtract, and a multiply - compared to just 2 add/subtract). So if the results are right, Intel timings make no sense to me at all ;-) |
Aargh! That is of course what I meant, but not in fact what I timed. :-( I'll redo the timings. Please disregard the previous message. |
!!! Even more baffling then. Seems like the code posted got out of math_comb_impl() early here: if (overflow || ki > ni) {
result = PyLong_FromLong(0);
goto done;
} 67 out of every 68 times comb() was called, before any actual ;-) computation was even tried. Yet one way was significantly faster than the other overall, despite that they were so rarely executed at all? Something ... seems off here ;-) |
PR 30305 applies Mark's algorithm for larger n (up to 127) depending on k, as was suggested by Raymond. Note that it uses different table for limits, which maps k to maximal n. |
Thanks Tim for spotting the stupid mistake. The reworked timings are a bit more ... plausible. tl;dr: On my machine, Raymond's suggestion gives a 2.2% speedup in the case where POPCNT is not available, and a 0.45% slowdown in the case that it _is_ available. Given that, and the fact that a single-instruction population count is not as readily available as I thought it was, I'd be happy to change the implementation to use the trailing zero counts as suggested. I'll attach the scripts I used for timing and analysis. There are two of them: "timecomb.py" produces a single timing. "driver.py" repeatedly switches branches, re-runs make, runs "timecomb.py", then assembles the results. I ran the driver.py script twice: once with a regular Output from the script for the normal ./configure
Output for CFLAGS="-march=haswell":
|
Thanks. I think that is a portability win and will made the code a lot easier to explain. |
Done in #74498 (though this will conflict with Serhiy's PR). |
I've been experimenting with a modification of Serhiy's recurrence relation, using a different value of j rather than j=k//2. The current design splits-off three ways when it recurses, so the number of calls grows quickly. For C(200,100), C(225,112), and C(250,125), the underlying 64 bit modular arithmetic routine is called 115 times, 150 times, and 193 times respectively. But with another 2kb of precomputed values, it drops to 3, 16, and 26 calls. The main idea is to precompute one diagonal of Pascal's triangle, starting where the 64-bit mod arithmetic version leaves off and going through a limit as high as we want, depending on our tolerance for table size. A table for C(n, 20) where 67 < n <= 225 takes 2101 bytes. The new routine adds one line and modifies one line from the current code: def C(n, k):
k = min(k, n - k)
if k == 0: return 1
if k == 1: return n
if k < len(k2n) and n <= k2n[k]: return ModArith64bit(n, k)
if k == FixedJ and n <= Jlim: return lookup_known(n) # New line
j = min(k // 2, FixedJ) # Modified
return C(n, j) * C(n-j, k-j) // C(k, j) The benefit of pinning j to match the precomputed diagonal is that two of the three splits-offs are to known values where no further work is necessary. Given a table for C(n, 20), we get:
A proof of concept is attached. To make it easy to experiment with, the precomputed diagonal is stored in a dictionary. At the bottom, I show an equivalent function to be used in a C version. It looks promising at this point, but I haven't run timings, so I am not sure this is a net win. |
One fixup:
+ j = FixedJ if k > FixedJ else k // 2 With that fix, the number of 64-bit mod arithmetic calls drops to 3, 4, and 20 for C(200,100), C(225,112), and C(250,125). The compares to 115, 150, and 193 calls in the current code. |
Just noting that comb_pole.py requires a development version of Python to run (under all released versions, a byteorder argument is required for int.{to, from}_byte() calls). |
Just posted an update that runs on 3.8 or later. |
Ran some timings on the pure python version with and without the precomputed diagonal: [C(n, k) for k in range(n+1)]
n=85 156 usec 160 usec As expected, the best improvement comes in the range of the precomputed combinations: C(129, 20) through C(250, 20). To get a better idea of where the benefit is coming from, here is a trace from the recursive part of the new code: >>> C(240, 112)
C(240, 112) = C(240, 20) * C(220, 92) // C(112, 20)
C(220, 92) = C(220, 20) * C(200, 72) // C(92, 20)
C(200, 72) = C(200, 20) * C(180, 52) // C(72, 20)
C(180, 52) = C(180, 20) * C(160, 32) // C(52, 20)
C(160, 32) = C(160, 20) * C(140, 12) // C(32, 20)
C(140, 12) = C(140, 6) * C(134, 6) // C(12, 6)
C(140, 6) = C(140, 3) * C(137, 3) // C(6, 3)
C(140, 3) = C(140, 1) * C(139, 2) // C(3, 1)
C(139, 2) = C(139, 1) * C(138, 1) // C(2, 1)
C(137, 3) = C(137, 1) * C(136, 2) // C(3, 1)
C(136, 2) = C(136, 1) * C(135, 1) // C(2, 1)
C(134, 6) = C(134, 3) * C(131, 3) // C(6, 3)
C(134, 3) = C(134, 1) * C(133, 2) // C(3, 1)
C(133, 2) = C(133, 1) * C(132, 1) // C(2, 1)
C(131, 3) = C(131, 1) * C(130, 2) // C(3, 1)
C(130, 2) = C(130, 1) * C(129, 1) // C(2, 1)
53425668586973006090333976236624193248148570813984466591034254245235155 Here is a trace for the old code without the precomputed diagonal: >>> C(240, 112)
C(240, 112) = C(240, 56) * C(184, 56) // C(112, 56)
C(240, 56) = C(240, 28) * C(212, 28) // C(56, 28)
C(240, 28) = C(240, 14) * C(226, 14) // C(28, 14)
C(240, 14) = C(240, 7) * C(233, 7) // C(14, 7)
C(240, 7) = C(240, 3) * C(237, 4) // C(7, 3)
C(240, 3) = C(240, 1) * C(239, 2) // C(3, 1)
C(239, 2) = C(239, 1) * C(238, 1) // C(2, 1)
C(237, 4) = C(237, 2) * C(235, 2) // C(4, 2)
C(237, 2) = C(237, 1) * C(236, 1) // C(2, 1)
C(235, 2) = C(235, 1) * C(234, 1) // C(2, 1)
C(233, 7) = C(233, 3) * C(230, 4) // C(7, 3)
C(233, 3) = C(233, 1) * C(232, 2) // C(3, 1)
C(232, 2) = C(232, 1) * C(231, 1) // C(2, 1)
C(230, 4) = C(230, 2) * C(228, 2) // C(4, 2)
C(230, 2) = C(230, 1) * C(229, 1) // C(2, 1)
C(228, 2) = C(228, 1) * C(227, 1) // C(2, 1)
C(226, 14) = C(226, 7) * C(219, 7) // C(14, 7)
C(226, 7) = C(226, 3) * C(223, 4) // C(7, 3)
C(226, 3) = C(226, 1) * C(225, 2) // C(3, 1)
C(225, 2) = C(225, 1) * C(224, 1) // C(2, 1)
C(223, 4) = C(223, 2) * C(221, 2) // C(4, 2)
C(223, 2) = C(223, 1) * C(222, 1) // C(2, 1)
C(221, 2) = C(221, 1) * C(220, 1) // C(2, 1)
C(219, 7) = C(219, 3) * C(216, 4) // C(7, 3)
C(219, 3) = C(219, 1) * C(218, 2) // C(3, 1)
C(218, 2) = C(218, 1) * C(217, 1) // C(2, 1)
C(216, 4) = C(216, 2) * C(214, 2) // C(4, 2)
C(216, 2) = C(216, 1) * C(215, 1) // C(2, 1)
C(214, 2) = C(214, 1) * C(213, 1) // C(2, 1)
C(212, 28) = C(212, 14) * C(198, 14) // C(28, 14)
C(212, 14) = C(212, 7) * C(205, 7) // C(14, 7)
C(212, 7) = C(212, 3) * C(209, 4) // C(7, 3)
C(212, 3) = C(212, 1) * C(211, 2) // C(3, 1)
C(211, 2) = C(211, 1) * C(210, 1) // C(2, 1)
C(209, 4) = C(209, 2) * C(207, 2) // C(4, 2)
C(209, 2) = C(209, 1) * C(208, 1) // C(2, 1)
C(207, 2) = C(207, 1) * C(206, 1) // C(2, 1)
C(205, 7) = C(205, 3) * C(202, 4) // C(7, 3)
C(205, 3) = C(205, 1) * C(204, 2) // C(3, 1)
C(204, 2) = C(204, 1) * C(203, 1) // C(2, 1)
C(202, 4) = C(202, 2) * C(200, 2) // C(4, 2)
C(202, 2) = C(202, 1) * C(201, 1) // C(2, 1)
C(200, 2) = C(200, 1) * C(199, 1) // C(2, 1)
C(198, 14) = C(198, 7) * C(191, 7) // C(14, 7)
C(198, 7) = C(198, 3) * C(195, 4) // C(7, 3)
C(198, 3) = C(198, 1) * C(197, 2) // C(3, 1)
C(197, 2) = C(197, 1) * C(196, 1) // C(2, 1)
C(195, 4) = C(195, 2) * C(193, 2) // C(4, 2)
C(195, 2) = C(195, 1) * C(194, 1) // C(2, 1)
C(193, 2) = C(193, 1) * C(192, 1) // C(2, 1)
C(191, 7) = C(191, 3) * C(188, 4) // C(7, 3)
C(191, 3) = C(191, 1) * C(190, 2) // C(3, 1)
C(190, 2) = C(190, 1) * C(189, 1) // C(2, 1)
C(188, 4) = C(188, 2) * C(186, 2) // C(4, 2)
C(188, 2) = C(188, 1) * C(187, 1) // C(2, 1)
C(186, 2) = C(186, 1) * C(185, 1) // C(2, 1)
C(184, 56) = C(184, 28) * C(156, 28) // C(56, 28)
C(184, 28) = C(184, 14) * C(170, 14) // C(28, 14)
C(184, 14) = C(184, 7) * C(177, 7) // C(14, 7)
C(184, 7) = C(184, 3) * C(181, 4) // C(7, 3)
C(184, 3) = C(184, 1) * C(183, 2) // C(3, 1)
C(183, 2) = C(183, 1) * C(182, 1) // C(2, 1)
C(181, 4) = C(181, 2) * C(179, 2) // C(4, 2)
C(181, 2) = C(181, 1) * C(180, 1) // C(2, 1)
C(179, 2) = C(179, 1) * C(178, 1) // C(2, 1)
C(177, 7) = C(177, 3) * C(174, 4) // C(7, 3)
C(177, 3) = C(177, 1) * C(176, 2) // C(3, 1)
C(176, 2) = C(176, 1) * C(175, 1) // C(2, 1)
C(174, 4) = C(174, 2) * C(172, 2) // C(4, 2)
C(174, 2) = C(174, 1) * C(173, 1) // C(2, 1)
C(172, 2) = C(172, 1) * C(171, 1) // C(2, 1)
C(170, 14) = C(170, 7) * C(163, 7) // C(14, 7)
C(170, 7) = C(170, 3) * C(167, 4) // C(7, 3)
C(170, 3) = C(170, 1) * C(169, 2) // C(3, 1)
C(169, 2) = C(169, 1) * C(168, 1) // C(2, 1)
C(167, 4) = C(167, 2) * C(165, 2) // C(4, 2)
C(167, 2) = C(167, 1) * C(166, 1) // C(2, 1)
C(165, 2) = C(165, 1) * C(164, 1) // C(2, 1)
C(163, 7) = C(163, 3) * C(160, 4) // C(7, 3)
C(163, 3) = C(163, 1) * C(162, 2) // C(3, 1)
C(162, 2) = C(162, 1) * C(161, 1) // C(2, 1)
C(160, 4) = C(160, 2) * C(158, 2) // C(4, 2)
C(160, 2) = C(160, 1) * C(159, 1) // C(2, 1)
C(158, 2) = C(158, 1) * C(157, 1) // C(2, 1)
C(156, 28) = C(156, 14) * C(142, 14) // C(28, 14)
C(156, 14) = C(156, 7) * C(149, 7) // C(14, 7)
C(156, 7) = C(156, 3) * C(153, 4) // C(7, 3)
C(156, 3) = C(156, 1) * C(155, 2) // C(3, 1)
C(155, 2) = C(155, 1) * C(154, 1) // C(2, 1)
C(153, 4) = C(153, 2) * C(151, 2) // C(4, 2)
C(153, 2) = C(153, 1) * C(152, 1) // C(2, 1)
C(151, 2) = C(151, 1) * C(150, 1) // C(2, 1)
C(149, 7) = C(149, 3) * C(146, 4) // C(7, 3)
C(149, 3) = C(149, 1) * C(148, 2) // C(3, 1)
C(148, 2) = C(148, 1) * C(147, 1) // C(2, 1)
C(146, 4) = C(146, 2) * C(144, 2) // C(4, 2)
C(146, 2) = C(146, 1) * C(145, 1) // C(2, 1)
C(144, 2) = C(144, 1) * C(143, 1) // C(2, 1)
C(142, 14) = C(142, 7) * C(135, 7) // C(14, 7)
C(142, 7) = C(142, 3) * C(139, 4) // C(7, 3)
C(142, 3) = C(142, 1) * C(141, 2) // C(3, 1)
C(141, 2) = C(141, 1) * C(140, 1) // C(2, 1)
C(139, 4) = C(139, 2) * C(137, 2) // C(4, 2)
C(139, 2) = C(139, 1) * C(138, 1) // C(2, 1)
C(137, 2) = C(137, 1) * C(136, 1) // C(2, 1)
C(135, 7) = C(135, 3) * C(132, 4) // C(7, 3)
C(135, 3) = C(135, 1) * C(134, 2) // C(3, 1)
C(134, 2) = C(134, 1) * C(133, 1) // C(2, 1)
C(132, 4) = C(132, 2) * C(130, 2) // C(4, 2)
C(132, 2) = C(132, 1) * C(131, 1) // C(2, 1)
C(130, 2) = C(130, 1) * C(129, 1) // C(2, 1)
C(112, 56) = C(112, 28) * C(84, 28) // C(56, 28)
C(112, 28) = C(112, 14) * C(98, 14) // C(28, 14)
C(84, 28) = C(84, 14) * C(70, 14) // C(28, 14)
53425668586973006090333976236624193248148570813984466591034254245235155 |
A feature of the current code is that, while the recursion tree can be very wide, it's not tall, with max depth proportional to the log of k. But it's proportional to k in the proposal (the C(n-j, k-j) term's second argument goes down by at most 20 per recursion level). So, e.g., C(1000000, 500000) dies with RecursionError in Python; in C, whatever platform-specific weird things can happen when the C stack is blown. The width of the recursion tree could be slashed in any case by "just dealing with" (say) k <= 20 directly, no matter how large n is. Do the obvious loop with k-1 multiplies, and k-1 divides by the tiny ints in range(2, k+1). Note that "almost all" the calls in your "trace for the old code" listing are due to recurring for k <= 20. Or use Stefan's method: if limited to k <= 20, it only requires a tiny precomputed table of the 8 primes <= 20, and a stack array to hold range(n, n-k, -1); that can be arranged to keep Karatsuba in play if n is large. An irony is that the primary point of the recurrence is to get Karatsuba multiplication into play, but the ints involved in computing C(240, 112) are nowhere near big enough to trigger that. To limit recursion depth, I think you have to change your approach to decide in advance the deepest you're willing to risk letting it go, and keep the current j = k // 2 whenever repeatedly subtracting 20 could exceed that. |
Okay, will set a cap on the n where a fixedj is used. Also, making a direct computation for k<20 is promising. |
def comb64(n, k):
'comb(n, k) in multiplicative group modulo 64-bits'
return (F[n] * Finv[k] * Finv[n-k] & (2**64-1)) << (S[n] - S[k] - S[n - k])
def comb_iterative(n, k):
'Straight multiply and divide when k is small.'
result = 1
for r in range(1, k+1):
result *= n - r + 1
result //= r
return result
def C(n, k):
k = min(k, n - k)
if k == 0: return 1
if k == 1: return n
if k < len(k2n) and n <= k2n[k]: return comb64(n, k) # 64-bit fast case
if k == FixedJ and n <= Jlim: return KnownComb[n] # Precomputed diagonal
if k < 10: return comb_iterative(n, k) # Non-recursive for small k
j = FixedJ if k > FixedJ and n <= Jlim else k // 2
return C(n, j) * C(n-j, k-j) // C(k, j) # Recursive case |
I was thinking about comb(1000000, 500000) The simple "* --n / ++k" loop does 499,999 each of multiplication and division, and in all instances the second operand is a single Python digit. Cheap as can be. In contrast, despite that it short-circuits all "small k" cases, comb_pole2.py executes
7,299,598 times. Far more top-level arithmetic operations, including extra-expensive long division with both operands multi-digit. At the very top of the tree, "// C(500000, 250000)" is dividing out nearly half a million bits. But a tiny Python function coding the simple loop takes about 150 seconds, while Raymond's C() about 30 (under the released 3.10.1). The benefits from provoking Karatsuba are major. Just for the heck of it, I coded a complication of the simple loop that just tries to provoke Karatsuba on the numerator (prod(range(n, n-k, -1))), and dividing that just once at the end, by factorial(k). That dropped the time to about 17.5 seconds. Over 80% of that time is spent computing the "return" expression, where len(p) is 40 and only 7 entries pass the x>1 test:
That is, with really big results, almost everything is mostly noise compared to the time needed to do the relative handful of * and // on the very largest intermediate results. Stefan's code runs much faster still, but requires storage proportional to k. The crude hack I used needs a fixed (independent of k and n) and small amount of temp storage, basically grouping inputs into buckets roughly based on the log _of_ their log, and keeping only one int per bucket. Here's the code: def pcomb2(n, k):
from math import prod, factorial
def fold_into_p(x):
if x == 1:
return
while True:
i = max(x.bit_length().bit_length() - 5, 0)
if p[i] == 1:
p[i] = x
break
x *= p[i]
p[i] = 1
def showp():
for i in range(PLEN):
pi = p[i]
if pi > 1:
print(i, pi.bit_length())
for i in range(k):
fold_into_p(n)
n -= 1
showp()
return prod(x for x in p if x > 1) // factorial(k) I'm not sure it's practical. For example, while the list p[] can be kept quite small, factorial(k) can require a substantial amount of temp storage of its own - factorial(500000) in the case at hand is an int approaching 9 million bits. Note: again in the case at hand, computing it via
takes about 10% longer than Raymond's C() function. |
Another trick, building on the last one: computing factorial(k) isn't cheap, in time or space, and neither is dividing by it. But we know it will entirely cancel out. Indeed, for each outer loop iteration, prod(p) is divisible by the current k. But, unlike as in Stefan's code, which materializes range(n, n-k, -1) as an explicit list, we have no way to calculate "in advance" which elements of p[] are divisible by what. What we _can_ do is march over all of p[], and do a gcd of each element with the current k. If greater than 1, it can be divided out of both that element of p[], and the current k. Later, rinse, repeat - the current k must eventually be driven to 1 then. But that slows things down: gcd() is also expensive. But there's a standard trick to speed that too: as in serious implementations of Pollard's rho factorization method, "chunk it". That is, don't do it on every outer loop iteration, but instead accumulate the running product of several denominators first, then do the expensive gcd pass on that product. Here's a replacement for "the main loop" of the last code that delays doing gcds until the running product is at least 2000 bits: fold_into_p(n)
kk = 1
for k in range(2, k+1):
n -= 1
# Merge into p[].
fold_into_p(n)
# Divide by k.
kk *= k
if kk.bit_length() < 2000:
continue
for i, pi in enumerate(p):
if pi > 1:
g = gcd(pi, kk)
if g > 1:
p[i] = pi // g
kk //= g
if kk == 1:
break
assert kk == 1
showp()
return prod(x for x in p if x > 1) // kk That runs in under half the time (for n=1000000, k=500000), down to under 7.5 seconds. And, of course, the largest denominator consumes only about 2000 bits instead of 500000!'s 8,744,448 bits. Raising the kk bit limit from 2000 to 10000 cuts another 2.5 seconds off, down to about 5 seconds. Much above that, it starts getting slower again. Seems to hard to out-think! And highly dubious to fine-tune it based on a single input case ;-) Curious: at a cutoff of 10000 bits, we're beyond the point where Karatsuba would have paid off for computing denominator partial products too. |
All this should be tested with the C implementation because relative cost of operations is different in C and Python. I have tested Raymond's idea about using iterative algorithm for small k. $ ./python -m timeit -s 'from math import comb' "comb(3329023, 3)"
recursive: Mean +- std dev: 173 ns +- 9 ns
iterative: Mean +- std dev: 257 ns +- 13 ns
$ ./python -m pyperf timeit -s 'from math import comb' "comb(102571, 4)"
recursive: Mean +- std dev: 184 ns +- 10 ns
iterative: Mean +- std dev: 390 ns +- 29 ns
$ ./python -m pyperf timeit -s 'from math import comb' "comb(747, 8)"
recursive: Mean +- std dev: 187 ns +- 10 ns
iterative: Mean +- std dev: 708 ns +- 39 ns Recursive algorithm is always faster than iterative one for k>2 (they are equal for k=1 and k=2). And it is not only because of division, because for perm() we have the same difference. $ ./python -m pyperf timeit -s 'from math import perm' "perm(2642247, 3)"
recursive: Mean +- std dev: 118 ns +- 7 ns
iterative: Mean +- std dev: 147 ns +- 8 ns
$ ./python -m pyperf timeit -s 'from math import perm' "perm(65538, 4)"
recursive: Mean +- std dev: 130 ns +- 9 ns
iterative: Mean +- std dev: 203 ns +- 13 ns
$ ./python -m pyperf timeit -s 'from math import perm' "perm(260, 8)"
recursive: Mean +- std dev: 131 ns +- 10 ns
iterative: Mean +- std dev: 324 ns +- 16 ns As for the idea about using a table for fixed k=20, note that comb(87, 20) exceeds 64 bits, so we will need to use a table of 128-bit integers. And I am not sure if this algorithm will be faster than the recursive one. We may achieve better results for lesser cost if extend Mark's algorithm to use 128-bit integers. I am not sure whether it is worth, the current code is good enough and cover the wide range of cases. Additional optimizations will likely have lesser effort/benefit ratio. |
comb(n, k) can be computed as perm(n, k) // factorial(k). $ ./python -m timeit -r1 -n1 -s 'from math import comb' "comb(1000000, 500000)"
recursive: 1 loop, best of 1: 9.16 sec per loop
iterative: 1 loop, best of 1: 164 sec per loop
$ ./python -m timeit -r1 -n1 -s 'from math import perm, factorial' "perm(1000000, 500000) // factorial(500000)"
recursive: 1 loop, best of 1: 19.8 sec per loop
iterative: 1 loop, best of 1: 137 sec per loop It is slightly faster than division on every step if use the iterative algorithm, but still much slower than the recursive algorithm. And the latter if faster if perform many small divisions and keep intermediate results smaller. |
Ya, I don't expect anyone will check in a change without doing comparative timings in C first. Not worried about that. I'd be happy to declare victory and move on at this point ;-) But that's me. Near the start of this, I noted that we just won't compete with GMP's vast array of tricks. I noted that they use a special routine for division when it's known in advance that the remainder is 0 (as it is, e.g., in every division performed by "our" recursion (which GMP also uses, in some cases)). But I didn't let on just how much that can buy them. Under 3.10.1 on my box, the final division alone in math.factorial(1000000) // math.factorial(500000)**2 takes over 20 seconds. But a pure Python implementation of what I assume (don't know for sure) is the key idea(*) in their exact-division algorithm does the same thing in under 0.4 seconds. Huge difference - and while the pure Python version only ever wants the lower N bits of an NxN product, there's no real way to do that in pure Python except via throwing away the higher N bits of a double-width int product. In C, of course, the high half of the bits wouldn't be computed to begin with. (*) Modular arithmetic again. Given n and d such that it's known n = q*d for some integer q, shift n and d right until d is odd. q is unaffected. A good upper bound on the bit length of q is then n.bit_length() - d.bit_length() + 1. Do the remaining work modulo 2 raised to that power. Call that base B. We "merely" need to solve for q in the equation n = q*d (mod B). Because d is odd, I = pow(d, -1, B) exists. Just multiply both sides by I to get n * I = q (mod B). No divisions of any kind are needed. More, there's also a very efficient, division-free algorithm for finding an inverse modulo a power of 2. To start with, every odd int is its own inverse mod 8, so we start with 3 good bits. A modular Newton-like iteration can double the number of correct bits on each iteration. But I won't post code (unless someone asks) because I don't want to encourage anyone :-) |
Okay, I'll ask. |
OK, here's the last version I had. Preconditions are that d > 0, n > 0, and n % d == 0. This version tries to use the narrowest possible integers on each step. The lowermost Taking out all the modular stuff, the body of the loop boils down to just
For insight, if
for some k and i (IOW, if dinv * d = 1 modulo 2**i), then
and so dinv times that equals 1 - k**2 * 2**(2*i). Or, IOW, the next value of dinv is such that d * dinv = 1 modulo 2**(2*i) - it's good to twice as many bits. def ediv(n, d):
assert d
def makemask(n):
return (1 << n) - 1
if d & 1 == 0:
ntz = (d & -d).bit_length() - 1
n >>= ntz
d >>= ntz
bits_needed = n.bit_length() - d.bit_length() + 1
good_bits = 3
dinv = d & 7
while good_bits < bits_needed:
twice = min(2 * good_bits, bits_needed)
twomask = makemask(twice)
fac2 = dinv * (d & twomask)
fac2 &= twomask
fac2 = (2 - fac2) & twomask
dinv = (dinv * fac2) & twomask
good_bits = twice
goodmask = makemask(bits_needed)
return ((dinv & goodmask) * (n & goodmask)) & goodmask |
n
inmath.comb
#29020Note: these values reflect the state of the issue at the time it was migrated and might not reflect the current state.
Show more details
GitHub fields:
bugs.python.org fields:
The text was updated successfully, but these errors were encountered: