Skip to content

Commit

Permalink
Merge pull request #40 from tomwhite/pywren
Browse files Browse the repository at this point in the history
Pywren executor
  • Loading branch information
rabernat committed Aug 17, 2020
2 parents 001c01a + 9a74d1b commit a7a75c9
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 3 deletions.
4 changes: 4 additions & 0 deletions rechunker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ def _get_executor(name: str) -> Executor:
from rechunker.executors.python import PythonExecutor

return PythonExecutor()
elif name.lower() == "pywren":
from rechunker.executors.pywren import PywrenExecutor

return PywrenExecutor()
else:
raise ValueError(f"unrecognized executor {name}")

Expand Down
78 changes: 78 additions & 0 deletions rechunker/executors/pywren.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from functools import partial

from typing import Callable, Iterable, Tuple

from rechunker.executors.util import chunk_keys, split_into_direct_copies
from rechunker.types import CopySpec, Executor, ReadableArray, WriteableArray

import pywren_ibm_cloud as pywren
from pywren_ibm_cloud.executor import FunctionExecutor

# PywrenExecutor represents delayed execution tasks as functions that require
# a FunctionExecutor.
Task = Callable[[FunctionExecutor], None]


class PywrenExecutor(Executor[Task]):
"""An execution engine based on Pywren.
Supports zarr arrays as inputs. Outputs must be zarr arrays.
Any Pywren FunctionExecutor can be passed to the constructor. By default
a Pywren `local_executor` will be used
Execution plans for PywrenExecutor are functions that accept no arguments.
"""

def __init__(self, pywren_function_executor: FunctionExecutor = None):
self.pywren_function_executor = pywren_function_executor

def prepare_plan(self, specs: Iterable[CopySpec]) -> Task:
tasks = []
for spec in specs:
# Tasks for a single spec must be executed in series
spec_tasks = []
for direct_spec in split_into_direct_copies(spec):
spec_tasks.append(partial(_direct_array_copy, *direct_spec))
tasks.append(partial(_execute_in_series, spec_tasks))
# TODO: execute tasks for different specs in parallel
return partial(_execute_in_series, tasks)

def execute_plan(self, plan: Task):
if self.pywren_function_executor is None:
# No Pywren function executor specified, so use a local one, and shutdown after use
with pywren_local_function_executor() as pywren_function_executor:
plan(pywren_function_executor)
else:
plan(self.pywren_function_executor)


def pywren_local_function_executor():
return pywren.local_executor(
# Minimal config needed to avoid Pywren error if ~/.pywren_config is missing
config={"pywren": {"storage_bucket": "unused"}}
)


def _direct_array_copy(
source: ReadableArray,
target: WriteableArray,
chunks: Tuple[int, ...],
pywren_function_executor: FunctionExecutor,
) -> None:
"""Direct copy between arrays using Pywren for parallelism"""
iterdata = [(source, target, key) for key in chunk_keys(source.shape, chunks)]

def direct_copy(iterdata):
source, target, key = iterdata
target[key] = source[key]

futures = pywren_function_executor.map(direct_copy, iterdata)
pywren_function_executor.get_result(futures)


def _execute_in_series(
tasks: Iterable[Task], pywren_function_executor: FunctionExecutor
) -> None:
for task in tasks:
task(pywren_function_executor)
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
]

extras_require = {
"complete": install_requires + ["apache_beam", "pyyaml", "fsspec", "prefect"],
"complete": install_requires
+ ["apache_beam", "pyyaml", "fsspec", "prefect", "pywren_ibm_cloud"],
"docs": doc_requires,
}
extras_require["dev"] = extras_require["complete"] + [
Expand Down
68 changes: 66 additions & 2 deletions tests/test_rechunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def requires_import(module, *args):

requires_beam = partial(requires_import, "apache_beam")
requires_prefect = partial(requires_import, "prefect")
requires_pywren = partial(requires_import, "pywren_ibm_cloud")


@pytest.fixture(params=[(8000, 200), {"y": 8000, "x": 200}])
Expand All @@ -38,7 +39,14 @@ def target_chunks(request):
@pytest.mark.parametrize("dtype", ["f4"])
@pytest.mark.parametrize("max_mem", [25600000, "25.6MB"])
@pytest.mark.parametrize(
"executor", ["dask", "python", requires_beam("beam"), requires_prefect("prefect")],
"executor",
[
"dask",
"python",
requires_beam("beam"),
requires_prefect("prefect"),
requires_pywren("pywren"),
],
)
@pytest.mark.parametrize(
"dims,target_chunks",
Expand Down Expand Up @@ -131,7 +139,14 @@ def test_rechunk_dask_array(


@pytest.mark.parametrize(
"executor", ["dask", "python", requires_beam("beam"), requires_prefect("prefect")],
"executor",
[
"dask",
"python",
requires_beam("beam"),
requires_prefect("prefect"),
requires_pywren("pywren"),
],
)
def test_rechunk_group(tmp_path, executor):
store_source = str(tmp_path / "source.zarr")
Expand Down Expand Up @@ -256,3 +271,52 @@ def test_no_intermediate_fused(tmp_path):

num_tasks = len([v for v in rechunked.plan.dask.values() if dask.core.istask(v)])
assert num_tasks < 20 # less than if no fuse


def test_pywren_function_executor(tmp_path):
pytest.importorskip("pywren_ibm_cloud")
from rechunker.executors.pywren import (
pywren_local_function_executor,
PywrenExecutor,
)

# Create a Pywren function exectutor that we manage ourselves
# and pass in to rechunker's PywrenExecutor
with pywren_local_function_executor() as function_executor:

executor = PywrenExecutor(function_executor)

shape = (8000, 8000)
source_chunks = (200, 8000)
dtype = "f4"
max_mem = 25600000
target_chunks = (400, 8000)

### Create source array ###
store_source = str(tmp_path / "source.zarr")
source_array = zarr.ones(
shape, chunks=source_chunks, dtype=dtype, store=store_source
)

### Create targets ###
target_store = str(tmp_path / "target.zarr")
temp_store = str(tmp_path / "temp.zarr")

rechunked = api.rechunk(
source_array,
target_chunks,
max_mem,
target_store,
temp_store=temp_store,
executor=executor,
)
assert isinstance(rechunked, api.Rechunked)

target_array = zarr.open(target_store)

assert target_array.chunks == tuple(target_chunks)

result = rechunked.execute()
assert isinstance(result, zarr.Array)
a_tar = dsa.from_zarr(target_array)
assert dsa.equal(a_tar, 1).all().compute()

0 comments on commit a7a75c9

Please sign in to comment.