# Merge Sort

[Click here to run this chapter on Colab](https://colab.research.google.com/github/AllenDowney/DSIRP/blob/main/notebooks/mergesort.ipynb)

## Implementing Merge Sort

[Merge sort](https://en.wikipedia.org/wiki/Merge_sort) is a divide and conquer strategy:

1. Divide the sequence into two halves,

2. Sort the halves, and

3. Merge the sorted sub-sequences into a single sequence.

Since step 2 involves sorting, this algorithm is recursive, so we need a base case.
There are two options:

1. If the size falls below some threshold, we can use another sort algorithm.

2. If the size of a sub-sequence is 1, it is already sorted.

[Comparison with other sort algorithms](https://en.wikipedia.org/wiki/Merge_sort#Comparison_with_other_sort_algorithms)

To implement merge sort, I think it's helpful to start with a non-recursive version that uses the Python `sort` function to sort the sub-sequences.

In [1]:
def merge_sort_norec(xs):
    n = len(xs)
    mid = n//2
    left = xs[:mid]
    right = xs[mid:]

    left.sort()
    right.sort()

    return merge(left, right)

**Exercise:** Write a function called `merge` that takes two sorted sequences, `left` and `right`, and returns a sequence that contains all elements from `left` and `right`, in ascending order (or non-decreasing order, to be more precise).

Note: this function is not conceptually difficult, but it is notoriously tricky to get all of the edge cases right without making the function unreadable.
Take it as a challenge to write a version that is correct, concise, and readable.
I found that I could write it more concisely as a generator function.

In [3]:
def merge(left, right):
    i, j = 0, 0
    merged = []

    # Walk through both lists
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            merged.append(left[i])
            i += 1
        else:
            merged.append(right[j])
            j += 1

    # Append the leftovers (only one of these will run)
    merged.extend(left[i:])
    merged.extend(right[j:])

    return merged

def merge_generator(left, right):
    i, j = 0, 0
    while i < len(left) and j < len(right):
        if left[i] <= right[j]:
            yield left[i]
            i += 1
        else:
            yield right[j]
            j += 1

    # yield remaining items (one of these will be empty)
    while i < len(left):
        yield left[i]
        i += 1

    while j < len(right):
        yield right[j]
        j += 1


############################################################
# Example if we have
#left = [1, 3, 5]
#right = [2, 4]
#
# First while loop merges until right is exhausted →
#                          yields [1,2,3,4] with i=2, j=2.
# as j =2 which is lenght of second list so loop exits.
#Second while loop finishes remaining left →
#                                      yields [5].
#
#
################################################################


You can use the following example to test your code.

In [4]:
import random

population = range(100)
xs = random.sample(population, k=6)
ys = random.sample(population, k=6)
ys

[21, 24, 10, 18, 95, 16]

In [5]:
xs.sort()
print(xs)
ys.sort()
ys

[33, 39, 43, 50, 77, 93]


[10, 16, 18, 21, 24, 95]

In [6]:
res = list(merge_generator(xs, ys))
res

[10, 16, 18, 21, 24, 33, 39, 43, 50, 77, 93, 95]

In [7]:
sorted(res) == res

True

**Exercise:**  Starting with `merge_sort_norec`, write a function called `merge_sort_rec` that's fully recursive; that is, instead of using Python's `sort` function to sort the halves, it should use `merge_sort_rec`.  Of course, you will need a base case to avoid an infinite recursion.



In [8]:
def merge_sort_rec(xs):
    if len(xs) <= 1:  #base case
        return xs
    else:
        mid = len(xs) // 2
        #left = xs[:mid]
        #right = xs[mid:]
        left =  merge_sort_rec(xs[:mid])
        right = merge_sort_rec(xs[mid:])
        return merge(left, right)


Test your method by running the code in the next cell, then use `test_merge_sort_rec`, below, to check the performance of your function.

In [9]:
xs = random.sample(population, k=12)
xs

[41, 57, 45, 40, 25, 8, 94, 85, 93, 27, 70, 6]

In [10]:
res = list(merge_sort_rec(xs))
res

[6, 8, 25, 27, 40, 41, 45, 57, 70, 85, 93, 94]

In [11]:
sorted(res) == res

True

## Heap Merge

Suppose we want to merge more than two sub-sequences.
A convenient way to do that is to use a heap.
For example, here are three sorted sub-sequences.

In [27]:
xs = random.sample(population, k=5)
ys = random.sample(population, k=5)
zs = random.sample(population, k=5)

min(xs), min(ys), min(zs)

(3, 43, 2)

In [28]:
xs.sort()
ys.sort()
zs.sort()

For each sequence, I'll make an iterator and push onto the heap a tuple that contains:

* The first element from the iterator,

* An index that's different for each iterator, and

* The iterator itself.

When the heap compares two of these tuples, it compares the elements first.
If there's a tie, it compares the indices.
Since the indices are unique, there can't be a tie, so we never have to compare iterators (which would be an error).

In [29]:
sequences = [xs, ys, zs]

In [20]:
from heapq import heappush, heappop

heap = []
for i, seq in enumerate(sequences):
    iterator = iter(seq)
    first = next(iterator)
    heappush(heap, (first, i, iterator))

When we pop a value from the heap, we get the tuple with the smallest value.

In [21]:
value, i, iterator = heappop(heap)
value

0

If we know that the iterator has more values, we can use `next` to get the next one and then push a tuple back into the heap.

In [None]:
heappush(heap, (next(iterator), i, iterator))

If we repeat this process, we'll get all elements from all sub-sequences in ascending order.

However, we have to deal with the case where the iterator is empty.
In Python, the only way to check is to call `next` and take your chances!
If there are no more elements in the iterator, `next` raises a `StopIteration` exception, which you can handle with a `try` statement, like this:

In [22]:
iterator = iter(xs)

while True:
    try:
        print(next(iterator))
    except StopIteration:
        break

25
39
71
92
93


**Exercise:** Write a generator function called `heapmerge` that takes a list of sequences and yields the elements from the sequences in increasing order.

In [26]:
from heapq import heappush, heappop

from heapq import heappush, heappop

def heapmerge(sequences):
    heap = []
    # Step 1: Initialize the heap with the first element of each sequence
    for i, seq in enumerate(sequences):
        it = iter(seq)              # get an iterator for this sequence
        try:
            first = next(it)        # take the first element
            heappush(heap, (first, i, it))  # push tuple (value, seq_id, iterator)
        except StopIteration:
            pass   # ignore empty sequences

    # Step 2: Extract elements one by one in sorted order
    while heap:
        value, i, it = heappop(heap)   # get the smallest element
        yield value                    # output it

        # Step 3: Advance the same iterator and push the next element (if any)
        try:
            nxt = next(it)
            heappush(heap, (nxt, i, it))
        except StopIteration:
            pass   # this sequence is finished



You can use the following examples to test your function.

In [30]:
seq = list(heapmerge([xs, ys, zs]))
seq

[2, 3, 4, 41, 43, 45, 48, 50, 56, 63, 76, 81, 88, 88, 96]

In [31]:
sorted(seq) == seq

True

The `heapq` module provides a function called `merge` that implements this algorithm.

## Comparing sort algorithms

NumPy provides implementations of three sorting algorithms, quicksort, mergesort, and heapsort.

In theory that are all in `O(n log n)`.
Let's see what that looks like when we plot runtime versus problem size.


In [None]:
from os.path import basename, exists

def download(url):
    filename = basename(url)
    if not exists(filename):
        from urllib.request import urlretrieve
        local, _ = urlretrieve(url, filename)
        print('Downloaded ' + local)

download('https://github.com/AllenDowney/DSIRP/raw/main/timing.py')

In [None]:
from timing import run_timing_test, plot_timing_test

In [None]:
import numpy as np

def test_quicksort(n):
    xs = np.random.normal(size=n)
    xs.sort(kind='quicksort')

ns, ts = run_timing_test(test_quicksort)
plot_timing_test(ns, ts, 'test_quicksort', exp=1)

quicksort is hard to distinguish from linear, up to about 10 million elements.

In [None]:
def test_mergesort(n):
    xs = np.random.normal(size=n)
    xs.sort(kind='mergesort')

ns, ts = run_timing_test(test_mergesort)
plot_timing_test(ns, ts, 'test_mergesort', exp=1)

Merge sort is similar, maybe with some upward curvature.

In [None]:
def test_heapsort(n):
    xs = np.random.normal(size=n)
    xs.sort(kind='heapsort')

ns, ts = run_timing_test(test_quicksort)
plot_timing_test(ns, ts, 'test_heapsort', exp=1)

The three methods are effectively linear over this range of problem sizes.

And their run times are about the same, with quicksort being the fastest, despite being the one with the worst asympotic performance in the worst case.

Now let's see how our implementation of merge sort does.

In [None]:
def test_merge_sort_rec(n):
    xs = np.random.normal(size=n)
    spectrum = merge_sort_rec(xs)

ns, ts = run_timing_test(test_merge_sort_rec)
plot_timing_test(ns, ts, 'test_merge_sort_rec', exp=1)

If things go according to plan, our implementation of merge sort should be close to linear, or a little steeper.

*Data Structures and Information Retrieval in Python*

Copyright 2021 Allen Downey

License: [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International](https://creativecommons.org/licenses/by-nc-sa/4.0/)