In [3]:
import pandas as pd
import numpy as np

import random
from collections import Counter
from itertools import combinations, product
from functools import reduce

My logic here was to filter out so we only have to do Miller Rabin on a smaller amount of numbers. Miller-Rabin allows us to figure out if a number is prime very quickly.

For each $d$ we set $M(10, d) = 9$ by default (cannot be $10$). Then we go in reverse, trying to find how many repeated digits can I have and still have a prime: if there is a prime with $d$ repeated $9$ times, then we can break out; if not, then we have to check again with $8$. The important thing here is that once we find $1$ that works, we can break to save time.

Then in the second half, we go through a smaller list of numbers. For each $d$, $M(10, d)$ is now known from the above, so we can just yield all possible numbers (i.e., permutations) with $d$ repeated $M(10, d)$ times. We count them in $N(10, d)$ and sum them in $S(10, d)$.

This algorithm is quite fast, and can handle even $n = 50$ in well under a minute. Needless to say, $n = 10$ is nearly instant.

In [12]:
def miller_rabin(n, k = 10):
    if n <= 3:
        return n == 2 or n == 3

    if n % 2 == 0:
        return False

    r, s = 0, n - 1
    while s % 2 == 0:
        r += 1
        s //= 2
    for _ in range(k):
        a = random.randrange(2, n - 1)
        x = pow(a, s, n)
        if x == 1 or x == n - 1:
            continue
        for _ in range(r - 1):
            x = pow(x, 2, n)
            if x == n - 1:
                break
        else:
            return False
    return True

In [14]:
number_of_digits = 10
# m[d] = maximum times that d is repeated
m = [9 for _ in range(10)]
# n[d] = number of primes with m[d] d's
n = [0 for _ in range(10)]
# s[d] = sum of primes with m[d] d's
s = [0 for _ in range(10)]

digs = list(range(10))
# figure out what correct m values are
for d in digs:
    complete = False
    other_digs = digs.copy()
    other_digs.remove(d)

    # x represents the number of times to repeat the digit
    for x in range(number_of_digits)[::-1]:

        # t creates a tuple of other digits
        for t in product(other_digs, repeat = number_of_digits-x):

            # positions tells us where to put the d's
            # tup represents a number with x d's and number_of_digits - x of other digits
            for positions in combinations(range(0 if d > 0 else 1, number_of_digits), x):
                tup = [0]*number_of_digits
                ind = 0
                for i in range(number_of_digits):
                    if i in positions:
                        tup[i] = d
                    else:
                        tup[i] = t[ind]
                        ind += 1
                
                tup = tuple(tup)
                num = reduce(lambda a,b: a + tup[b-1]*(10**(number_of_digits - b)), range(1, number_of_digits+1), 0)
                
                if miller_rabin(num):
                    print(num, x)
                    complete = True
                    m[d] = x
                    break
            
            if complete:
                break
        
        if complete:
            break
print()

# given the m values, we can quickly find the satisfactory primes
# pretty similar logic to above, except don't need to loop through x's
for d in digs:
    other_digs = digs.copy()
    other_digs.remove(d)
    max_repeats = m[d]

    for positions in combinations(range(0 if d > 0 else 1, number_of_digits), max_repeats):
        for t in product(other_digs, repeat = number_of_digits - max_repeats):
            tup = [0]*number_of_digits
            ind = 0

            for i in range(number_of_digits):
                if i in positions:
                    tup[i] = d
                else:
                    tup[i] = t[ind]
                    ind += 1

            tup = tuple(tup)
            num = reduce(lambda a,b: a + tup[b-1]*(10**(number_of_digits - b)), range(1, number_of_digits+1), 0)
            
            # check it still has 10 digits i.e., no leading 0
            if num >= 10**(number_of_digits - 1) and num < 10**(number_of_digits) and miller_rabin(num):
                n[d] += 1
                s[d] += num

print('sum of sums:', sum(s))
pd.DataFrame({'m': m, 'n': n, 's': s})

1000000007 8
1111111121 9
2022222221 8
3333133333 9
4444444447 9
5555555557 9
6666666661 9
7777717777 9
8888880881 8
9199999999 9

sum of sums: 612407567715


Unnamed: 0,m,n,s
0,8,8,38000000042
1,9,11,12882626601
2,8,39,97447914665
3,9,7,23234122821
4,9,1,4444444447
5,9,1,5555555557
6,9,1,6666666661
7,9,9,59950904793
8,8,32,285769942206
9,9,8,78455389922
