# Merge Sort Lab ðŸ§ ðŸ§°
In this notebook, we will:

1. Implement merge sort (recursive) and the merge step.
2. Count *comparisons* made during sorting.
3. Experimentally confirm the bookâ€™s big claim: merge sort uses about **O(n log n)** comparisons.

### What the book gives us
- A recursive picture of merge sort splitting and merging a list (see the diagram on page 3).
- **Algorithm 9**: recursive merge sort (page 4).
- **Algorithm 10**: merging two sorted lists (page 5).
- A key bound: merge sort comparisons are **O(n log n)** (page 6).

We'll turn those ideas into runnable code and evidence.


## Quick intuition
Merge sort does two repeated actions:

### 1) Split
Keep splitting the list into halves until each piece has size 1.

### 2) Merge
Merge sorted halves back together.
The merge step is where comparisons happen.

A key fact from the text:
If you merge two sorted lists with lengths `m` and `n`,
the merge needs at most **m + n âˆ’ 1** comparisons.
(Thatâ€™s the engine behind the O(n log n) result.)


In [None]:
from dataclasses import dataclass
import random
import math
import time

@dataclass
class SortStats:
    comparisons: int = 0
    writes: int = 0   # how many items we append into merged output


C:\Users\evertj\AppData\Local\Microsoft\WindowsApps\PythonSoftwareFoundation.Python.3.11_qbz5n2kfra8p0\python.exe
3.11.9 (tags/v3.11.9:de54cf5, Apr  2 2024, 10:12:12) [MSC v.1938 64 bit (AMD64)]


In [None]:
def merge_sorted_lists(left, right, stats: SortStats):
    """
    Merge two already-sorted lists into one sorted list.
    Counts comparisons and writes (append operations).

    Mirrors the bookâ€™s "merge two lists" idea (Algorithm 10).
    """
    merged = []
    i = j = 0

    while i < len(left) and j < len(right):
        stats.comparisons += 1
        if left[i] <= right[j]:
            merged.append(left[i])
            stats.writes += 1
            i += 1
        else:
            merged.append(right[j])
            stats.writes += 1
            j += 1

    # Append leftovers (no comparisons needed here)
    if i < len(left):
        merged.extend(left[i:])
        stats.writes += (len(left) - i)
    if j < len(right):
        merged.extend(right[j:])
        stats.writes += (len(right) - j)

    return merged


[3, 1, 4, 1, 5, 9, 2]


In [None]:
def merge_sort(arr, stats: SortStats):
    """
    Recursive merge sort.
    Splits list, sorts halves recursively, then merges.
    """
    if len(arr) <= 1:
        return arr

    mid = len(arr) // 2
    left_sorted = merge_sort(arr[:mid], stats)
    right_sorted = merge_sort(arr[mid:], stats)

    return merge_sorted_lists(left_sorted, right_sorted, stats)


In [None]:
data = [8, 2, 4, 6, 9, 7, 10, 1, 5, 3]  # similar to the page-3 example list
stats = SortStats()
sorted_data = merge_sort(data, stats)

sorted_data, stats


## What should we expect?
If `n = 10`, merge sort splits into halves until size 1, then merges back.

The bookâ€™s storyline:
- merge step comparisons are bounded (â‰¤ m + n âˆ’ 1)
- total merges happen across about log2(n) "levels"
- so total comparisons scale like **n log2(n)**

Next: weâ€™ll measure comparisons for different `n` and compare to `n log2(n)`.


In [None]:
def run_one_trial(n, seed=None):
    if seed is not None:
        random.seed(seed)
    arr = [random.randint(0, 10**9) for _ in range(n)]
    stats = SortStats()
    out = merge_sort(arr, stats)
    assert out == sorted(arr), "Sort failed!"
    return stats.comparisons

def run_experiment(ns, trials=30):
    results = {}
    for n in ns:
        comps = [run_one_trial(n) for _ in range(trials)]
        results[n] = {
            "avg_comparisons": sum(comps) / len(comps),
            "min_comparisons": min(comps),
            "max_comparisons": max(comps),
            "n_log2_n": n * math.log2(n) if n > 1 else 0
        }
    return results

ns = [8, 16, 32, 64, 128, 256, 512, 1024]
results = run_experiment(ns, trials=40)
results


## Reading the results
We computed:
- average comparisons actually used
- the value `n log2(n)` as a reference scale

We do **not** expect comparisons to equal `n log2(n)` exactly.
We *do* expect the comparisons to grow proportionally to it.


In [None]:
def print_table(results):
    print(f"{'n':>6} | {'avg comps':>12} | {'n log2 n':>12} | {'avg/(n log2 n)':>14}")
    print("-" * 55)
    for n in sorted(results.keys()):
        avg_c = results[n]["avg_comparisons"]
        ref = results[n]["n_log2_n"]
        ratio = (avg_c / ref) if ref else float('nan')
        print(f"{n:>6} | {avg_c:>12.2f} | {ref:>12.2f} | {ratio:>14.4f}")

print_table(results)


### What does the ratio mean?
If merge sort is O(n log n), then:

avg_comparisons â‰ˆ C * (n log2 n)

So the ratio avg_comparisons / (n log2 n) should hover around a constant C
as n grows (it might wiggle a bit, but it shouldnâ€™t explode).

If the ratio grows without bound, we'd be in trouble.
If it stabilizes, thatâ€™s empirical evidence for O(n log n).
