diff --git a/runhouse/resources/functionals/mapper.py b/runhouse/resources/functionals/mapper.py index 9834e7b64..c59ddeb85 100644 --- a/runhouse/resources/functionals/mapper.py +++ b/runhouse/resources/functionals/mapper.py @@ -1,8 +1,9 @@ -import concurrent.futures -import contextvars +import asyncio import logging from typing import Callable, List, Optional, Union +from runhouse.utils import sync_function + try: from tqdm import tqdm except ImportError: @@ -158,60 +159,28 @@ def starmap( else: kwargs["stream_logs"] = kwargs.get("stream_logs", False) - retry_list = [] - - def call_method_on_replica(job, retry=True): - replica, method_name, context, argies, kwargies = job - # reset context - for var, value in context.items(): - var.set(value) - + async def call_method_on_replica(replica, argies): try: - return getattr(replica, method_name)(*argies, **kwargies) + return await getattr(replica, method or self.method)( + *argies, run_async=True, **kwargs + ) except Exception as e: - logger.error(f"Error running {method_name} on {replica.name}: {e}") - if retry: - retry_list.append(job) - else: - return e - - context = contextvars.copy_context() - jobs = [ - ( - self.replicas[self.increment_counter()], - method or self.method, - context, - args, - kwargs, - ) - for args in arg_list - ] - - results = [] - max_threads = round(self.concurrency * self.num_replicas) - with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor: - futs = [ - executor.submit(call_method_on_replica, job, retries > 0) - for job in jobs - ] - for fut in tqdm(concurrent.futures.as_completed(futs), total=len(jobs)): - results.extend([fut.result()]) - for i in range(retries): - if len(retry_list) == 0: - break - logger.info(f"Retry {i}: {len(retry_list)} failed jobs") - jobs, retry_list = retry_list, [] - retry = i != retries - 1 - results.append( - list( - tqdm( - executor.map(call_method_on_replica, jobs, retry), - total=len(jobs), - ) - ) + logger.error( + f"Error running {method or self.method} on {replica.name}: {e}" ) + return e + + async def batch_jobs_with_async(): + return await asyncio.gather( + *[ + call_method_on_replica( + self.replicas[self.increment_counter()], args + ) + for args in arg_list + ] + ) - return results + return sync_function(batch_jobs_with_async)() # TODO should we add an async version of this for when we're on the cluster? # async def call_method_on_args(argies):