Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use asyncio for Mapper instead of threads. #755

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
73 changes: 21 additions & 52 deletions runhouse/resources/functionals/mapper.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import concurrent.futures
import contextvars
import asyncio
import logging
from typing import Callable, List, Optional, Union

from runhouse.utils import sync_function

try:
from tqdm import tqdm
except ImportError:
Expand Down Expand Up @@ -158,60 +159,28 @@ def starmap(
else:
kwargs["stream_logs"] = kwargs.get("stream_logs", False)

retry_list = []

def call_method_on_replica(job, retry=True):
replica, method_name, context, argies, kwargies = job
# reset context
for var, value in context.items():
var.set(value)

async def call_method_on_replica(replica, argies):
try:
return getattr(replica, method_name)(*argies, **kwargies)
return await getattr(replica, method or self.method)(
*argies, run_async=True, **kwargs
)
except Exception as e:
logger.error(f"Error running {method_name} on {replica.name}: {e}")
if retry:
retry_list.append(job)
else:
return e

context = contextvars.copy_context()
jobs = [
(
self.replicas[self.increment_counter()],
method or self.method,
context,
args,
kwargs,
)
for args in arg_list
]

results = []
max_threads = round(self.concurrency * self.num_replicas)
with concurrent.futures.ThreadPoolExecutor(max_workers=max_threads) as executor:
futs = [
executor.submit(call_method_on_replica, job, retries > 0)
for job in jobs
]
for fut in tqdm(concurrent.futures.as_completed(futs), total=len(jobs)):
results.extend([fut.result()])
for i in range(retries):
if len(retry_list) == 0:
break
logger.info(f"Retry {i}: {len(retry_list)} failed jobs")
jobs, retry_list = retry_list, []
retry = i != retries - 1
results.append(
list(
tqdm(
executor.map(call_method_on_replica, jobs, retry),
total=len(jobs),
)
)
logger.error(
f"Error running {method or self.method} on {replica.name}: {e}"
)
return e

async def batch_jobs_with_async():
return await asyncio.gather(
*[
call_method_on_replica(
self.replicas[self.increment_counter()], args
)
for args in arg_list
]
)

return results
return sync_function(batch_jobs_with_async)()

# TODO should we add an async version of this for when we're on the cluster?
# async def call_method_on_args(argies):
Expand Down