In [153]:
from numba import jit
import numpy as np
import time
import os

from tqdm.notebook import tqdm
from typing import List, Union, Iterable

# Single version

**Recursive version**

In [112]:
# all ooperations are done inplace, not very pythonic but more effective in terms of memory

def merge_bitonic_parts(array: Iterable[float], start_index: int, num_elements_to_sort: int, ascending: bool = True):
    if num_elements_to_sort > 1:
        
        direction = 1 if ascending else -1    
        mid_index = num_elements_to_sort // 2
        
        for i in range(start_index, start_index + mid_index):
            # swap elements if condition reached
            index_left = i
            index_right = i + mid_index
#             print(array[index_left], array[index_right], direction, (direction * array[index_left]) > array[index_right])
            if direction * array[index_left] > direction * array[index_right]:
                array[index_left], array[index_right] = array[index_right], array[index_left]

        merge_bitonic_parts(array, start_index, mid_index, ascending)
        merge_bitonic_parts(array, start_index + mid_index, mid_index, ascending)


def do_bitonic_sort(array: Iterable[float], start_index: int = 0, num_elements_to_sort: int = -1, ascending: bool = True):
    num_elements_to_sort = len(array) if num_elements_to_sort == -1 else num_elements_to_sort
    
    if num_elements_to_sort > 1:
        
        mid_index = num_elements_to_sort // 2

        do_bitonic_sort(array, start_index, mid_index, ascending)
        do_bitonic_sort(array, start_index + mid_index, mid_index, not ascending)

        merge_bitonic_parts(array, start_index, num_elements_to_sort, ascending)

In [122]:
from math import ceil

sizes = [10**5, 10**6]

times_single = []

for size in tqdm(sizes):
    n = ceil(np.log2(size))
    
    arr = np.random.randn(2**n)
    arr_copy = arr.copy()
    
    start = time.time()
    do_bitonic_sort(arr_copy)
    times_single.append(time.time() - start)
    
    assert np.allclose(arr_copy, sorted(arr))

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=2.0), HTML(value='')))




# Multiprocess version

In [154]:
if os.path.exists('times'):
    os.remove('times')
    
sizes = [10**5, 10**6]
for size in sizes:
    print(size)
    n = ceil(np.log2(size))
    for n_processes in tqdm([1, 2, 4, 8, 16]):
        os.system(f'mpirun -n {n_processes} python mpi_bitonic_sort.py --return_time -n={n}')

In [None]:
proc_times = defaultdict{dict}

with open('times', 'r') as f:
    for line in f.readlines():
        size, proc, times = line.strip().split()
        proc_times[int(size)][int(proc)] = float(times)