Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Possible optimizations for math.comb() #81476

Open
rhettinger opened this issue Jun 15, 2019 · 73 comments
Open

Possible optimizations for math.comb() #81476

rhettinger opened this issue Jun 15, 2019 · 73 comments
Labels
3.11 3.12 performance Performance or resource usage stdlib Python modules in the Lib dir

Comments

@rhettinger
Copy link
Contributor

rhettinger commented Jun 15, 2019

BPO 37295
Nosy @tim-one, @rhettinger, @mdickinson, @serhiy-storchaka, @PedanticHacker, @mcognetta, @pochmann
PRs
  • bpo-37295: Add fast path for small n in math.comb #29020
  • bpo-37295: Optimize math.comb() for small arguments #29030
  • bpo-37295: Optimize math.comb() and math.perm() #29090
  • bpo-37295: Speed up math.comb(n, k) for 0 <= k <= n <= 67 #30275
  • bpo-37295: Use constant-time comb() for larger n depending on k #30305
  • bpo-37295: More direct computation of power-of-two factor in math.comb #30313
  • Files
  • comb_with_primes.py
  • comb_with_side_limits.py: Compute limits for comb_small
  • timecomb.py
  • driver.py
  • comb_pole.py: Experiment with precomputed diagonal -- runs on 3.8 or later
  • comb_pole2.py: Precomputed k diagonal and handling for small k
  • Note: these values reflect the state of the issue at the time it was migrated and might not reflect the current state.

    Show more details

    GitHub fields:

    assignee = None
    closed_at = None
    created_at = <Date 2019-06-15.19:17:40.265>
    labels = ['library', 'performance']
    title = 'Possible optimizations for math.comb()'
    updated_at = <Date 2022-01-23.22:42:15.922>
    user = 'https://github.com/rhettinger'

    bugs.python.org fields:

    activity = <Date 2022-01-23.22:42:15.922>
    actor = 'tim.peters'
    assignee = 'none'
    closed = False
    closed_date = None
    closer = None
    components = ['Library (Lib)']
    creation = <Date 2019-06-15.19:17:40.265>
    creator = 'rhettinger'
    dependencies = []
    files = ['50439', '50526', '50530', '50531', '50557', '50559']
    hgrepos = []
    issue_num = 37295
    keywords = ['patch']
    message_count = 73.0
    messages = ['345711', '345731', '345761', '345784', '345876', '345952', '345956', '404184', '404208', '404429', '404472', '406300', '406337', '406338', '406341', '407736', '408867', '408952', '408954', '408962', '408970', '408984', '408989', '408994', '409000', '409001', '409002', '409010', '409024', '409028', '409062', '409095', '409110', '409116', '409154', '409254', '409268', '409269', '409270', '409274', '409277', '409298', '409315', '409317', '409320', '409321', '409331', '409346', '409360', '409373', '409376', '409377', '409382', '409383', '409393', '409417', '409432', '410145', '410339', '410347', '410382', '410383', '410387', '410420', '410448', '410460', '410543', '410587', '410925', '410931', '411343', '411427', '411430']
    nosy_count = 7.0
    nosy_names = ['tim.peters', 'rhettinger', 'mark.dickinson', 'serhiy.storchaka', 'PedanticHacker', 'mcognetta', 'Stefan Pochmann']
    pr_nums = ['29020', '29030', '29090', '30275', '30305', '30313']
    priority = 'normal'
    resolution = None
    stage = 'patch review'
    status = 'open'
    superseder = None
    type = 'performance'
    url = 'https://bugs.python.org/issue37295'
    versions = []

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jun 15, 2019

    The implementation of math.comb() is nice in that it limits the size of intermediate values as much as possible and it avoids unnecessary steps when min(k,n-k) is small with respect to k.

    There are some ways it can be made better or faster:

    1. For small values of n, there is no need to use PyLong objects at every step. Plain C integers would suffice and we would no longer need to make repeated allocations for all the intermediate values. For 32-bit builds, this could be done for n<=30. For 64-bit builds, it applies to n<=62. Adding this fast path would likely give a 10x to 50x speed-up.

    2. For slightly larger values of n, the computation could be sped-up by precomputing one or more rows of binomial coefficients (perhaps for n=75, n=150, and n=225). The calculation could then start from that row instead of from higher rows on Pascal's triangle.

    For example comb(100, 55) is currently computed as:
    comb(55,1) * 56 // 2 * 57 // 3 ... 100 // 45 <== 45 steps

    Instead, it could be computed as:
    comb(75,20) * 76 // 21 * 77 // 22 ... 100 / 45 <== 25 steps
    ^-- found by table lookup

    This gives a nice speed-up in exchange for a little memory in an array of constants (for n=75, we would need an array of length 75//2 after exploiting symmetry). Almost all cases would should show some benefit and in favorable cases like comb(76, 20) the speed-up would be nearly 75x.

    1. When k is close to n/2, the current algorithm is slower than just computing (n!) / (k! * (n-k)!). However, the factorial method comes at the cost of more memory usage for large n. The factorial method consumes memory proportional to n*log2(n) while the current early-cancellation method uses memory proportional to n+log2(n). Above some threshold for memory pain, the current method should always be preferred. I'm not sure the factorial method should be used at all, but it is embarrassing that factorial calls sometimes beat the current C implementation:
        $ python3.8 -m timeit -r 11 -s 'from math import comb, factorial as fact' -s 'n=100_000' -s 'k = n//2' 'comb(n, k)'
        1 loop, best of 11: 1.52 sec per loop
        $ python3.8 -m timeit -r 11 -s 'from math import comb, factorial as fact' -s 'n=100_000' -s 'k = n//2' 'fact(n) // (fact(k) * fact(n-k))'
        1 loop, best of 11: 518 msec per loop
    1. For values such as n=1_000_000 and k=500_000, the running time is very long and the routine doesn't respond to SIGINT. We could add checks for keyboard interrupts for large n. Also consider releasing the GIL.

    2. The inner-loop current does a pure python subtraction than could in many cases be done with plain C integers. When n is smaller than maxsize, we could have a code path that replaces "PyNumber_Subtract(factor, _PyLong_One)" with something like "PyLong_FromUnsignedLongLong((unsigned long long)n - i)".

    @rhettinger rhettinger added 3.8 3.9 stdlib Python modules in the Lib dir performance Performance or resource usage labels Jun 15, 2019
    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jun 16, 2019

    Optimizations 2 and 3 look something like this:

        bc75 = [comb(75, r) for r in range(75//2+1)]
        bc150 = [comb(150, r) for r in range(150//2+1)]
        bc225 = [comb(225, r) for r in range(225//2+1)]
    
        def comb(n, k):
            if n < 0 or k < 0: raise ValueError
            if k > n: return 0
            k = min(k, n-k)
            if k > n // 3 and n < 100_000:
                return factorial(n) // (factorial(r) * factorial(n-r))
            if 75 <= n <= 75 + k:
                c, num, den, times = bc75[75-(n-k)], 75+1, 75-(n-k)+1, n-75
            elif 150 <= n <= 150 + k:
                c, num, den, times = bc150[150-(n-k)], 150+1, 150-(n-k)+1, n-150
            elif 225 <= n <= 225 + k:
                c, num, den, times = bc225[225-(n-k)], 225+1, 225-(n-k)+1, n-225
            else:
                c, num, den, times = 1, n-k+1, 1, k
            for i in range(times):
                c = c * num // den
                num += 1
                den += 1
            return c

    @mdickinson
    Copy link
    Member

    mdickinson commented Jun 16, 2019

    (1), (4) and (5) sound good to me.

    For (1), it might make sense to ignore the 32-bit vs. 64-bit distinction and use uint64_t for the internal computations. Then we can do up to n = 62 regardless of platform.

    (2) feels like too much extra complication to me, but that would become clearer with an implementation.

    For (3), I somewhat agree that the factorial method should be avoided.

    For (4), I don't see how/when the GIL could be released: doesn't the algorithm involve lots of memory allocations/deallocations and reference count adjustments?

    Can the suggested performance improvements go into 3.8, or should they wait for 3.9? It's not clear to me whether a performance improvement after feature freeze is okay or not.

    @PedanticHacker
    Copy link
    Mannequin

    PedanticHacker mannequin commented Jun 16, 2019

    Performance improvements is what a beta build exists for in the first place.

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jun 17, 2019

    FWIW, here's a rough implementation of (3). It is limited to a specific range to avoid excess memory usage. For comb(50_000, 25_000), it gives a three-fold speedup:

        if (k_long > 20 && k_long > n_long / 3 && n_long <= 100000) {
            PyObject *den = math_factorial(module, k);                 /* den = k! */
            temp = PyNumber_Subtract(n, k);
            Py_SETREF(temp, math_factorial(module, temp));
            Py_SETREF(den, PyNumber_Multiply(den, temp));              /* den *= (n - k)! */
            Py_DECREF(temp);
            Py_SETREF(result, math_factorial(module, n));              /* result = n! */
            Py_SETREF(result, PyNumber_FloorDivide(result, den));      /* result //= (n-k)! */
            Py_DECREF(den);
            return result;
        }

    Can the suggested performance improvements go into 3.8, or should they wait for 3.9?
    It's not clear to me whether a performance improvement after feature freeze is okay or not.

    Historically, we've used the beta phase for optimizations, tweaking APIs, and improving docs. However, I'm in no rush and this can easily wait for 3.9.

    My only concern is that the other math functions, except for factorial(), have bounded running times and memory usage, so performance is more of a concern for this function which could end-up being an unexpected bottleneck or being a vector for a DOS attack. That said, we haven't had any negative reports regarding factorial(), so this may be a low priority.

    @tim-one
    Copy link
    Member

    tim-one commented Jun 18, 2019

    In real life, I expect 99.999%+ of calls will be made with small arguments, so (1) is worth it. I like Mark's suggestion to use uint64_t so the acceptable range doesn't depend on platform. At least in the world I live in, 32-bit boxes are all but extinct anyway.

    I honestly wouldn't bother with more than that. It's fun to optimize giant-integer algorithms with an ever-ballooning number of different clever approaches, but Python is an odd place for that. People looking for blazing fast giant-int facilities generally want lots & lots of them, so are better steered toward, e.g., gmp. That's its reason for existing.

    For example, their implementation of binomial coefficients uses special division algorithms exploiting that the quotient is exact (no remainder):

    https://gmplib.org/manual/Exact-Division.html#Exact-Division

    There's just no end to potential speedups. But in Python, I expect a vast majority of users will be happy if they get the right answer for the number of possible poker hands ;-)

    Oh ya - some smart kid will file a bug report about the inability to interrupt the calculation of a billion-bit result, so (4) is inevitable. Me, I'd wait for them to complain, and encourage _them_ to learn something useful by writing a patch to fix it :-)

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Jun 18, 2019

    Me, I'd wait for them to complain, and encourage _them_ to learn something useful by writing a patch to fix it :-)

    +1!

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Oct 18, 2021

    Here is more optimized PR inspired by PR 29020. It would be too long to explain how PR 29020 can be improved, so I write a new PR.

    Basically it implements Raymond's idea #1, but supports n>62 for smaller k.

    How to calculate limits:

    import math
    n = m = 2**64
    k = 1
    while True:
        nmax = int(math.ceil((m * math.factorial(k-1)) ** (1/k) + (k-1)/2)) + 100
        n = min(n, nmax)
        while math.comb(n, k) * k >= m:
            n -= 1
        if n < 2*k: break
        print(k, n)
        k += 1

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Oct 18, 2021

    Microbenchmarks:

    $ ./python -m pyperf timeit -s 'from math import comb' '[comb(n, k) for n in range(63) for k in range(n+1)]'
    Mean +- std dev: 1.57 ms +- 0.07 ms -> 209 us +- 11 us: 7.53x faster
    
    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(62, 31)'
    Mean +- std dev: 2.95 us +- 0.14 us -> 296 ns +- 11 ns: 9.99x faster
    
    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(110, 15)'
    Mean +- std dev: 1.33 us +- 0.06 us -> 95.8 ns +- 3.1 ns: 13.86x faster
    
    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(1449, 7)'
    Mean +- std dev: 689 ns +- 33 ns -> 59.0 ns +- 3.2 ns: 11.69x faster
    
    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(3329022, 3)'
    Mean +- std dev: 308 ns +- 19 ns -> 57.2 ns +- 4.2 ns: 5.39x faster

    Now I want to try to optimize for larger arguments. Perhaps using recursive formula C(n, k) = C(n, j)*C(n-j, k-j)//C(k, j) where j=k//2 could help.

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Oct 20, 2021

    Divide-and-conquer approach works pretty well for larger n.

    For results slightly out of the 64-bit range:

    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(63, 31)'
    Mean +- std dev: 2.80 us +- 0.14 us -> 388 ns +- 19 ns: 7.22x faster
    
    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(111, 15)'
    Mean +- std dev: 1.24 us +- 0.06 us -> 215 ns +- 18 ns: 5.76x faster
    
    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(1450, 7)'
    Mean +- std dev: 654 ns +- 45 ns -> 178 ns +- 13 ns: 3.67x faster
    
    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(3329023, 3)'
    Mean +- std dev: 276 ns +- 15 ns -> 175 ns +- 11 ns: 1.58x faster

    For very large n:

    $ ./python -m pyperf timeit 'from math import comb' 'comb(2**100, 2**10)'
    Mean +- std dev: 26.2 ms +- 1.7 ms -> 3.21 ms +- 0.20 ms: 8.16x faster
    
    $ ./python -m pyperf timeit 'from math import comb' 'comb(2**1000, 2**10)'
    Mean +- std dev: 704 ms +- 15 ms -> 103 ms +- 5 ms: 6.85x faster

    And it is faster than using factorial:

    $ ./python -m pyperf timeit -s 'from math import comb' 'comb(100_000, 50_000)'
    Mean +- std dev: 1.61 sec +- 0.02 sec -> 177 ms +- 9 ms: 9.12x faster
    
    $ ./python -m pyperf timeit -s 'from math import factorial as fact' 'fact(100_000) // (fact(50_000)*fact(50_000))'
    Mean +- std dev: 507 ms +- 20 ms

    math.perm() can benefit from reusing the same code:

    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(63, 31)'
    Mean +- std dev: 1.35 us +- 0.07 us -> 1.18 us +- 0.06 us: 1.15x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(111, 15)'
    Mean +- std dev: 601 ns +- 35 ns -> 563 ns +- 28 ns: 1.07x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(2**100, 2**10)'
    Mean +- std dev: 5.96 ms +- 0.29 ms -> 2.32 ms +- 0.12 ms: 2.57x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(2**1000, 2**10)'
    Mean +- std dev: 486 ms +- 14 ms -> 95.7 ms +- 4.2 ms: 5.08x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(100_000, 50_000)'
    Mean +- std dev: 639 ms +- 23 ms -> 66.6 ms +- 3.2 ms: 9.60x faster

    Even in worst cases it is almost as fast as factorial:

    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(100_000, 100_000)'
    Mean +- std dev: 2.55 sec +- 0.02 sec -> 187 ms +- 8 ms: 13.66x faster
    
    $ ./python -m pyperf timeit -s 'from math import factorial' 'factorial(100_000)'
    Mean +- std dev: 142 ms +- 7 ms

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Oct 20, 2021

    And with optimization of math.perm() for small arguments:

    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(30, 14)'
    Mean +- std dev: 524 ns +- 43 ns -> 66.7 ns +- 4.6 ns: 7.85x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(31, 14)'
    Mean +- std dev: 522 ns +- 26 ns -> 127 ns +- 6 ns: 4.09x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(568, 7)'
    Mean +- std dev: 318 ns +- 19 ns -> 62.9 ns +- 3.7 ns: 5.05x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(569, 7)'
    Mean +- std dev: 311 ns +- 14 ns -> 114 ns +- 7 ns: 2.73x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(63, 31)'
    Mean +- std dev: 1.36 us +- 0.08 us -> 263 ns +- 14 ns: 5.17x faster
    
    $ ./python -m pyperf timeit -s 'from math import perm' 'perm(111, 15)'
    Mean +- std dev: 595 ns +- 27 ns -> 126 ns +- 7 ns: 4.71x faster

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Nov 14, 2021

    These speedups all to be significant and worth doing.

    @pochmann
    Copy link
    Mannequin

    pochmann mannequin commented Nov 15, 2021

    I wrote a Python solution ("mycomb") that computes comb(100_000, 50_000) faster, maybe of interest:

    1510.4 ms math.comb(n, k)
    460.8 ms factorial(n) // (factorial(k) * factorial(n-k))
    27.5 ms mycomb(n, k)
    6.7 ms *estimation* for mycomb if written in C

    The idea:

                13 * 12 * 11 * 10 * 9 * 8
    

    comb(13, 6) = ------------------------- = 13 * 1 * 11 * 1 * 3 * 4
    1 * 2 * 3 * 4 * 5 * 6

    It lists the numerator factors, then divides the denominator factors out of them (using primes), then just multiplies.

    Preparing the factors for the final multiplication took most of the time, about 23.1 ms. That part only needs numbers <= n, so it could be done with C ints and be much faster. If it's ten times faster, then mycomb in C would take 23.1/10 + (27.5-23.1) = 6.71 ms.

    See the comb_with_primes.py file.

    @pochmann
    Copy link
    Mannequin

    pochmann mannequin commented Nov 15, 2021

    And for Raymond's case 4), about running very long and not responding to SIGINT, with n=1_000_000 and k=500_000:

    150.91 seconds math.comb(n, k)
    39.11 seconds factorial(n) // (factorial(k) * factorial(n-k))
    0.40 seconds mycomb(n, k)
    0.14 seconds *estimation* for mycomb if written in C

    And for n=10_000_000 and k=5_000_000:

    ~4 hours *estimation* for math.comb(n, k)
    ~1 hour *estimation* for factorials solution
    8.3 seconds mycomb(n, k)
    4.5 seconds *estimation* for mycomb if written in C

    @pochmann
    Copy link
    Mannequin

    pochmann mannequin commented Nov 15, 2021

    Turns out for n=100_000, k=50_000, about 87% of my factors are 1, so they don't even need to be turned into Python ints for multiplication, improving the multiplication part to 3.05 ms. And a C++ version to produce the factors took 0.85 ms. Updated estimation:

    1510.4 ms math.comb(n, k)
    460.8 ms factorial(n) // (factorial(k) * factorial(n-k))
    3.9 ms *estimation* for mycomb if written in C

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Dec 5, 2021

    New changeset 60c320c by Serhiy Storchaka in branch 'main':
    bpo-37295: Optimize math.comb() and math.perm() (GH-29090)
    60c320c

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Dec 18, 2021

    For the small cases (say n < 50), we can get faster code by using a small (3Kb) table of factorial logarithms:

       double lf[50] = [log2(factorial(n)) for n in range(50)];

    Then comb() and perm() function can be computed quickly and in constant time using the C99 math functions:

       result = PyLong_FromDouble(round(exp2(lf[n] - (lf[r] + lf[n-r]))));

    @tim-one
    Copy link
    Member

    tim-one commented Dec 30, 2021

    [Tim]

    That's [Mark's code] cheaper than almost every case handled by
    perm_comb_small()'s current ... "iscomb" loop.

    Although I should clarify they're aimed at different things, and don't overlap all that much. Mark's code, & even more so Raymond's extension, picks on small "n" and then looks for the largest "k" such that comb(n, k) can be done with supernatural speed.

    But the existing perm_comb_small() picks on small "k" and then looks for the largest "n" such that "the traditional" one-at-a-time loop can complete without ever overflowing a C uint64 along the way.

    The latter is doubtless more valuable for perm_comb_small(), since its recursive calls cut k roughly in half, and the first such call doesn't reduce n at all.

    But where they do overlap (e.g., comb(50, 15)), Mark's approach is much faster, so that should be checked first.

    @mdickinson
    Copy link
    Member

    mdickinson commented Dec 30, 2021

    So which of xor-popcount and add-up-up-trailing-zero-counts is faster may well depend on platform.

    I ran some timings for comb(k, 67) on my macOS / Intel MacBook Pro, using timeit to time calls to a function that looked like this:

    def f(comb):
        for k in range(68):
            for _ in range(256):
                comb(k, 67)
                comb(k, 67)
                ... # 64 repetitions of comb(k, 67) in all

    Based on 200 timings of this script with each of the popcount approach and the uint8_t-table-of-trailing-zero-counts approach (interleaved), the popcount approach won, but just barely, at around 1.3% faster. The result was statistically significant (SciPy gave me a result of Ttest_indResult(statistic=19.929941828072433, pvalue=8.570975609117687e-62)).

    Interestingly, the default build on macOS/Intel is _not_ using the dedicated POPCNT instruction that arrived with the Nehalem architecture, presumably because it wants to produce builds that will still be useable on pre-Nehalem machines. It uses Clang's __builtin_popcount, but that gets translated to the same SIMD-within-a-register approach that we have already in pycore_bitutils.h.

    If I recompile with -msse4.2, then the POPCNT instruction *is* used, and I get an even more marginal improvement: a 1.7% speedup over the lookup-table-based version.

    @tim-one
    Copy link
    Member

    tim-one commented Dec 30, 2021

    [Mark]

    I ran some timings for comb(k, 67) on my macOS / Intel MacBook Pro,
    using timeit to time calls to a function that looked like this:

    def f(comb):
    for k in range(68):
    for _ in range(256):
    comb(k, 67)
    comb(k, 67)
    ... # 64 repetitions of comb(k, 67) in all

    I'm assuming you meant to write comb(67, k) instead, since the comb(k, 67) given is 0 at all tested k values except for k=67, and almost never executes any of the code in question.

    It's surprising to me that even the long-winded popcount code was faster! The other way needs to read up 3 1-byte values from a trailing zero table, but the long-winded popcount emulation needs to read up 4 4-byte mask constants (or are they embedded in the instruction stream?), in addition to doing many more bit-fiddling operations (4 shifts, 4 "&" masks, 3 add/subtract, and a multiply - compared to just 2 add/subtract).

    So if the results are right, Intel timings make no sense to me at all ;-)

    @mdickinson
    Copy link
    Member

    mdickinson commented Dec 30, 2021

    I'm assuming you meant to write comb(67, k) instead

    Aargh! That is of course what I meant, but not in fact what I timed. :-(

    I'll redo the timings. Please disregard the previous message.

    @tim-one
    Copy link
    Member

    tim-one commented Dec 30, 2021

    Aargh! That is of course what I meant, but not in fact
    what I timed. :-(

    !!! Even more baffling then. Seems like the code posted got out of math_comb_impl() early here:

            if (overflow || ki > ni) {
                result = PyLong_FromLong(0);
                goto done;
            }

    67 out of every 68 times comb() was called, before any actual ;-) computation was even tried. Yet one way was significantly faster than the other overall, despite that they were so rarely executed at all?

    Something ... seems off here ;-)

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Dec 30, 2021

    PR 30305 applies Mark's algorithm for larger n (up to 127) depending on k, as was suggested by Raymond. Note that it uses different table for limits, which maps k to maximal n.

    @mdickinson
    Copy link
    Member

    mdickinson commented Dec 30, 2021

    Thanks Tim for spotting the stupid mistake. The reworked timings are a bit more ... plausible.

    tl;dr: On my machine, Raymond's suggestion gives a 2.2% speedup in the case where POPCNT is not available, and a 0.45% slowdown in the case that it _is_ available. Given that, and the fact that a single-instruction population count is not as readily available as I thought it was, I'd be happy to change the implementation to use the trailing zero counts as suggested.

    I'll attach the scripts I used for timing and analysis. There are two of them: "timecomb.py" produces a single timing. "driver.py" repeatedly switches branches, re-runs make, runs "timecomb.py", then assembles the results.

    I ran the driver.py script twice: once with a regular ./configure step, and once with ./configure CFLAGS="-march=haswell". Below, "base" refers to the code currently in master; "alt" is the branch with Raymond's suggested change on it.

    Output from the script for the normal ./configure

    Mean time for base: 40.130ns
    Mean for alt: 39.268ns
    Speedup: 2.19%
    Ttest_indResult(statistic=7.9929245698581415, pvalue=1.4418376402220854e-14)
    

    Output for CFLAGS="-march=haswell":

    Mean time for base: 39.612ns
    Mean for alt: 39.791ns
    Speedup: -0.45%
    Ttest_indResult(statistic=-6.75385578636895, pvalue=5.119724894191512e-11)
    

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Dec 30, 2021

    I'd be happy to change the implementation to use the trailing
    zero counts as suggested.

    Thanks. I think that is a portability win and will made the code a lot easier to explain.

    @mdickinson
    Copy link
    Member

    mdickinson commented Dec 31, 2021

    I'd be happy to change the implementation to use the trailing zero counts as suggested.

    Done in #74498 (though this will conflict with Serhiy's PR).

    @mdickinson
    Copy link
    Member

    mdickinson commented Dec 31, 2021

    New changeset 0b58bac by Mark Dickinson in branch 'main':
    bpo-37295: More direct computation of power-of-two factor in math.comb (GH-30313)
    0b58bac

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Jan 9, 2022

    New changeset 2d78797 by Serhiy Storchaka in branch 'main':
    bpo-37295: Use constant-time comb() and perm() for larger n depending on k (GH-30305)
    2d78797

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jan 11, 2022

    I've been experimenting with a modification of Serhiy's recurrence relation, using a different value of j rather than j=k//2.

    The current design splits-off three ways when it recurses, so the number of calls grows quickly. For C(200,100), C(225,112), and C(250,125), the underlying 64 bit modular arithmetic routine is called 115 times, 150 times, and 193 times respectively.

    But with another 2kb of precomputed values, it drops to 3, 16, and 26 calls.

    The main idea is to precompute one diagonal of Pascal's triangle, starting where the 64-bit mod arithmetic version leaves off and going through a limit as high as we want, depending on our tolerance for table size. A table for C(n, 20) where 67 < n <= 225 takes 2101 bytes.

    The new routine adds one line and modifies one line from the current code:

      def C(n, k):
        k = min(k, n - k)
        if k == 0: return 1
        if k == 1: return n
        if k < len(k2n) and n <= k2n[k]: return ModArith64bit(n, k)
        if k == FixedJ and n <= Jlim:  return lookup_known(n)  # New line
        j = min(k // 2, FixedJ)                                # Modified
        return C(n, j) * C(n-j, k-j) // C(k, j)

    The benefit of pinning j to match the precomputed diagonal is that two of the three splits-offs are to known values where no further work is necessary. Given a table for C(n, 20), we get:

    C(200, 100) = C(200, 20) * C(180, 80) // C(100, 20)
                  \_known_/   \_recurse_/    \_known_/
    

    A proof of concept is attached. To make it easy to experiment with, the precomputed diagonal is stored in a dictionary. At the bottom, I show an equivalent function to be used in a C version.

    It looks promising at this point, but I haven't run timings, so I am not sure this is a net win.

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jan 11, 2022

    One fixup:

    • j = min(k // 2, FixedJ) 
      

    + j = FixedJ if k > FixedJ else k // 2

    With that fix, the number of 64-bit mod arithmetic calls drops to 3, 4, and 20 for C(200,100), C(225,112), and C(250,125). The compares to 115, 150, and 193 calls in the current code.

    @tim-one
    Copy link
    Member

    tim-one commented Jan 12, 2022

    Just noting that comb_pole.py requires a development version of Python to run (under all released versions, a byteorder argument is required for int.{to, from}_byte() calls).

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jan 12, 2022

    Just posted an update that runs on 3.8 or later.

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jan 12, 2022

    Ran some timings on the pure python version with and without the precomputed diagonal: [C(n, k) for k in range(n+1)]

          New Code        Baseline
         ---------       --------- 
    

    n=85 156 usec 160 usec
    n=137 632 usec 1.82 msec
    n=156 1.01 msec 4.01 msec
    n=183 1.58 msec 7.06 msec
    n=212 2.29 msec 10.4 msec
    n=251 7.07 msec 15.6 msec
    n=311 19.6 msec 25.5 msec
    n=1001 314 msec 471 msec

    As expected, the best improvement comes in the range of the precomputed combinations: C(129, 20) through C(250, 20).

    To get a better idea of where the benefit is coming from, here is a trace from the recursive part of the new code:

    >>> C(240, 112)
    C(240, 112) = C(240, 20) * C(220, 92) // C(112, 20)
    C(220, 92) = C(220, 20) * C(200, 72) // C(92, 20)
    C(200, 72) = C(200, 20) * C(180, 52) // C(72, 20)
    C(180, 52) = C(180, 20) * C(160, 32) // C(52, 20)
    C(160, 32) = C(160, 20) * C(140, 12) // C(32, 20)
    C(140, 12) = C(140, 6) * C(134, 6) // C(12, 6)
    C(140, 6) = C(140, 3) * C(137, 3) // C(6, 3)
    C(140, 3) = C(140, 1) * C(139, 2) // C(3, 1)
    C(139, 2) = C(139, 1) * C(138, 1) // C(2, 1)
    C(137, 3) = C(137, 1) * C(136, 2) // C(3, 1)
    C(136, 2) = C(136, 1) * C(135, 1) // C(2, 1)
    C(134, 6) = C(134, 3) * C(131, 3) // C(6, 3)
    C(134, 3) = C(134, 1) * C(133, 2) // C(3, 1)
    C(133, 2) = C(133, 1) * C(132, 1) // C(2, 1)
    C(131, 3) = C(131, 1) * C(130, 2) // C(3, 1)
    C(130, 2) = C(130, 1) * C(129, 1) // C(2, 1)
    53425668586973006090333976236624193248148570813984466591034254245235155

    Here is a trace for the old code without the precomputed diagonal:

    >>> C(240, 112)
    C(240, 112) = C(240, 56) * C(184, 56) // C(112, 56)
    C(240, 56) = C(240, 28) * C(212, 28) // C(56, 28)
    C(240, 28) = C(240, 14) * C(226, 14) // C(28, 14)
    C(240, 14) = C(240, 7) * C(233, 7) // C(14, 7)
    C(240, 7) = C(240, 3) * C(237, 4) // C(7, 3)
    C(240, 3) = C(240, 1) * C(239, 2) // C(3, 1)
    C(239, 2) = C(239, 1) * C(238, 1) // C(2, 1)
    C(237, 4) = C(237, 2) * C(235, 2) // C(4, 2)
    C(237, 2) = C(237, 1) * C(236, 1) // C(2, 1)
    C(235, 2) = C(235, 1) * C(234, 1) // C(2, 1)
    C(233, 7) = C(233, 3) * C(230, 4) // C(7, 3)
    C(233, 3) = C(233, 1) * C(232, 2) // C(3, 1)
    C(232, 2) = C(232, 1) * C(231, 1) // C(2, 1)
    C(230, 4) = C(230, 2) * C(228, 2) // C(4, 2)
    C(230, 2) = C(230, 1) * C(229, 1) // C(2, 1)
    C(228, 2) = C(228, 1) * C(227, 1) // C(2, 1)
    C(226, 14) = C(226, 7) * C(219, 7) // C(14, 7)
    C(226, 7) = C(226, 3) * C(223, 4) // C(7, 3)
    C(226, 3) = C(226, 1) * C(225, 2) // C(3, 1)
    C(225, 2) = C(225, 1) * C(224, 1) // C(2, 1)
    C(223, 4) = C(223, 2) * C(221, 2) // C(4, 2)
    C(223, 2) = C(223, 1) * C(222, 1) // C(2, 1)
    C(221, 2) = C(221, 1) * C(220, 1) // C(2, 1)
    C(219, 7) = C(219, 3) * C(216, 4) // C(7, 3)
    C(219, 3) = C(219, 1) * C(218, 2) // C(3, 1)
    C(218, 2) = C(218, 1) * C(217, 1) // C(2, 1)
    C(216, 4) = C(216, 2) * C(214, 2) // C(4, 2)
    C(216, 2) = C(216, 1) * C(215, 1) // C(2, 1)
    C(214, 2) = C(214, 1) * C(213, 1) // C(2, 1)
    C(212, 28) = C(212, 14) * C(198, 14) // C(28, 14)
    C(212, 14) = C(212, 7) * C(205, 7) // C(14, 7)
    C(212, 7) = C(212, 3) * C(209, 4) // C(7, 3)
    C(212, 3) = C(212, 1) * C(211, 2) // C(3, 1)
    C(211, 2) = C(211, 1) * C(210, 1) // C(2, 1)
    C(209, 4) = C(209, 2) * C(207, 2) // C(4, 2)
    C(209, 2) = C(209, 1) * C(208, 1) // C(2, 1)
    C(207, 2) = C(207, 1) * C(206, 1) // C(2, 1)
    C(205, 7) = C(205, 3) * C(202, 4) // C(7, 3)
    C(205, 3) = C(205, 1) * C(204, 2) // C(3, 1)
    C(204, 2) = C(204, 1) * C(203, 1) // C(2, 1)
    C(202, 4) = C(202, 2) * C(200, 2) // C(4, 2)
    C(202, 2) = C(202, 1) * C(201, 1) // C(2, 1)
    C(200, 2) = C(200, 1) * C(199, 1) // C(2, 1)
    C(198, 14) = C(198, 7) * C(191, 7) // C(14, 7)
    C(198, 7) = C(198, 3) * C(195, 4) // C(7, 3)
    C(198, 3) = C(198, 1) * C(197, 2) // C(3, 1)
    C(197, 2) = C(197, 1) * C(196, 1) // C(2, 1)
    C(195, 4) = C(195, 2) * C(193, 2) // C(4, 2)
    C(195, 2) = C(195, 1) * C(194, 1) // C(2, 1)
    C(193, 2) = C(193, 1) * C(192, 1) // C(2, 1)
    C(191, 7) = C(191, 3) * C(188, 4) // C(7, 3)
    C(191, 3) = C(191, 1) * C(190, 2) // C(3, 1)
    C(190, 2) = C(190, 1) * C(189, 1) // C(2, 1)
    C(188, 4) = C(188, 2) * C(186, 2) // C(4, 2)
    C(188, 2) = C(188, 1) * C(187, 1) // C(2, 1)
    C(186, 2) = C(186, 1) * C(185, 1) // C(2, 1)
    C(184, 56) = C(184, 28) * C(156, 28) // C(56, 28)
    C(184, 28) = C(184, 14) * C(170, 14) // C(28, 14)
    C(184, 14) = C(184, 7) * C(177, 7) // C(14, 7)
    C(184, 7) = C(184, 3) * C(181, 4) // C(7, 3)
    C(184, 3) = C(184, 1) * C(183, 2) // C(3, 1)
    C(183, 2) = C(183, 1) * C(182, 1) // C(2, 1)
    C(181, 4) = C(181, 2) * C(179, 2) // C(4, 2)
    C(181, 2) = C(181, 1) * C(180, 1) // C(2, 1)
    C(179, 2) = C(179, 1) * C(178, 1) // C(2, 1)
    C(177, 7) = C(177, 3) * C(174, 4) // C(7, 3)
    C(177, 3) = C(177, 1) * C(176, 2) // C(3, 1)
    C(176, 2) = C(176, 1) * C(175, 1) // C(2, 1)
    C(174, 4) = C(174, 2) * C(172, 2) // C(4, 2)
    C(174, 2) = C(174, 1) * C(173, 1) // C(2, 1)
    C(172, 2) = C(172, 1) * C(171, 1) // C(2, 1)
    C(170, 14) = C(170, 7) * C(163, 7) // C(14, 7)
    C(170, 7) = C(170, 3) * C(167, 4) // C(7, 3)
    C(170, 3) = C(170, 1) * C(169, 2) // C(3, 1)
    C(169, 2) = C(169, 1) * C(168, 1) // C(2, 1)
    C(167, 4) = C(167, 2) * C(165, 2) // C(4, 2)
    C(167, 2) = C(167, 1) * C(166, 1) // C(2, 1)
    C(165, 2) = C(165, 1) * C(164, 1) // C(2, 1)
    C(163, 7) = C(163, 3) * C(160, 4) // C(7, 3)
    C(163, 3) = C(163, 1) * C(162, 2) // C(3, 1)
    C(162, 2) = C(162, 1) * C(161, 1) // C(2, 1)
    C(160, 4) = C(160, 2) * C(158, 2) // C(4, 2)
    C(160, 2) = C(160, 1) * C(159, 1) // C(2, 1)
    C(158, 2) = C(158, 1) * C(157, 1) // C(2, 1)
    C(156, 28) = C(156, 14) * C(142, 14) // C(28, 14)
    C(156, 14) = C(156, 7) * C(149, 7) // C(14, 7)
    C(156, 7) = C(156, 3) * C(153, 4) // C(7, 3)
    C(156, 3) = C(156, 1) * C(155, 2) // C(3, 1)
    C(155, 2) = C(155, 1) * C(154, 1) // C(2, 1)
    C(153, 4) = C(153, 2) * C(151, 2) // C(4, 2)
    C(153, 2) = C(153, 1) * C(152, 1) // C(2, 1)
    C(151, 2) = C(151, 1) * C(150, 1) // C(2, 1)
    C(149, 7) = C(149, 3) * C(146, 4) // C(7, 3)
    C(149, 3) = C(149, 1) * C(148, 2) // C(3, 1)
    C(148, 2) = C(148, 1) * C(147, 1) // C(2, 1)
    C(146, 4) = C(146, 2) * C(144, 2) // C(4, 2)
    C(146, 2) = C(146, 1) * C(145, 1) // C(2, 1)
    C(144, 2) = C(144, 1) * C(143, 1) // C(2, 1)
    C(142, 14) = C(142, 7) * C(135, 7) // C(14, 7)
    C(142, 7) = C(142, 3) * C(139, 4) // C(7, 3)
    C(142, 3) = C(142, 1) * C(141, 2) // C(3, 1)
    C(141, 2) = C(141, 1) * C(140, 1) // C(2, 1)
    C(139, 4) = C(139, 2) * C(137, 2) // C(4, 2)
    C(139, 2) = C(139, 1) * C(138, 1) // C(2, 1)
    C(137, 2) = C(137, 1) * C(136, 1) // C(2, 1)
    C(135, 7) = C(135, 3) * C(132, 4) // C(7, 3)
    C(135, 3) = C(135, 1) * C(134, 2) // C(3, 1)
    C(134, 2) = C(134, 1) * C(133, 1) // C(2, 1)
    C(132, 4) = C(132, 2) * C(130, 2) // C(4, 2)
    C(132, 2) = C(132, 1) * C(131, 1) // C(2, 1)
    C(130, 2) = C(130, 1) * C(129, 1) // C(2, 1)
    C(112, 56) = C(112, 28) * C(84, 28) // C(56, 28)
    C(112, 28) = C(112, 14) * C(98, 14) // C(28, 14)
    C(84, 28) = C(84, 14) * C(70, 14) // C(28, 14)
    53425668586973006090333976236624193248148570813984466591034254245235155

    @tim-one
    Copy link
    Member

    tim-one commented Jan 12, 2022

    A feature of the current code is that, while the recursion tree can be very wide, it's not tall, with max depth proportional to the log of k. But it's proportional to k in the proposal (the C(n-j, k-j) term's second argument goes down by at most 20 per recursion level).

    So, e.g., C(1000000, 500000) dies with RecursionError in Python; in C, whatever platform-specific weird things can happen when the C stack is blown.

    The width of the recursion tree could be slashed in any case by "just dealing with" (say) k <= 20 directly, no matter how large n is. Do the obvious loop with k-1 multiplies, and k-1 divides by the tiny ints in range(2, k+1). Note that "almost all" the calls in your "trace for the old code" listing are due to recurring for k <= 20. Or use Stefan's method: if limited to k <= 20, it only requires a tiny precomputed table of the 8 primes <= 20, and a stack array to hold range(n, n-k, -1); that can be arranged to keep Karatsuba in play if n is large.

    An irony is that the primary point of the recurrence is to get Karatsuba multiplication into play, but the ints involved in computing C(240, 112) are nowhere near big enough to trigger that.

    To limit recursion depth, I think you have to change your approach to decide in advance the deepest you're willing to risk letting it go, and keep the current j = k // 2 whenever repeatedly subtracting 20 could exceed that.

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jan 13, 2022

    Okay, will set a cap on the n where a fixedj is used. Also, making a direct computation for k<20 is promising.

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jan 13, 2022

    def comb64(n, k):
        'comb(n, k) in multiplicative group modulo 64-bits'
        return (F[n] * Finv[k] * Finv[n-k] & (2**64-1)) << (S[n] - S[k] - S[n - k])
    
    def comb_iterative(n, k):
        'Straight multiply and divide when k is small.'
        result = 1
        for r in range(1, k+1):
            result *= n - r + 1
            result //= r
        return result
    
    def C(n, k):
        k = min(k, n - k)
        if k == 0: return 1
        if k == 1: return n
        if k < len(k2n) and n <= k2n[k]:  return comb64(n, k) # 64-bit fast case
        if k == FixedJ and n <= Jlim:  return KnownComb[n]    # Precomputed diagonal
        if k < 10: return comb_iterative(n, k)                # Non-recursive for small k  
        j = FixedJ if k > FixedJ and n <= Jlim else k // 2
        return C(n, j) * C(n-j, k-j) // C(k, j)               # Recursive case

    @tim-one
    Copy link
    Member

    tim-one commented Jan 14, 2022

    I was thinking about

    comb(1000000, 500000)

    The simple "* --n / ++k" loop does 499,999 each of multiplication and division, and in all instances the second operand is a single Python digit. Cheap as can be.

    In contrast, despite that it short-circuits all "small k" cases, comb_pole2.py executes

    return C(n, j) * C(n-j, k-j) // C(k, j)               # Recursive case
    

    7,299,598 times. Far more top-level arithmetic operations, including extra-expensive long division with both operands multi-digit. At the very top of the tree, "// C(500000, 250000)" is dividing out nearly half a million bits.

    But a tiny Python function coding the simple loop takes about 150 seconds, while Raymond's C() about 30 (under the released 3.10.1). The benefits from provoking Karatsuba are major.

    Just for the heck of it, I coded a complication of the simple loop that just tries to provoke Karatsuba on the numerator (prod(range(n, n-k, -1))), and dividing that just once at the end, by factorial(k). That dropped the time to about 17.5 seconds. Over 80% of that time is spent computing the "return" expression, where len(p) is 40 and only 7 entries pass the x>1 test:

    return prod(x for x in p if x > 1) // factorial(k)
    

    That is, with really big results, almost everything is mostly noise compared to the time needed to do the relative handful of * and // on the very largest intermediate results.

    Stefan's code runs much faster still, but requires storage proportional to k. The crude hack I used needs a fixed (independent of k and n) and small amount of temp storage, basically grouping inputs into buckets roughly based on the log _of_ their log, and keeping only one int per bucket.

    Here's the code:

        def pcomb2(n, k):
            from math import prod, factorial
        assert 0 <= k <= n
        k = min(k, n-k)
        PLEN = 40 # good enough for ints with over a trillion bits
        p = [1] * PLEN
    
            def fold_into_p(x):
                if x == 1:
                    return
                while True:
                    i = max(x.bit_length().bit_length() - 5, 0)
                    if p[i] == 1:
                        p[i] = x
                        break
                    x *= p[i]
                    p[i] = 1
    
            def showp():
                for i in range(PLEN):
                    pi = p[i]
                    if pi > 1:
                        print(i, pi.bit_length())
    
            for i in range(k):
                fold_into_p(n)
                n -= 1
            showp()
            return prod(x for x in p if x > 1) // factorial(k)

    I'm not sure it's practical. For example, while the list p[] can be kept quite small, factorial(k) can require a substantial amount of temp storage of its own - factorial(500000) in the case at hand is an int approaching 9 million bits.

    Note: again in the case at hand, computing it via

    factorial(1000000) // factorial(500000)**2
    

    takes about 10% longer than Raymond's C() function.

    @tim-one
    Copy link
    Member

    tim-one commented Jan 14, 2022

    Another trick, building on the last one: computing factorial(k) isn't cheap, in time or space, and neither is dividing by it. But we know it will entirely cancel out. Indeed, for each outer loop iteration, prod(p) is divisible by the current k. But, unlike as in Stefan's code, which materializes range(n, n-k, -1) as an explicit list, we have no way to calculate "in advance" which elements of p[] are divisible by what.

    What we _can_ do is march over all of p[], and do a gcd of each element with the current k. If greater than 1, it can be divided out of both that element of p[], and the current k. Later, rinse, repeat - the current k must eventually be driven to 1 then.

    But that slows things down: gcd() is also expensive.

    But there's a standard trick to speed that too: as in serious implementations of Pollard's rho factorization method, "chunk it". That is, don't do it on every outer loop iteration, but instead accumulate the running product of several denominators first, then do the expensive gcd pass on that product.

    Here's a replacement for "the main loop" of the last code that delays doing gcds until the running product is at least 2000 bits:

            fold_into_p(n)
    
            kk = 1
            for k in range(2, k+1):
                n -= 1
                # Merge into p[].
                fold_into_p(n)
                # Divide by k.
                kk *= k
                if kk.bit_length() < 2000:
                    continue
                for i, pi in enumerate(p):
                    if pi > 1:
                        g = gcd(pi, kk)
                        if g > 1:
                            p[i] = pi // g
                            kk //= g
                            if kk == 1:
                                break
                assert kk == 1
            showp()
            return prod(x for x in p if x > 1) // kk

    That runs in under half the time (for n=1000000, k=500000), down to under 7.5 seconds. And, of course, the largest denominator consumes only about 2000 bits instead of 500000!'s 8,744,448 bits.

    Raising the kk bit limit from 2000 to 10000 cuts another 2.5 seconds off, down to about 5 seconds.

    Much above that, it starts getting slower again.

    Seems to hard to out-think! And highly dubious to fine-tune it based on a single input case ;-)

    Curious: at a cutoff of 10000 bits, we're beyond the point where Karatsuba would have paid off for computing denominator partial products too.

    @tim-one tim-one removed the 3.11 label Jan 14, 2022
    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Jan 19, 2022

    All this should be tested with the C implementation because relative cost of operations is different in C and Python.

    I have tested Raymond's idea about using iterative algorithm for small k.

    $ ./python -m timeit -s 'from math import comb' "comb(3329023, 3)"
    recursive: Mean +- std dev: 173 ns +- 9 ns
    iterative: Mean +- std dev: 257 ns +- 13 ns
    
    $ ./python -m pyperf timeit -s 'from math import comb' "comb(102571, 4)"
    recursive: Mean +- std dev: 184 ns +- 10 ns
    iterative: Mean +- std dev: 390 ns +- 29 ns
    
    $ ./python -m pyperf timeit -s 'from math import comb' "comb(747, 8)"
    recursive: Mean +- std dev: 187 ns +- 10 ns
    iterative: Mean +- std dev: 708 ns +- 39 ns

    Recursive algorithm is always faster than iterative one for k>2 (they are equal for k=1 and k=2).

    And it is not only because of division, because for perm() we have the same difference.

    $ ./python -m pyperf timeit -s 'from math import perm' "perm(2642247, 3)"
    recursive: Mean +- std dev: 118 ns +- 7 ns
    iterative: Mean +- std dev: 147 ns +- 8 ns
    
    $ ./python -m pyperf timeit -s 'from math import perm' "perm(65538, 4)"
    recursive: Mean +- std dev: 130 ns +- 9 ns
    iterative: Mean +- std dev: 203 ns +- 13 ns
    
    $ ./python -m pyperf timeit -s 'from math import perm' "perm(260, 8)"
    recursive: Mean +- std dev: 131 ns +- 10 ns
    iterative: Mean +- std dev: 324 ns +- 16 ns

    As for the idea about using a table for fixed k=20, note that comb(87, 20) exceeds 64 bits, so we will need to use a table of 128-bit integers. And I am not sure if this algorithm will be faster than the recursive one.

    We may achieve better results for lesser cost if extend Mark's algorithm to use 128-bit integers. I am not sure whether it is worth, the current code is good enough and cover the wide range of cases. Additional optimizations will likely have lesser effort/benefit ratio.

    @serhiy-storchaka
    Copy link
    Member

    serhiy-storchaka commented Jan 19, 2022

    comb(n, k) can be computed as perm(n, k) // factorial(k).

    $ ./python -m timeit -r1 -n1 -s 'from math import comb' "comb(1000000, 500000)"
    recursive: 1 loop, best of 1: 9.16 sec per loop
    iterative: 1 loop, best of 1: 164 sec per loop
    
    $ ./python -m timeit -r1 -n1 -s 'from math import perm, factorial' "perm(1000000, 500000) // factorial(500000)"
    recursive: 1 loop, best of 1: 19.8 sec per loop
    iterative: 1 loop, best of 1: 137 sec per loop

    It is slightly faster than division on every step if use the iterative algorithm, but still much slower than the recursive algorithm. And the latter if faster if perform many small divisions and keep intermediate results smaller.

    @tim-one
    Copy link
    Member

    tim-one commented Jan 23, 2022

    Ya, I don't expect anyone will check in a change without doing comparative timings in C first. Not worried about that.

    I'd be happy to declare victory and move on at this point ;-) But that's me. Near the start of this, I noted that we just won't compete with GMP's vast array of tricks.

    I noted that they use a special routine for division when it's known in advance that the remainder is 0 (as it is, e.g., in every division performed by "our" recursion (which GMP also uses, in some cases)).

    But I didn't let on just how much that can buy them. Under 3.10.1 on my box, the final division alone in math.factorial(1000000) // math.factorial(500000)**2 takes over 20 seconds. But a pure Python implementation of what I assume (don't know for sure) is the key idea(*) in their exact-division algorithm does the same thing in under 0.4 seconds. Huge difference - and while the pure Python version only ever wants the lower N bits of an NxN product, there's no real way to do that in pure Python except via throwing away the higher N bits of a double-width int product. In C, of course, the high half of the bits wouldn't be computed to begin with.

    (*) Modular arithmetic again. Given n and d such that it's known n = q*d for some integer q, shift n and d right until d is odd. q is unaffected. A good upper bound on the bit length of q is then n.bit_length() - d.bit_length() + 1. Do the remaining work modulo 2 raised to that power. Call that base B.

    We "merely" need to solve for q in the equation n = q*d (mod B). Because d is odd, I = pow(d, -1, B) exists. Just multiply both sides by I to get n * I = q (mod B).

    No divisions of any kind are needed. More, there's also a very efficient, division-free algorithm for finding an inverse modulo a power of 2. To start with, every odd int is its own inverse mod 8, so we start with 3 good bits. A modular Newton-like iteration can double the number of correct bits on each iteration.

    But I won't post code (unless someone asks) because I don't want to encourage anyone :-)

    @rhettinger
    Copy link
    Contributor Author

    rhettinger commented Jan 23, 2022

    But I won't post code (unless someone asks)

    Okay, I'll ask.

    @tim-one
    Copy link
    Member

    tim-one commented Jan 23, 2022

    OK, here's the last version I had. Preconditions are that d > 0, n > 0, and n % d == 0.

    This version tries to use the narrowest possible integers on each step. The lowermost good_bits of dinv at the start of the loop are correct already.

    Taking out all the modular stuff, the body of the loop boils down to just

    dinv *= 2 - dinv * d
    

    For insight, if

    dinv * d = 1 + k*2**i
    

    for some k and i (IOW, if dinv * d = 1 modulo 2**i), then

    2 - dinv * d = 1 - k*2**i
    

    and so dinv times that equals 1 - k**2 * 2**(2*i). Or, IOW, the next value of dinv is such that d * dinv = 1 modulo 2**(2*i) - it's good to twice as many bits.

        def ediv(n, d):
            assert d
    
            def makemask(n):
                return (1 << n) - 1
    
            if d & 1 == 0:
                ntz = (d & -d).bit_length() - 1
                n >>= ntz
                d >>= ntz
            bits_needed = n.bit_length() - d.bit_length() + 1
            good_bits = 3
            dinv = d & 7
            while good_bits < bits_needed:
                twice = min(2 * good_bits, bits_needed)
                twomask = makemask(twice)
                fac2 = dinv * (d & twomask)
                fac2 &= twomask
                fac2 = (2 - fac2) & twomask
                dinv = (dinv * fac2) & twomask
                good_bits = twice
            goodmask = makemask(bits_needed)
            return ((dinv & goodmask) * (n & goodmask)) & goodmask

    @ezio-melotti ezio-melotti transferred this issue from another repository Apr 10, 2022
    Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
    Labels
    3.11 3.12 performance Performance or resource usage stdlib Python modules in the Lib dir
    Projects
    None yet
    Development

    No branches or pull requests

    5 participants