In [6]:
from multiprocessing import Manager, Pool
from modules.sorting import merge_sort, merge_sort_multiple
from modules.helpers import generate_random_int_list
from modules.timing import function_timing

In [7]:
def merge_higher_dimensions_lists(lst):
    idx = [0] * len(lst)
    out = []
    while lst:
        m = 0
        for i in range(len(lst)):
            if lst[i][idx[i]] < lst[m][idx[m]]:
                m = i
        out.append(lst[m][idx[m]])
        idx[m] += 1
        if idx[m] >= len(lst[m]):
            del lst[m]
            del idx[m]
    return out

In [8]:
def threaded_merge_sort(array, threads):
    
    # Create necessary multiprocessing objects
    thread_manager = Manager()
    result_manager = thread_manager.list()
    pool = Pool(threads)
    
    # Add asynchronous processes to pool
    for i in range(threads):
        pool.apply_async(merge_sort_multiple, (result_manager, array[i::threads]))
    
    # Close pool
    pool.close()
    
    # Wait till all processes are finished with join
    pool.join()
    
    # Merge sort higher dimensional lists
    return merge_higher_dimensions_lists(result_manager)

In [9]:
# Test function
array = generate_random_int_list(20)
print(threaded_merge_sort(array, 1) == sorted(array))

True


In [10]:
array = generate_random_int_list(100)

print("Normal time: " + str(function_timing(merge_sort, 1, array)))
for i in range(2, 8, 2):
    print(f"{i} threads: " + str(function_timing(threaded_merge_sort, 1, array, i)))

Normal time: 0.0003662000000019816
2 threads: 0.4285101000000111
4 threads: 0.4606066999999996
6 threads: 0.5138164999999901
