In [280]:
import numpy as np

from math import gcd, sqrt
from scipy.special import comb
from tqdm import trange

First step is to check if $n = a^{b}$. We do it as a brute force search by power $(b)$ and binary search by base $(a)$. Time complexity is $\mathcal{O}^{\sim}(log^{3}n)$.

In [281]:
def checkpower(i, n):
    left = 2
    right = n
    flag = False

    while left <= right and not flag:
        mid = int((left + right) / 2)
        temp = mid ** i
        if temp == n:
            flag = True
        else:
            if n < temp:
                right = int((left + right) / 2) - 1
            else:
                left = int((left + right) / 2) + 1

    return flag

In [282]:
def checkperfectpower(n):
    for i in np.arange(2, np.log2(n) + 1):
        if checkpower(i, n):
            return True
    return False

Second step. Time complexity is $\mathcal{O}^{\sim}(log^{7}n)$.

In [283]:
def getord(r, n, threshold):
    if gcd(r, n) > 1:
        return False
    for i in np.arange(1, threshold):
        if (n ** int(i)) % r == 1:
            return False
    return True

In [284]:
def smallestr(n):    
    rmax = max(3, int(np.ceil(np.log2(n) ** 5)))
    threshold = np.log2(n) ** 2
    for r in np.arange(2, rmax + 1):
        if getord(r, n, threshold):
            return r

Third step. Time complexity is $\mathcal{O}(log^{6}n)$.

In [285]:
def elimination(r, n):
    for a in np.arange(2, r + 1):
        if 1 < gcd(n, a) < n:
            return True
    return False

Forth step. Time complexity is $\mathcal{O}(log\ n)$.

Fifth step. Time complexity is $\mathcal{O}(log^{\frac{21}{2}}n)$.

In [286]:
def euler(n):
    amount = 0
    for k in range(1, n + 1):
        if gcd(n, k) == 1:
            amount += 1
    return amount

In [287]:
def polynomial_coef(n, a):
    ex = []
    for i in range(n + 1):
        ex.append(comb(n, n - i, exact=True) * (a ** (n - i)))
    return ex[::-1]

In [288]:
def reduce(polynomial, r):
    for i in range(len(polynomial) - r):
        k = polynomial[i]
        polynomial[i] = 0
        polynomial[i + r] -= k
    return polynomial

In [293]:
def modul(n, r):
    for a in np.arange(1, np.floor(sqrt(euler(n)) * np.log2(n)) + 1):
        coef = polynomial_coef(n, int(a))
        coef[0] -= 1
        coef[-1] -= int(a)
        divided = reduce(coef, int(r))
        if not all(x%n == 0 for x in divided):
            return False
    return True

To sum up:

In [303]:
def AKS(n):
    if checkperfectpower(n):
        return False
    else:
        r = smallestr(n)
        if elimination(r, n):
            return False
        else:
            if n <= r:
                return True
            else:
                if modul(n, r):
                    return True
                else:
                    return False

In [291]:
def is_prime(a):
    return all(a % i for i in range(2, a))

Correctness

In [314]:
for i in trange(1, 1001):
    if AKS(i) != is_prime(i):
        print(i)

100%|██████████| 1000/1000 [14:21<00:00,  1.16it/s]


In [350]:
%%time
AKS(10000000000000000001)

CPU times: user 8 ms, sys: 0 ns, total: 8 ms
Wall time: 7.41 ms




False

In [351]:
%%time
is_prime(10000000000000000001)

CPU times: user 0 ns, sys: 0 ns, total: 0 ns
Wall time: 14.5 µs


False

In [352]:
%%time
AKS(98764321261)



KeyboardInterrupt: 

In [354]:
%%time
AKS(23456789101112)



CPU times: user 33 s, sys: 1.74 s, total: 34.7 s
Wall time: 34.7 s


False

**TODO**
* Change numpy to smth else (because it doesn't work well with big numbers)