In [None]:
from itertools import accumulate, count, cycle, islice, takewhile, tee
import math

from heapq import heappush, heappop, heapreplace
def heapmin(heap):
    return heap[0]

In [None]:
# imports not used in article
from itertools import zip_longest
from tqdm.notebook import tqdm
from IPython.display import Image

# The Easy, Slow Way

In [None]:
def is_prime(n):
    """Check if number is prime by testing all numbers to sqrt(n)
    """
    assert n > 0, 'only integers > 0'
    if n == 1:
        return False
    for i in range(2, math.floor(math.sqrt(n)) + 1):
        if n % i == 0:
            return False
    return True

In [None]:
def primes_brute1(n):
    """Generator which yields primes up to n
    """
    return (i for i in range(2, n) if is_prime(i))

primes100 = list(primes_brute1(100))
print(primes100)

In [None]:
def primes_brute2():
    """Generator which yields all primes
    """
    return (i for i in count(1) if is_prime(i))

In [None]:
assert list(takewhile(lambda x: x<100, primes_brute2())) == primes100

In [None]:
def last(s):
    """Get last element of sequence
    """
    for last in s:
        pass
    return last

In [None]:
%%time
last(takewhile(lambda x: x<int(1e7), primes_brute2()))

# Primes Sieve

In [None]:
def primes_sieve1(n):
    """Primes using sieve of erasthenes
    Generate all primes less than n
    """
    n = int(n)
    prime_sieve = [True] * n
    for i in range(2, math.ceil(math.sqrt(n))):
        for j in range(i*i, n, i):
            prime_sieve[j] = False
    return [
        idx for idx, is_prime 
        in enumerate(prime_sieve) 
        if is_prime and idx>1
    ]
print(primes_sieve1(100))

In [None]:
assert primes_sieve1(100) == primes100

In [None]:
%%time
primes_sieve1(1e7)
None

# Incremental Sieve

In [None]:
def primes_sieve2():
    """Incremental Sieve of Erasthenes
    """
    yield 2
    pqueue = [(4, 2)]
    for i in count(3):
        while i > heapmin(pqueue)[0]:
            np, p = heappop(pqueue)
            heappush(pqueue, (np + p, p))
        if i != heapmin(pqueue)[0]:
            yield i
            heappush(pqueue, (i*i, i))

In [None]:
assert list(takewhile(lambda x: x<100, primes_sieve2()))==primes100

In [None]:
%%time
last(takewhile(lambda x: x<int(1e7), primes_sieve2()))

# Incremental Sieve Plot

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.collections as coll
from pathlib import Path

# https://stackoverflow.com/questions/44417945/creating-a-grid-of-squares-patches-in-matplotlib
class NumberGrid:
    def __init__(self):     
        self.squares = []
        self.saves = 0

        self.fig = plt.figure(figsize=(8, 8))
        self.ax = plt.Axes(self.fig, [0., 0., 1., 1.])
        
        self.text = None

        w, h = 1, 1
        nrows, ncols = 12, 10
        inbetween = 0.1

        self.ax.set_axis_off()
        self.fig.add_axes(self.ax)

        xx = np.arange(0, ncols+1, (w+inbetween))
        yy = np.arange(0, nrows+1, (h+inbetween))

        pat = []
        idx = nrows * ncols
        for yi in yy:
            for xi in xx[::-1]:
                sq = patches.Rectangle((xi, yi), w, h, fill=True, color='lightgrey')
                self.ax.add_patch(sq)
                self.squares.append(sq)
                x, y = xi + w / 2, yi + h / 2
                if idx != 1:
                    self.ax.annotate(idx, (x, y), fontsize=16, weight='bold', va='center', ha='center')
                idx-=1 
        self.squares.reverse()

        pc = coll.PatchCollection(pat)
        self.ax.add_collection(pc)

        plt.axis('off')
        self.ax.autoscale_view()
        
        self.squares[0].set_visible(False)
        
        self.highlight = patches.Rectangle((0, 0), 1, 1, fill=False, color='black', linewidth=2.5)
        self.highlight.set_visible(False)
        self.ax.add_patch(self.highlight)

    def set_color(self, idx, color):
        # fail quietly if out of range
        if (idx-1) < len(self.squares):
            self.squares[idx-1].set_facecolor(color)
        
    def annotate(self, s):
        if not self.text:
            self.text = self.ax.annotate(s, (11, 13), fontsize=16, va='top', family='monospace')
        else:
            self.text.set_text(s)

    def save(self):
        plt.savefig(f'image{self.saves:03d}.png',
                    dpi=90, cmap='RGB', bbox_inches='tight')
        self.saves += 1
    
    def rm(self):
        for i in Path('.').glob('image*.png'):
            os.remove(i)
    
    def set_highlight(self, i):
        i, j = (i-1) % 10, 11 - ((i-1)//10)
        i, j = i*1.1, j*1.1
        self.highlight.set_visible(True)
        self.highlight.set_xy((i, j))
    
    def animate(self, fname):
        cmd1 = 'ffmpeg -i image%03d.png -vf palettegen palette.png -y'
        cmd2 = f'ffmpeg -r 2/1 -i image%03d.png  -i palette.png -lavfi paletteuse {fname} -y'
        assert os.system(cmd1) == 0
        assert os.system(cmd2) == 0

In [None]:
grid = NumberGrid()
grid.set_highlight(15)
grid.annotate('testing')

In [None]:
def grouper(iterable, n, fillvalue=None):
    "Collect data into fixed-length chunks or blocks"
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
    args = [iter(iterable)] * n
    return zip_longest(fillvalue=fillvalue, *args)

def fmt_primes(primes, columns=4, rows=8):
    if primes == []:
        primes = [' ']
    primes_str = grouper(primes, columns, '')
    primes_str = '\n'.join([
        ' '.join([str(i).ljust(4) for i in primes])
        for primes
        in primes_str
    ])
    return  'primes:\n' + primes_str

print(fmt_primes([1, 2, 3, 4, 5]))

In [None]:
colors = {
    2: 'lightcoral',
    3: 'lightgreen',
    5: 'lightblue',
    7: 'lightyellow'
}

colors1 = {
    2: 'red',
    3: 'green',
    5: 'blue',
    7: 'yellow'
}

idx_color = 'white'

In [None]:
grid = NumberGrid()
grid.rm()

queue = [(4, 2)]
primes = [2]

grid.set_highlight(2)
grid.annotate(fmt_primes(primes))
grid.save()

grid.set_color(2, colors1.get(2))
grid.set_color(4, colors.get(2))

grid.annotate(fmt_primes(primes))
grid.save()
grid.set_color(2, colors1.get(2))

for i in tqdm(range(3, 121)):
    grid.set_highlight(i)
    grid.annotate(fmt_primes(primes))
    grid.save()
    while i > heapmin(queue)[0]:
        np_, p = heappop(queue)
        heappush(queue, (np_ + p, p))
        if i > heapmin(queue)[0]:
            grid.set_color(np_, colors.get(heapmin(queue)[1]))
        else:
            grid.set_color(np_, 'lightgrey')
        grid.set_color(np_ + p, colors.get(p, 'purple'))
        grid.annotate(fmt_primes(primes))
        grid.save()
    if i != heapmin(queue)[0]:
        primes.append(i)
        heappush(queue, (i*i, i))
        grid.annotate(fmt_primes(primes))
        if i*i <= 121:
            grid.set_color(i*i, colors.get(i, 'purple'))
        grid.set_color(i, colors1.get(i, 'purple'))
        grid.annotate(fmt_primes(primes))
        grid.save()
    elif i in [i[0] for i in queue]:
        grid.set_color(i, colors.get(heapmin(queue)[1], 'black'))
    else:
        grid.set_color(i, 'lightgrey')

In [None]:
grid.animate('animated.gif')
Image(filename="animated.gif")

# Wheels

In [None]:
def wheel23():
    """Wheel generates numbers coprime to 2 and 3
    """
    yield from accumulate(cycle([2, 4]), initial=1)

list(islice(wheel23(), 10))

In [None]:
# https://docs.python.org/3/library/itertools.html
def pairwise(iterable):
    """s -> (s0,s1), (s1,s2), (s2, s3), ...
    """
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

In [None]:
def coprime(a, b):
    """Check if a and be have any factors in common other than 1
    """
    return math.gcd(a, b) == 1

assert not coprime(10, 5)
assert coprime(7, 22)
assert not coprime(3, 6)

In [None]:
def wheel(n):
    """Generate wheel coprime to first n primes
    Returns tuple of primes, then iterator of wheel
    """
    assert n > 1
    initial_primes = list(islice(primes_brute2(), n))
    n = math.prod(initial_primes)
    coprimes = [i for i in range(1, n+1) if coprime(i, n)]
    diffs = [
        (b - a)
        for a, b
        in pairwise(coprimes)
    ]
    diffs.append((coprimes[0] - coprimes[-1]) % n)
    coprimes = accumulate(cycle(diffs), initial=1)
    # skip 1
    next(coprimes)
    return (
        initial_primes,
        coprimes
    )

In [None]:
first_primes, wheel23 = wheel(2)
print(first_primes)
seq = [5, 7, 11, 13, 17, 19, 23, 25, 29, 31]
assert list(islice(wheel23, 10)) == seq

In [None]:
first_primes, wheel235 = wheel(3)
print(first_primes)
first_rotation = [7, 11, 13, 17, 19, 23, 29, 31, 37, 41]
print(first_rotation)
assert list(islice(wheel235, 10)) == first_rotation

In [None]:
import numpy as np
import pandas as pd

In [None]:
def wheel_stats(n):
    initial_primes = list(islice(primes_brute2(), n))
    n = math.prod(initial_primes)
    coprimes = [coprime(i, n) for i in range(1, n+1)]
    coprimes = np.array(coprimes) * 1
    return len(initial_primes), coprimes.mean(), coprimes.size

wheel_stats(1)

In [None]:
%%time
res = []
for i in range(1, 10):
    res.append(wheel_stats(i))

In [None]:
res = pd.DataFrame(res, columns=['primes', 'speed up', 'array size'])
res['array size'] = res['array size'].apply(lambda x: '{:,}'.format(x))
res['speed up'] = (1 / res['saving']).apply(lambda x: f'{x:.2f}x')
print(res[['primes', 'speed up', 'array size']].to_html(index=False))

# Incremental + Sieve

In [None]:
def primes_sieve3(wheel_size):
    """Incremental Sieve of Erasthenes, using wheel
    """
    first_primes, wheel_ = wheel(wheel_size)
    yield from first_primes
    
    p = next(wheel_)
    yield p
    pqueue = [(p*p, p)]

    for i in wheel_:
        while i > heapmin(pqueue)[0]:
            np, p = heappop(pqueue)
            heappush(pqueue, (np + p, p))
        if i != heapmin(pqueue)[0]:
            yield i
            heappush(pqueue, (i*i, i))

In [None]:
assert list(takewhile(lambda x: x<100, primes_sieve3(3))) == primes100

In [None]:
%%time
last(takewhile(lambda x: x<int(1e7), primes_sieve3(7)))

In [None]:
grid = NumberGrid()
grid.rm()

primes = [3, 5, 7]
grid.annotate(fmt_primes(primes))
grid.save()
for i in range(121):
    if i not in (wheel_ + first_primes):
        grid.set_color(i, 'darkgrey')
grid.save()

for i in wheel_:
    grid.set_highlight(i)
    grid.save()
    while i > heapmin(queue)[0]:
        np_, p = heappop(queue)
        heappush(queue, (np_ + p, p))
        grid.set_color(np_+p, colors.get(p, 'black'))
    if i != heapmin(queue)[0]:
        primes.append(i)
        heappush(queue, (i*i, i))
        grid.annotate(fmt_primes(primes))
        grid.set_color(i, colors1.get(i, 'purple'))
        grid.set_color(i*i, colors.get(i, 'black'))

In [None]:
grid.animate('prime-sieve-wheel.gif')
Image(filename="prime-sieve-wheel.gif.gif")