In [58]:
from time import time

def time_elapsed(t):
    seconds = time() - t
    
    minutes, seconds = divmod(seconds, 60)
    hours, minutes = divmod(minutes, 60)
    
    return "{:.0f}h {:.0f}m {:f}s".format(hours, minutes, seconds)

# Project Euler problem 196 - Prime triplets

[Link to problem on Project Euler homepage](https://projecteuler.net/problem=196)

## Description

Build a triangle from all positive integers in the following way:

````
 1
 2  3
 4  5  6
 7  8  9 10
11 12 13 14 15
16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31 32 33 34 35 36
37 38 39 40 41 42 43 44 45
46 47 48 49 50 51 52 53 54 55
56 57 58 59 60 61 62 63 64 65 66
. . .
````

Each positive integer has up to eight neighbours in the triangle.

A set of three primes is called a prime triplet if one of the three primes has the other two as neighbours in the triangle.

For example, in the second row, the prime numbers 2 and 3 are elements of some prime triplet.

If row 8 is considered, it contains two primes which are elements of some prime triplet, i.e. 29 and 31.
If row 9 is considered, it contains only one prime which is an element of some prime triplet: 37.

Define S(n) as the sum of the primes in row n which are elements of any prime triplet.
Then S(8)=60 and S(9)=37.

You are given that S(10000)=950007619.

Find  S(5678027) + S(7208785).

In [59]:
from math import sqrt, ceil

def triangular(n):
    return n*(n+1)//2

def first(row):
    return triangular(row)+1

def last(row):
    return triangular(row+1)

def row_index(n):
    return ceil((sqrt(8*n + 1) - 1)/2)-1

def column_index(n, row=None):
    row = row_index(n) if row is None else row
    column = n - first(row)
    return column

def row_column_index(n):
    row = row_index(n)
    return (row, column_index(n, row))

def number(row, column):
    return first(row)+column

def neighbour_indices(row, column):
    if row == 0 and column == 0:
        return [(1, 0), (1, 1)]
    elif row == 1 and column == 0:
        return [(0, 0), (1, 1), (2, 0), (2, 1), (2, 2)]
    elif row == 1 and column == 1:
        return [(0, 0), (1, 0), (2, 0), (2, 1), (2, 2)]
    else:
        if column == 0:
            return [(row-1, 0), (row-1, 1), (row, 1), (row+1, 0), (row+1, 1)]
        elif column == row-1:
            return [(row-1, column-1), (row-1, column), (row, column-1), (row, column+1), (row+1, column-1), (row+1, column), (row+1, column+1)]
        elif column == row:
            return [(row-1, column-1), (row, column-1), (row+1, column-1), (row+1, column), (row+1, column+1)]
        else:
            return [(row-1, column-1), (row-1, column), (row-1, column+1), (row, column-1), (row, column+1), (row+1, column-1), (row+1, column), (row+1, column+1)]

def neighbour_numbers(n):
    row, column = row_column_index(n)
    return [number(r, c) for r, c in neighbour_indices(row, column)]

In [60]:
assert all([first(0) == 1, first(1) == 2, first(2) == 4, first(3) == 7, first(4) == 11, first(5) == 16])

assert all([last(0) == 1, last(1) == 3, last(2) == 6, last(3) == 10, last(4) == 15, last(5) == 21])

assert all([row_index(1) == 0, row_index(2) == 1, row_index(3) == 1, row_index(29) == 7, row_index(59) == 10])

assert column_index(10, row=3) == 3

assert number(3, 3) == 10

print(neighbour_indices(3, 0))
print(neighbour_indices(3, 1))
print(neighbour_indices(3, 2))
print(neighbour_indices(3, 3))

print(neighbour_numbers(7))
print(neighbour_numbers(8))
print(neighbour_numbers(9))
print(neighbour_numbers(10))

[(2, 0), (2, 1), (3, 1), (4, 0), (4, 1)]
[(2, 0), (2, 1), (2, 2), (3, 0), (3, 2), (4, 0), (4, 1), (4, 2)]
[(2, 1), (2, 2), (3, 1), (3, 3), (4, 1), (4, 2), (4, 3)]
[(2, 2), (3, 2), (4, 2), (4, 3), (4, 4)]
[4, 5, 8, 11, 12]
[4, 5, 6, 7, 9, 11, 12, 13]
[5, 6, 8, 10, 12, 13, 14]
[6, 9, 13, 14, 15]


In [75]:
from sympy import primerange, isprime

def prime_neighbours(n):
    return [_n for _n in neighbour_numbers(n) if isprime(_n)]

def S(row_number):
    row_index = row_number - 1
    #primes = list(primerange(first(row_index-2), last(row_index+2)+1))
    
    total = 0
    pl = list(primerange(first(row_index), last(row_index)+1))
    for i, prime in enumerate(pl):
        pn = prime_neighbours(prime)
        if len(pn) >= 2:
            total += prime
        elif len(pn) == 1:
            _pn = prime_neighbours(pn[0])
            if len(_pn) >= 2:
                total += prime
    return total

print(S(8))
print(S(9))
print(S(10000))

60
37
950007619


In [77]:
t0 = time()

t1 = S(5678027)
print(t1, time_elapsed(t0))
t2 = S(7208785)
print(t2, time_elapsed(t0))
total = t1+t2

print("Result: {}".format(total))
print("Time elapsed: {}".format(time_elapsed(t0)))

79697256800321526 0h 0m 12.982290s
242605983970758409 0h 0m 30.244469s
Result: 322303240771079935
Time elapsed: 0h 0m 30.244639s
