In [1]:
from dask.distributed import Client, LocalCluster, wait, as_completed
import time
from tqdm import tqdm_notebook as tqdm

In [1]:
# We wrap the clients in a class to prevent
# double init
class MyClient:
    mp = None
    mt = None
    modules = None 
    started = False

    # Stats
    p_workers = None 
    p_cores   = None 
    p_workers = None 
    p_cores   = None 
    
    @staticmethod
    def restart():
        MyClient.mp.restart()
        MyClient.mt.restart()
    
    @staticmethod
    def init():
        if MyClient.started is False:
            # Start process cluster
            print("Starting multiprocess cluster...")
            MyClient.mp = Client(scheduler_port=0, dashboard_address=None)
            MyClient.p_workers = len(MyClient.mp.cluster.workers)
            MyClient.p_cores = MyClient.mp.cluster.worker_kwargs['ncores']
            print("Multiprocess cluster started w/ {} workers ({} cores each)".format(MyClient.p_workers, MyClient.p_cores))
            
            # Start thread cluster
            print("Starting multithread cluster...")
            MyClient.mt = Client(scheduler_port=0, dashboard_address=None, processes=False)
            MyClient.t_workers = len(MyClient.mt.cluster.workers)
            MyClient.t_cores = MyClient.mt.cluster.worker_kwargs['ncores']
            print("Multithread cluster started w/ {} workers ({} cores each)".format(MyClient.t_workers, MyClient.t_cores))
            
            # Toggle             
            MyClient.started = True
        else:
            print("Multithread cluster already started")
            
def init_parallel(module_names=None):
    MyClient.init()    
    if module_names is not None:
        for name in module_names:
            MyClient.mp.upload_file(name)
            MyClient.mt.upload_file(name)
    print(
        """
        Parallel Plugin Loaded. You can now decorate functions with @profile(profile_array) 
        and @parallel(map=True, threads=True, background=False). MyClient and get_results(futures)
        have also been loaded into your namespace.
        """)


In [1]:
# User decorators
def profile(array_input):
    def profile_decorator(f):
        def profile_wrapper(*args, **kwargs):
            # Stats
            MyClient.restart() # clear cache
            
            print("---------Parallel vs Sequential Profile--------")
            print("Function: {}".format(f.__name__))
            print("Input length: {}".format(len(array_input)))
            print("Input type: {}".format(type(array_input[0])))
            print("Timing...")
            seq = MyClient.mp.submit(_time_sequential_map, f, array_input)
            seq = seq.result()
            par_processes = _time_parallel_map(f, array_input)
            MyClient.restart() # clear cache
            par_threads = _time_parallel_map(f, array_input, threads=False)
            print("========================")
            print("Sequential time: {:.4f}s".format(seq))
            print("Multiprocessing time: {:.4f}s (~{:.4f}x speedup, {} workers, {} cores/worker)".format(par_processes, seq/par_processes, MyClient.p_workers, MyClient.p_cores))
            print("Multithreading time: {:.4f}s (~{:.4f}x speedup, {} workers, {} cores/worker)".format(par_threads, seq/par_threads, MyClient.t_workers, MyClient.t_cores))
        return profile_wrapper
    return profile_decorator

def parallel(map=False, threads=False, background=True):
    def concurrent_decorator(func):
        def wrapper_concurrent(*args, **kwargs):
            client = MyClient.mt if threads else MyClient.mp
            if map:
                res = client.map(func, *args, **kwargs)
            else:
                res = client.submit(func, *args, **kwargs)
            if background:
                return res
            else:
                return [r for f, r in tqdm(as_completed(res, with_results=True), total=len(res))]
        return wrapper_concurrent
    return concurrent_decorator

def get_results(futures):
    return [f.result() for f in as_completed(futures)]

In [5]:
# Helpers
def _time_sequential_map(f, array):
    ts = time.time()
    for a in array:
        result = f(a)
    te = time.time()
    time_ms = (te - ts)
    return time_ms

def _time_parallel_map(f, array, threads=False):
    ts = time.time()
    client = MyClient.mt if threads else MyClient.mp
    futures = client.map(f, array)
    wait(futures)
    te = time.time()
    time_ms = (te - ts)
    return time_ms