In [66]:
import numpy as np
from time import time

In [213]:
def test_sort(func):
    t = time()

    a = np.array([4, 3, 2, 1])
    a = np.array(func(a))
    assert (np.array(a - np.array([1, 2, 3, 4])) ** 2).sum() == 0

    a = np.random.permutation(10)
    b = a.copy()
    a = np.array(func(a))
    assert (np.array(a - sorted(b)) ** 2).sum() == 0

    for _ in range(100):
        a = np.random.permutation(1000)
        func(a)

    return time() - t

#### naive sort $O(n^2)$

In [214]:
def naivesort(a):
    """Return sorted a, the wrong way."""
    a = list(a)
    sorted_a = []
    while len(a):
        smallest_i = None
        for i, x in enumerate(a):
            smallest_i = i if (smallest_i is None) or (a[i] < a[smallest_i]) else smallest_i
        sorted_a.append(a.pop(smallest_i))
    return sorted_a


In [215]:
test_sort(naivesort)

1.7066919803619385

#### mergesort

In [216]:
def mergesort(a, l=None, r=None):
    """Return sorted a. Including l, excluding r."""
    l = 0 if l is None else l
    r = len(a) if r is None else r
    
    if l < r - 1:  # at least two elements: [l, l + 1 <= r - 1]
        mid = (l + r) // 2  # 0..2 --> 1, 0..3 --> 1
        mergesort(a, l, mid)  # [0]
        mergesort(a, mid, r)  # [1, (2)]

        # now merge
        left, right = a[l:mid].copy(), a[mid:r].copy()
        i, j = 0, 0
        n, m = mid - l, r - mid
        for k in range(l, r):
            if j == m or (i < n and left[i] < right[j]):
                a[k] = left[i]
                i += 1
            else:
                a[k] = right[j]
                j += 1
                
        return a  # also changed in-place, but return just for ease later

In [217]:
for _ in range(5):
    print(test_sort(mergesort))

0.2880868911743164
0.2540709972381592
0.25176405906677246
0.2529878616333008
0.25121307373046875


#### quicksort

In [218]:
def quicksort(a):
    """Return sorted a."""
    n = len(a)   
    if n <= 1:
        return a
    
    p = np.random.choice(n)

    left, right = [], []
    for i, x in enumerate(a):
        if i == p:
            pass
        elif x < a[p]:
            left.append(x)
        else:
            right.append(x)

    return quicksort(left) + [a[p]] + quicksort(right)


In [220]:
for _ in range(5):
    print(test_sort(quicksort))

0.305372953414917
0.2873969078063965
0.290391206741333
0.3024420738220215
0.29720401763916016
