In [1]:
from dask.distributed import Client, LocalCluster, wait
import time

In [2]:
# We wrap the clients in a class to prevent
# double init
class MyClient:
    mp = None
    mt = None
    modules = None 
    started = False
    
def init_parallel(module_names=None):
    if MyClient.started is False:
        print("Starting multiprocess cluster...")
        MyClient.mp = Client(scheduler_port=0, dashboard_address=None)
        print("Starting multithread cluster...")
        MyClient.mt = Client(scheduler_port=0, dashboard_address=None, processes=False)        
        MyClient.started = True
        if module_names is not None:
            for name in module_names:
                MyClient.mp.upload_file(name)
                MyClient.mt.upload_file(name)
    else:
        print("Multithread cluster already started")

In [3]:
# User decorators
def profile(array_input):
    def profile_decorator(f):
        def profile_wrapper(*args, **kwargs):
            # Stats
            p_workers = len(MyClient.mp.cluster.workers)
            p_cores = MyClient.mp.cluster.worker_kwargs['ncores']
            
            t_workers = len(MyClient.mt.cluster.workers)
            t_cores = MyClient.mt.cluster.worker_kwargs['ncores']
            
            print("---------Parallel vs Sequential Profile--------")
            print("Function: {}".format(f.__name__))
            print("Input length: {}".format(len(array_input)))
            print("Input type: {}".format(type(array_input[0])))
            print("Timing...")
            seq = MyClient.mp.submit(_time_sequential_map, f, array_input)
            seq = seq.result()
            par_processes = _time_parallel_map(f, array_input)
            par_threads = _time_parallel_map(f, array_input, threads=False)
            print("========================")
            print("Sequential time: {:.4f}s".format(seq))
            print("Multiprocessing time: {:.4f}s (~{:.4f}x speedup, {} workers, {} cores/worker)".format(par_processes, seq/par_processes, p_workers, p_cores))
            print("Multithreading time: {:.4f}s (~{:.4f}x speedup, {} workers, {} cores/worker)".format(par_threads, seq/par_threads, t_workers, t_cores))
        return profile_wrapper
    return profile_decorator

def concurrent(map=False, threads=False):
    def concurrent_decorator(func):
        def wrapper_concurrent(*args, **kwargs):
            client = MyClient.mt if threads else MyClient.mp
            if map:
                res = client.map(func, *args, **kwargs)
            else:
                res = client.submit(func, *args, **kwargs)
            return res
        return wrapper_concurrent
    return concurrent_decorator

In [5]:
# Helpers
def _time_sequential_map(f, array):
    ts = time.time()
    for a in array:
        result = f(a)
    te = time.time()
    time_ms = (te - ts)
    return time_ms

def _time_parallel_map(f, array, threads=False):
    ts = time.time()
    client = MyClient.mt if threads else MyClient.mp
    futures = client.map(f, array)
    wait(futures)
    te = time.time()
    time_ms = (te - ts)
    return time_ms