In [None]:
#default_exp partition
# https://www.reddit.com/r/dailyprogrammer/comments/jfcuz5/20201021_challenge_386_intermediate_partition/

# Challenge

- Compute `p(666)` (26-digit, with sum of digits=127)
- `p(66)=2323520`

# Sequence formula

- `seq1 = alternate ({1,2,3,4,5,...}, {3,5,7,9,11...})`
- `seq2`
  - `seq2[0] = 1`
  - `seq2[i+1] = seq2[i] + seq1[i]`
- `seq3 = {1, 1, -1, -1, 1, 1, -1, -1, ...}`

- `p(n) = \sum_{i=1}^{n} p(n-seq2[i]) * seq3[i]`
  - `p(n) = 0, n < 0`
  
`p(n) = 1, 1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, ...`

In [None]:
#export
import functools
import numpy as np

In [None]:
#export
def partition_seq1():
    # Alternate between {1, 2, 3, 4, ...} and {3, 5, 7, 9, ...}
    i, j = 1, 3
    while True:
        yield i
        yield j
        i, j = i+1, j+2

def partition_seq2():
    i = 1
    for ddelta in partition_seq1():
        yield i
        i += ddelta

In [None]:
import itertools # For testing
assert tuple(itertools.islice(partition_seq1(), 13)) == (1, 3, 2, 5, 3, 7, 4, 9, 5, 11, 6, 13, 7)
assert tuple(itertools.islice(partition_seq2(), 13)) == (1, 2, 5, 7, 12, 15, 22, 26, 35, 40, 51, 57, 70)

In [None]:
#export
@functools.lru_cache(maxsize=1000, typed=True)
def partition(n:int):
    if n < 0: return 0
    elif n == 0: return 1

    delta_gen = partition_seq2()
    total = 0

    while True:
        d = [next(delta_gen) for _ in range(4)]

        total += partition(n - d[0]) + partition(n - d[1])
        total -= partition(n - d[2]) + partition(n - d[3])

        if d[3] > n: break

    return total

In [None]:
assert [partition(i) for i in range(10)] == [1, 1, 2, 3, 5, 7, 11, 15, 22, 30]
assert partition(66) == 2323520
%timeit -n1 -r1 partition(666)
print(partition(666))

In [None]:
#export
def sum_digits(n):
    total = 0
    while n > 0:
        total += n % 10
        n //= 10
    return total

In [None]:
assert sum_digits(partition(666)) == 127

In [None]:
# export 
def partition_iterative(n:int) -> int:

    # Calculate how many numbers we need (+ buffer of 4 numbers)
    x = np.int(np.ceil((np.sqrt(24*n+81)-9)/6)) + 1 + 2

    deltas = np.zeros(2*x, dtype=np.int64)
    deltas[1::2] = np.arange(1, x+1)
    deltas[0::2] = np.arange(1, 2*x, 2)
    deltas = deltas.cumsum()
    deltas = list(map(int, deltas))

    flags = list(map(int, np.tile([1,1,-1,-1], x//2+1)))

    p = [1]
    for i in range(1, n+1):
        total = 0

        for j in range(0, 2*x, 4):
            if deltas[j+4] > i:
                if deltas[j+0] <= i: total += p[i - deltas[j+0]]
                if deltas[j+1] <= i: total += p[i - deltas[j+1]]
                if deltas[j+2] <= i: total -= p[i - deltas[j+2]]
                if deltas[j+3] <= i: total -= p[i - deltas[j+3]]
                break

            total += \
                p[i - deltas[j+0]] + \
                p[i - deltas[j+1]] - \
                p[i - deltas[j+2]] - \
                p[i - deltas[j+3]]

        p.append(total)

    return p[-1]

In [None]:
# Testing timeit output
%timeit -n1 -r1 partition_iterative(6)

In [None]:
# Test code for 666 to ensure correctness and check overflow
solution_666 = 11956824258286445517629485
%timeit -n1 -r1 partition_iterative(666)
assert partition_iterative(666) == solution_666

11956824258286445517629485
6.81 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [None]:
#slow
solution_666666 = 829882047250572684700899902613243782763602762816201701722599315815312910790359761069230836156205082863018110775536469855308986200073144248662915902110787189076754874498156375781560473819383193570267234294008114407862435374738896137895011798602712056367666855560838392848713564675054729329398073507378373208972509842880751022273604950120819819461244250221006793015786720300981470607613047369007107554702116361432490562419340585594835559930063181308823544907938556335147860188606415089685992917539117106588219848248270148792532079530603636993578091236835691954161244027792120896238596848636567612717269000784250428006924746617450033567240084513811817484845287957454044679781070379504435782073968802016327182672402147816498886658350521297949309218478570934795197523632953503835428280916586305632528116828229355871664575877094278615695592039186556142662054263695788587794970386821424021653983372333685780633
assert partition_iterative(666666) == solution_666666
%timeit -n1 -r1 print(partition_iterative(666666))