# Merge sort

Implement merge sort. As an extension, minimise the extra space used. There's a way to keep part of the array sorted and use the rest as working space but it's quite complicated. Aiming for one additional copy of the array is a good goal.


The idea is to divide and conquer by halving the array and merging the splits together. We need a merge function that takes two sorted arrays and combines them.


In [34]:
import random
import timeit

In [41]:
def merge(left: list[int], right: list[int]) -> list[int]:
    combined = []
    l, r = 0, 0

    while l < len(left) and r < len(right):
        if left[l] < right[r]:
            m = left[l]
            l += 1
        else:
            m = right[r]
            r += 1

        combined.append(m)

    # Append any left over elements if left and right are unbalanced
    if l < len(left):
        return combined + left[l:]
    if r < len(right):
        return combined + right[r:]

    return combined

In [42]:
assert merge([1, 3, 5], [2, 4, 6]) == [1, 2, 3, 4, 5, 6]
assert merge([1, 3, 5, 7, 9], [2, 4, 6]) == [1, 2, 3, 4, 5, 6, 7, 9]
assert merge([1, 3, 5], [0, 2, 4, 6, 8]) == [0, 1, 2, 3, 4, 5, 6, 8]

Merge sort is then a recursive wrapper around this.


In [43]:
def mergesort(arr: list[int]) -> list[int]:
    if len(arr) <= 1:
        return arr

    mid = len(arr) // 2

    left = mergesort(arr[:mid])
    right = mergesort(arr[mid:])

    return merge(left, right)

In [44]:
assert mergesort([5, 4, 3, 2, 1]) == [1, 2, 3, 4, 5], mergesort([5, 4, 3, 2, 1])
assert mergesort([8, 4, 2, 1]) == [1, 2, 4, 8], mergesort([8, 4, 2, 1])

Now, to optimise memory usage. How we avoid making heaps of copies of the array? Draw a diagram of all of this.


In [67]:
def _merge(arr: list[int], l: int, mid: int, r: int):
    """
    Merge the intervals [l, m) and [m, r)
    """
    ans = []

    i = l
    j = mid

    while i < mid and j < r:
        if arr[i] < arr[j]:
            ans.append(arr[i])
            i += 1
        else:
            ans.append(arr[j])
            j += 1

    if i < mid:
        ans += arr[i:mid]
    if j < r:
        ans += arr[j:r]

    arr[l:r] = ans


def _mergesort(arr: list[int], l: int, r: int):
    if r - l <= 1:
        return

    mid = l + (r - l) // 2

    _mergesort(arr, l, mid)
    _mergesort(arr, mid, r)

    _merge(arr, l, mid, r)


def mergesort2(arr: list[int]) -> list[int]:
    _mergesort(arr, 0, len(arr))
    return arr

In [69]:
def test(sort):
    arr = [random.randint(-100, 100) for _ in range(10000)]
    got = sort(arr)
    arr.sort()
    assert got == arr, f"{arr} vs {got}"


naive = timeit.timeit(lambda: test(mergesort), number=100)
print(f"naive merge sort took {naive*1000:.2f}ms")

optimised = timeit.timeit(lambda: test(mergesort2), number=100)
print(f"optimised merge sort took {optimised*1000:.2f}ms")

naive merge sort took 2324.81ms
optimised merge sort took 1858.66ms
