Skip to content

Commit

Permalink
Merge pull request #430 from ungarj/dask_executor_fixes
Browse files Browse the repository at this point in the history
Dask executor fixes
  • Loading branch information
ungarj committed Feb 23, 2022
2 parents 2830dbf + d7a13ea commit 6b1cb37
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 70 deletions.
7 changes: 5 additions & 2 deletions mapchete/_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,12 +152,16 @@ def cancel(self):
# reset so futures won't linger here for next call
self.running_futures = set()

def wait(self):
def wait(self, raise_exc=False):
logger.debug("wait for running futures to finish...")
try: # pragma: no cover
self._wait()
except CancelledError: # pragma: no cover
pass
except Exception as exc: # pragma: no cover
logger.error("exception caught when waiting for futures: %s", str(exc))
if raise_exc:
raise exc

def close(self): # pragma: no cover
self.__exit__(None, None, None)
Expand Down Expand Up @@ -326,7 +330,6 @@ def as_completed(
item, skip, skip_info = item
if skip:
yield SkippedFuture(item, skip_info=skip_info)
self._submitted -= 1
continue

# add processing item to chunk
Expand Down
81 changes: 42 additions & 39 deletions mapchete/_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from itertools import chain
import logging
import multiprocessing
from shapely.geometry import mapping
from tilematrix._funcs import Bounds
from traceback import format_exc
from typing import Generator

Expand Down Expand Up @@ -57,6 +59,7 @@ def __init__(
preprocessing_tasks: int = None,
executor_concurrency: str = "processes",
executor_kwargs: dict = None,
process_area=None,
):
self.func = func
self.fargs = fargs or ()
Expand All @@ -69,19 +72,31 @@ def __init__(
self.preprocessing_tasks = preprocessing_tasks or 0
self._total = self.preprocessing_tasks + self.tiles_tasks
self._as_iterator = as_iterator
self._process_area = process_area
self.bounds = Bounds(*process_area.bounds) if process_area is not None else None

if not as_iterator:
self._results = list(self._run())

@property
def __geo_interface__(self): # pragma: no cover
if self._process_area is not None:
return mapping(self._process_area)
else:
raise AttributeError(f"{self} has no geo information assigned")

def _run(self):
if self._total == 0:
return
logger.debug("opening executor for job %s", repr(self))
with Executor(
concurrency=self.executor_concurrency, **self.executor_kwargs
) as self.executor:
self.status = "running"
logger.debug("change of job status: %s", self)
yield from self.func(*self.fargs, executor=self.executor, **self.fkwargs)
self.status = "finished"
logger.debug("change of job status: %s", self)

def set_executor_kwargs(self, executor_kwargs):
"""
Expand Down Expand Up @@ -122,6 +137,7 @@ def __repr__(self): # pragma: no cover


def task_batches(process, zoom=None, tile=None, skip_output_check=False):
"""Create task batches for each processing stage."""
with Timer() as duration:
# preprocessing tasks
yield TaskBatch(
Expand All @@ -134,7 +150,8 @@ def task_batches(process, zoom=None, tile=None, skip_output_check=False):
with Timer() as duration:
if tile:
zoom_levels = [tile.zoom]
tiles = {tile.zoom: [tile]}
skip_output_check = True
tiles = {tile.zoom: [(tile, False)]}
else:
zoom_levels = list(
reversed(process.config.zoom_levels)
Expand All @@ -143,8 +160,8 @@ def task_batches(process, zoom=None, tile=None, skip_output_check=False):
)
tiles = {
zoom: (
tile
for tile, skip, process_msg in _filter_skipable(
(tile, skip)
for tile, skip, _ in _filter_skipable(
process=process,
tiles_batches=process.get_process_tiles(zoom, batch_by="row"),
target_set=None,
Expand All @@ -169,13 +186,13 @@ def task_batches(process, zoom=None, tile=None, skip_output_check=False):
tile=tile,
config=process.config,
skip=(
process.mode == "continue"
process.config.mode == "continue"
and process.config.output_reader.tiles_exist(tile)
)
if skip_output_check
else False,
else skip,
)
for tile in tiles[zoom]
for tile, skip in tiles[zoom]
),
func=func,
fkwargs=fkwargs,
Expand Down Expand Up @@ -529,13 +546,14 @@ def _compute_task_graph(
zoom=zoom_levels, tile=tile, skip_output_check=skip_output_check
)
)
logger.debug("%s dask collection generated in %s", len(coll), t)
logger.debug("dask collection with %s tasks generated in %s", len(coll), t)

# send to scheduler
with Timer() as t:
futures = executor._executor.compute(coll, optimize_graph=True, traverse=True)
logger.debug("sent to scheduler in %s", t)
logger.debug("%s tasks sent to scheduler in %s", len(futures), t)

logger.debug("wait for tasks to finish...")
for future in as_completed(
futures, with_results=with_results, raise_errors=raise_errors
):
Expand All @@ -561,29 +579,19 @@ def _compute_tasks(
skip_output_check=False,
**kwargs,
):
num_processed = 0
if not process.config.preprocessing_tasks_finished:
tasks = process.config.preprocessing_tasks()
logger.info(
"run preprocessing on %s tasks using %s workers", len(tasks), workers
)
# process all remaining tiles using todo list from before
for i, future in enumerate(
executor.as_completed(
func=_preprocess_task_wrapper,
iterable=tasks.values(),
fkwargs=dict(append_data=True),
**kwargs,
),
1,
for future in executor.as_completed(
func=_preprocess_task_wrapper,
iterable=tasks.values(),
fkwargs=dict(append_data=True),
**kwargs,
):
result = future.result()
logger.debug(
"preprocessing task %s/%s %s processed successfully",
i,
len(tasks),
result.task_key,
)
process.config.set_preprocessing_task_result(result.task_key, result.data)
yield future

Expand Down Expand Up @@ -627,23 +635,16 @@ def _compute_tasks(
else:
_process_batches = _run_multi_no_overviews

for num_processed, future in enumerate(
_process_batches(
zoom_levels=zoom_levels,
executor=executor,
func=func,
process=process,
skip_output_check=skip_output_check,
fkwargs=fkwargs,
write_in_parent_process=write_in_parent_process,
**kwargs,
),
1,
for future in _process_batches(
zoom_levels=zoom_levels,
executor=executor,
func=func,
process=process,
skip_output_check=skip_output_check,
fkwargs=fkwargs,
write_in_parent_process=write_in_parent_process,
**kwargs,
):
logger.debug(
"task %s finished",
num_processed,
)
yield future


Expand All @@ -664,6 +665,7 @@ def _run_multi_overviews(

for i, zoom in enumerate(zoom_levels):

logger.debug("sending tasks to executor %s...", executor)
# get generator list of tiles, whether they are to be skipped and skip_info
# from _filter_skipable and pass on to executor
for future in executor.as_completed(
Expand Down Expand Up @@ -747,6 +749,7 @@ def _run_multi_no_overviews(
dask_max_submitted_tasks=None,
write_in_parent_process=None,
):
logger.debug("sending tasks to executor %s...", executor)
# get generator list of tiles, whether they are to be skipped and skip_info
# from _filter_skipable and pass on to executor
for future in executor.as_completed(
Expand Down
41 changes: 20 additions & 21 deletions mapchete/_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,27 +406,26 @@ def to_dask_collection(batches):
from dask.delayed import delayed

tasks = {}
with Timer() as t:
previous_batch = None
for batch in batches:
previous_batch = None
for batch in batches:
logger.debug("converting batch %s", batch)
if previous_batch:
logger.debug("previous batch had %s tasks", len(previous_batch))
for task in batch.values():
if previous_batch:
logger.debug("previous batch had %s tasks", len(previous_batch))
for task in batch.values():
if previous_batch:
dependencies = {
child.id: tasks[child]
for child in previous_batch.intersection(task)
}
logger.debug(
"found %s dependencies from last batch for task %s",
len(dependencies),
task,
)
else:
dependencies = {}
tasks[task] = delayed(batch.func)(
task, dependencies=dependencies, **batch.fkwargs
dependencies = {
child.id: tasks[child]
for child in previous_batch.intersection(task)
}
logger.debug(
"found %s dependencies from last batch for task %s",
len(dependencies),
task,
)
previous_batch = batch
logger.debug("%s tasks generated in %s", len(tasks), t)
else:
dependencies = {}
tasks[task] = delayed(batch.func)(
task, dependencies=dependencies, **batch.fkwargs
)
previous_batch = batch
return list(tasks.values())
15 changes: 8 additions & 7 deletions mapchete/cli/default/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,16 @@ def execute(
for mapchete_file in mapchete_files:
tqdm.tqdm.write(f"preparing to process {mapchete_file}")
with mapchete.Timer() as t:
job = commands.execute(
mapchete_file,
*args,
as_iterator=True,
msg_callback=tqdm.tqdm.write if verbose else None,
**kwargs,
)
list(
tqdm.tqdm(
commands.execute(
mapchete_file,
*args,
as_iterator=True,
msg_callback=tqdm.tqdm.write if verbose else None,
**kwargs,
),
job,
unit="task",
disable=debug or no_pbar,
)
Expand Down
3 changes: 2 additions & 1 deletion mapchete/commands/_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _empty_callback(_):
tiles_tasks = 1 if tile else mp.count_tiles()
total_tasks = preprocessing_tasks + tiles_tasks
msg_callback(
f"processing {preprocessing_tasks} preprocessing tasks and {tiles_tasks} on {workers} worker(s)"
f"processing {preprocessing_tasks} preprocessing tasks and {tiles_tasks} tile tasks on {workers} worker(s)"
)
# automatically use dask Executor if dask scheduler is defined
if dask_scheduler or dask_client: # pragma: no cover
Expand Down Expand Up @@ -183,6 +183,7 @@ def _empty_callback(_):
as_iterator=as_iterator,
preprocessing_tasks=preprocessing_tasks,
tiles_tasks=tiles_tasks,
process_area=mp.config.init_area,
)
# explicitly exit the mp object on failure
except Exception as exc: # pragma: no cover
Expand Down

0 comments on commit 6b1cb37

Please sign in to comment.