# Pairs of sums of two cubes
Posted by Zak in the problem solving group
>Design an algorithm and write code to find all solutions to the equation a^3 + b^3 = c^3 + d^3 where a, b, c, and d are positive integers less than 1000. 
Ignore solutions of the form x^3 + y^3 = x^3 + y^3 and solutions that are simple permutations of other solutions (swapping left and right hand sides, swapping a and b, swapping c and d). 
For example, if you were printing all solutions less than 20, you could choose to print only 2^3 + 16^3 = 9^3 + 15^3 and 1^3 + 12^3 = 9^3 + 10^3.

## Naive solution
So this is an interesting problem, and there's a fairly obvious solution which is where you check every possible combination of a, b, c, d.

It sounds inefficient, but let's code it anyway and run it on some small values of n so that we can see what some solutions look like... this might help us find patterns in the solutions to optimise the code.

In [1]:
def naive_sum_of_cubes(n):
    solutions = []  
    for a in range(1, n):
        for b in range(1, n):
            for c in range(1, n):
                for d in range(1, n):
                    if not a**3 + b**3 == c**3 + d**3:
                        # doesn't add up
                        continue
                    solution = (a, b, c, d)
                    solutions.append(solution)
    return solutions

In [2]:
# Check that the examples from the problem work
solutions = naive_sum_of_cubes(20)
assert (2, 16, 9, 15) in solutions  # 2^3 + 16^3 = 9^3 + 15^3
assert (1, 12, 9, 10) in solutions  # 1^3 + 12^3 = 9^3 + 10^3

# Let's see some solutions
for solution in solutions[:10]:
    print(solution)
print(f'Found {len(solutions)} solutions')

(1, 1, 1, 1)
(1, 2, 1, 2)
(1, 2, 2, 1)
(1, 3, 1, 3)
(1, 3, 3, 1)
(1, 4, 1, 4)
(1, 4, 4, 1)
(1, 5, 1, 5)
(1, 5, 5, 1)
(1, 6, 1, 6)
Found 719 solutions


We have a lot of silly solutions:
- **Boring**: `(1, 2, 1, 2)` or `(1, 2, 2, 1)` don't tell us anything interesting
- **Permutations**: `(2, 16, 9, 15)` is the same as `(2, 16, 15, 9)` and many others.

So lets filter them out:

In [3]:
def boring(solution):
    a, b, c, d = solution
    identity = a == c and b == d           # (1, 2, 1, 2)
    flipped_identity = a == d and b == c   # (1, 2, 2, 1)
    return identity or flipped_identity

def permuted(solution):
    a, b, c, d = solution
    return {
        (a, b, c, d),
        (a, b, d, c),
        (b, a, c, d),
        (b, a, d, c),
        (c, d, a, b),
        (c, d, b, a),
        (d, c, a, b),
        (d, c, b, a),
    }

def contains(solutions, solution):
    '''Check if this solution (including its permutations) are in the solutions'''
    return any(permutation in solutions for permutation in permuted(solution))

def unique(solutions):
    # sets are very fast for checking membership because they are
    # implemented using a hash
    unique_solutions = set()
    for solution in solutions:
        if not contains(unique_solutions, solution):
            unique_solutions.add(solution)
    return unique_solutions

def interesting(solutions):
    return unique(solution for solution in solutions if not boring(solution))


In [4]:
# Check that the examples from the problem work
solutions = interesting(naive_sum_of_cubes(20))
assert contains(solutions, (2, 16, 9, 15))   # 2^3 + 16^3 = 9^3 + 15^3
assert contains(solutions, (1, 12, 9, 10))   # 1^3 + 12^3 = 9^3 + 10^3
assert len(solutions) == 2

Nice.
But let's see how long it takes...

In [5]:
n = 50
%time solutions = interesting(naive_sum_of_cubes(n))
print(f'Found {len(solutions)} solutions')

n = 100
%time solutions = interesting(naive_sum_of_cubes(n))
print(f'Found {len(solutions)} solutions')

CPU times: user 6.09 s, sys: 4.52 ms, total: 6.09 s
Wall time: 6.09 s
Found 12 solutions
CPU times: user 1min 42s, sys: 70 µs, total: 1min 42s
Wall time: 1min 42s
Found 45 solutions


6 secs for n=50  
1 min 40 secs for n=100

I'm not sure that we're going to make it to 1000!

## Ideas for a better approach
When I did this, I had some ideas which I thought might help:
1. As an over-arching strategy, I want to reduce the number of operations I have to do. I'm thinking that maybe there could be some sort of symmetry or constraints on what a, b, c, d can be so that we can consider far fewer solutions.
2. If I have checked `(a, b, c, d)`, I don't need to check `(b, a, c, d)`. This could maybe translate in code to
```python
for a in range(1, n):
    for b in range(1, a):
        # do stuff
```
This point also holds for c and d.  
Maybe all the stuff that we filtered out using `unique` above, we could just never look at it in the first place.
3. In the solution above, we re-calculate a^3 + b^3 many times. So lets pre-compute all the possible values, and store them in a list
4. I noticed that all the solutions are a _sandwhich_. I.e. a < c, d < b . Does this somehow reduce our search space?
5. Maybe doing it as a difference of cubes could help? Maybe there's a maths trick in here somewhere with factorisation?
```math
                                         a^3 + b^3 = c^3 + d^3
                             a^3 - c^3 + b^3 - d^3 = 0
(a - c)(a^2 + ac + c^2) + (b - d)(b^2 + bd + d^2) = 0
```
Hmmmm... nothing obvious to me
6. I haven't seen any solutions of the form `a == b`. Is it coincidence? If not, is this some kind of property that could be useful for reducing the sample space? Not sure...

## Good enough for now...
Approaches 1 - 3 above are useful. Approach 4 I think could useful to reduce the search space, but I can't figure out how to implement it. And in the end I decided that approach 5 and wouldn't help, and I think 6 might just be a coincidence.

I didn't have any other ideas, so I started implementing approach 2. Originally I thought of approach 3, but thought the speedup would be minimal, so I didn't bother to implement it until Ian was getting competitive about timing! Turns out it helps!

Here's 2 and 3:

In [6]:
def improved_sum_of_cubes(n):
    cube = [i**3 for i in range(n)]
    
    # pair_to_sum is a map of
    # (a, b) -> a^3 + b^3
    # so I can just look up the pair to find its cube sum
    # instead of calculating the cubes and their sum every time.
    pair_to_sum = {
        (a, b): cube[a] + cube[b]
        for a in range(n)
        for b in range(n)
    }
    # It's kind of interesting that I only make this once, but then I can
    # use it for both (a, b) AND for (c, d).
    
    # lets make solutions a set. 
    # adding and membership testing is faster than a list,
    # order is not important, uniqueness is.
    solutions = set()
    for a in range(1, n):
        for b in range(1, a):
            for c in range(1, n):
                for d in range(1, c):
                    # This double loop (1 -> n, 1 -> a) has the 
                    # cool property that we always know a > b. 
                    # Could this be used to sort solutions 
                    # so we don't get duplicates?
                    # Same for c, d.
                    
                    if not pair_to_sum[(a, b)] == pair_to_sum[(c, d)]:
                        # doesn't add up
                        continue
                        
                    solution = (a, b, c, d)                    
                    if boring(solution):
                        continue
                        
                    # No need to filter out boring solutions
                    # since miniming that range precludes the permutations                
                    solution = (a, b, c, d)
                    solutions.add(solution)
                    
    # but we're still creating non-unique ones, so we should filter them
    return unique(solutions)

In [7]:
# Check that the examples from the problem work
solutions = improved_sum_of_cubes(20)
assert contains(solutions, (2, 16, 9, 15))   # 2^3 + 16^3 = 9^3 + 15^3
assert contains(solutions, (1, 12, 9, 10))   # 1^3 + 12^3 = 9^3 + 10^3
assert len(solutions) == 2
print(f'Found {len(solutions)} solutions')

Found 2 solutions


In [8]:
n = 100
%time solutions = improved_sum_of_cubes(n)
print(f'Found {len(solutions)} solutions')

n = 200
%time solutions = improved_sum_of_cubes(n)
print(f'Found {len(solutions)} solutions')

CPU times: user 4.76 s, sys: 0 ns, total: 4.76 s
Wall time: 4.76 s
Found 45 solutions
CPU times: user 1min 28s, sys: 104 ms, total: 1min 28s
Wall time: 1min 28s
Found 135 solutions


## But not good enough
Wow! Okay so for n=100 we've gone from 1m40s to 5s! 20x speedup!

But n=200 still takes 1m30s... so we still need to do better to get to 1000.

The problem is that the approaches mentioned above, even if they reduce the search space by a very substantial _linear_ factor, do not reduce the runtime complexity down from O(n^4).


## The leap from O(n^4) to O(n^2)

So it's kind of interesting that the `pair_to_sum` map which we created in the previous approach
is created only once, but then it is used for both `a, b` and ALSO for `c, d` in exactly
the same way.

So the "trick" here is to realise that we don't really need to compare `a, b` against `c, d`.
We could just compare `a, b` against themselves.

And what we really care about is **whether the sum of their cubes match**. 
So if instead of looking up from pairs to cubes, we could go the other way somehow, then we could just see if any of those cubed values maps to more than one pair. If it does, we have a solution!

So instead of mapping from from  
`pair` -> `cube_sum`  
as we did in `improved_sum_of_cubes()` let's try to map from  
`cube_sum` -> `pairs`

In [9]:
from collections import defaultdict
from itertools import combinations

def unique_pairs_and_sums(n):
    '''
    Generate unique pairs and their cube sum
    
    I've chosen to factor out this double loop because:
    1) it lets me name the iterable for better self-documenting code
    2) it keeps the main logic less nested
    3) it let's me prepare iterated values exactly how I want them
    Bonus: Less memory
    '''
    cube = [i**3 for i in range(n)]   # starts at 0 so index lookups work
    for a in range(1, n):
        for b in range(1, a):
            yield (a, b), cube[a] + cube[b]

def map_cubes_to_pairs(n):
    '''
    Build up the cube_to_pairs map
    which is a mapping from
    a^3 + b^3 -> set((a1, b1), (a2, b2), ...)

    So for our n=20 examples:
        1^3 + 12^3 = 9^3 + 10^3 
        2^3 + 16^3 = 9^3 + 15^3
    we would have:
        cube_to_pairs = {
            2: {(1, 1)},                    If there's only a single pair, it's not a solution
            9: {(1, 2)},
            ...
            1729: {(1, 12), (9, 10)},       Whenever there are multiple pairs, all their
            4103: {(2, 16), (9, 15)},       pairwise combinations are solutions.
        }
    '''
    cube_to_pairs = defaultdict(set)
    for pair, cube_sum in unique_pairs_and_sums(n):
        cube_to_pairs[cube_sum].add(pair)
    return cube_to_pairs.values()

def pairwise_combinations(same_sum_pairs):
    '''
    Iterate over the same_sum_pairs
    and whenever there is more than 1 pair,
    the pairwise combinations of those pairs 
    are the solutions
    '''
    return {
        (a, b, c, d)
        for pairs in same_sum_pairs
        for (a, b), (c, d) in combinations(pairs, 2)
    }
    
def sum_of_cubes(n):
    '''
    Phase 1: Figure out all the pairs of numbers that add to the same cube_sum
    Phase 2: Get the pairwise combinations of them --> These are the solutions
    '''
    same_sum_pairs = map_cubes_to_pairs(n)         # Phase 1
    return pairwise_combinations(same_sum_pairs)   # Phase 2

In [10]:
# Check that the examples from the problem work
solutions = sum_of_cubes(20)
assert contains(solutions, (2, 16, 9, 15))   # 2^3 + 16^3 = 9^3 + 15^3
assert contains(solutions, (1, 12, 9, 10))   # 1^3 + 12^3 = 9^3 + 10^3
assert len(solutions) == 2
print(f'Found {len(solutions)} solutions')

Found 2 solutions


In [11]:
n = 100
%time solution = sum_of_cubes(n)
print(f'Found {len(solution)} solutions')

n = 200
%time solution = sum_of_cubes(n)
print(f'Found {len(solution)} solutions')

n = 1000
%time solution = sum_of_cubes(n)
print(f'Found {len(solution)} solutions')

CPU times: user 11.6 ms, sys: 0 ns, total: 11.6 ms
Wall time: 12.2 ms
Found 45 solutions
CPU times: user 56.1 ms, sys: 4 µs, total: 56.1 ms
Wall time: 56.7 ms
Found 135 solutions
CPU times: user 1.23 s, sys: 88 ms, total: 1.32 s
Wall time: 1.33 s
Found 1598 solutions


We're 10,000x faster than `interesting(naive_sum_cubes(n))` at `n=100`  
and 1,000x faster than `improved_sum_of_cubes(n)` at `n=200`

n=1000 in a second?  
That'll do.