Skip to content

Commit

Permalink
Merge pull request #34 from shoyer/beam
Browse files Browse the repository at this point in the history
Add Apache Beam executor
  • Loading branch information
rabernat committed Jul 29, 2020
2 parents 3434bf6 + 183c611 commit 22e6ca0
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 16 deletions.
4 changes: 4 additions & 0 deletions rechunker/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,10 @@ def _get_executor(name: str) -> Executor:
from rechunker.executors.dask import DaskExecutor

return DaskExecutor()
elif name.lower() == "beam":
from rechunker.executors.beam import BeamExecutor

return BeamExecutor()
elif name.lower() == "python":
from rechunker.executors.python import PythonExecutor

Expand Down
103 changes: 103 additions & 0 deletions rechunker/executors/beam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import uuid
from typing import Iterable, Optional, Mapping, Tuple

import apache_beam as beam

from rechunker.executors.util import chunk_keys
from rechunker.types import (
CopySpec,
StagedCopySpec,
Executor,
ReadableArray,
WriteableArray,
)


class BeamExecutor(Executor[beam.PTransform]):
"""An execution engine based on Apache Beam.
Supports copying between any arrays that implement ``__getitem__`` and
``__setitem__`` for tuples of ``slice`` objects. Array must also be
serializable by Beam (i.e., with pickle).
Execution plans for BeamExecutor are beam.PTransform objects.
"""

# TODO: explore adding an option to do rechunking with Beam groupby
# operations instead of explicitly writing intermediate arrays to disk.
# This would offer a cleaner API and would perhaps be faster, too.

def prepare_plan(self, specs: Iterable[StagedCopySpec]) -> beam.PTransform:
return "Rechunker" >> _Rechunker(specs)

def execute_plan(self, plan: beam.PTransform, **kwargs):
with beam.Pipeline(**kwargs) as pipeline:
pipeline | plan


class _Rechunker(beam.PTransform):
def __init__(self, specs: Iterable[StagedCopySpec]):
super().__init__()
self.specs = tuple(specs)

def expand(self, pcoll):
max_depth = max(len(spec.stages) for spec in self.specs)
specs_map = {uuid.uuid1().hex: spec for spec in self.specs}

# we explicitly thread target_id through each stage to ensure that they
# are executed in order
# TODO: consider refactoring to use Beam's ``Source`` API for improved
# performance:
# https://beam.apache.org/documentation/io/developing-io-overview/
pcoll = pcoll | "Create" >> beam.Create(specs_map.keys())
for stage in range(max_depth):
specs_by_target = {
k: v.stages[stage] if stage < len(v.stages) else None
for k, v in specs_map.items()
}
pcoll = pcoll | f"Stage{stage}" >> _CopyStage(specs_by_target)
return pcoll


class _CopyStage(beam.PTransform):
def __init__(self, specs_by_target: Mapping[str, CopySpec]):
super().__init__()
self.specs_by_target = specs_by_target

def expand(self, pcoll):
return (
pcoll
| "Start" >> beam.FlatMap(_start_stage, self.specs_by_target)
| "CreateTasks" >> beam.FlatMapTuple(_copy_tasks)
# prevent undesirable fusion
# https://stackoverflow.com/a/54131856/809705
| "Reshuffle" >> beam.Reshuffle()
| "CopyChunks" >> beam.MapTuple(_copy_chunk)
# prepare inputs for the next stage (if any)
| "Finish" >> beam.Distinct()
)


def _start_stage(
target_id: str, specs_by_target: Mapping[str, Optional[CopySpec]],
) -> Tuple[str, CopySpec]:
spec = specs_by_target[target_id]
if spec is not None:
yield target_id, spec


def _copy_tasks(
target_id: str, spec: CopySpec
) -> Tuple[str, Tuple[slice, ...], ReadableArray, WriteableArray]:
for key in chunk_keys(spec.source.shape, spec.chunks):
yield target_id, key, spec.source, spec.target


def _copy_chunk(
target_id: str,
key: Tuple[slice, ...],
source: ReadableArray,
target: WriteableArray,
) -> str:
target[key] = source[key]
return target_id
21 changes: 9 additions & 12 deletions rechunker/executors/python.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import itertools
from functools import partial
import math

from typing import Callable, Iterable
from typing import Callable, Iterable, Tuple

from rechunker.types import CopySpec, StagedCopySpec, Executor
from rechunker.executors.util import chunk_keys
from rechunker.types import StagedCopySpec, Executor, ReadableArray, WriteableArray


# PythonExecutor represents delayed execution tasks as functions that require
Expand All @@ -25,21 +24,19 @@ def prepare_plan(self, specs: Iterable[StagedCopySpec]) -> Task:
tasks = []
for staged_copy_spec in specs:
for copy_spec in staged_copy_spec.stages:
tasks.append(partial(_direct_copy_array, copy_spec))
tasks.append(partial(_direct_copy_array, *copy_spec))
return partial(_execute_all, tasks)

def execute_plan(self, plan: Task):
plan()


def _direct_copy_array(copy_spec: CopySpec) -> None:
def _direct_copy_array(
source: ReadableArray, target: WriteableArray, chunks: Tuple[int, ...]
) -> None:
"""Direct copy between zarr arrays."""
source_array, target_array, chunks = copy_spec
shape = source_array.shape
ranges = [range(math.ceil(s / c)) for s, c in zip(shape, chunks)]
for indices in itertools.product(*ranges):
key = tuple(slice(c * i, c * (i + 1)) for i, c in zip(indices, chunks))
target_array[key] = source_array[key]
for key in chunk_keys(source.shape, chunks):
target[key] = source[key]


def _execute_all(tasks: Iterable[Task]) -> None:
Expand Down
21 changes: 21 additions & 0 deletions rechunker/executors/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import itertools
import math

from typing import Iterator, Tuple


def chunk_keys(
shape: Tuple[int, ...], chunks: Tuple[int, ...]
) -> Iterator[Tuple[slice, ...]]:
"""Iterator over array indexing keys of the desired chunk sized.
The union of all keys indexes every element of an array of shape ``shape``
exactly once. Each array resulting from indexing is of shape ``chunks``,
except possibly for the last arrays along each dimension (if ``chunks``
do not even divide ``shape``).
"""
ranges = [range(math.ceil(s / c)) for s, c in zip(shape, chunks)]
for indices in itertools.product(*ranges):
yield tuple(
slice(c * i, min(c * (i + 1), s)) for i, s, c in zip(indices, shape, chunks)
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
]

extras_require = {
"complete": ["dask[array]", "zarr", "pyyaml", "fsspec"],
"complete": ["apache_beam", "dask[array]", "zarr", "pyyaml", "fsspec"],
"docs": doc_requires,
}
extras_require["dev"] = extras_require["complete"] + [
Expand Down
30 changes: 27 additions & 3 deletions tests/test_rechunk.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial
import importlib
import pytest

import zarr
Expand All @@ -11,6 +13,20 @@
_DIMENSION_KEY = "_ARRAY_DIMENSIONS"


def requires_import(module, *args):
try:
importlib.import_module(module)
except ImportError:
skip = True
else:
skip = False
mark = pytest.mark.skipif(skip, reason=f"requires {module}")
return pytest.param(*args, marks=mark)


requires_beam = partial(requires_import, "apache_beam")


@pytest.fixture(params=[(8000, 200), {"y": 8000, "x": 200}])
def target_chunks(request):
return request.param
Expand All @@ -21,7 +37,7 @@ def target_chunks(request):
@pytest.mark.parametrize("dtype", ["f4"])
@pytest.mark.parametrize("max_mem", [25600000, "25.6MB"])
@pytest.mark.parametrize(
"executor", ["dask", "python"],
"executor", ["dask", "python", requires_beam("beam")],
)
@pytest.mark.parametrize(
"dims,target_chunks",
Expand Down Expand Up @@ -113,7 +129,10 @@ def test_rechunk_dask_array(
assert dsa.equal(a_tar, 1).all().compute()


def test_rechunk_group(tmp_path):
@pytest.mark.parametrize(
"executor", ["dask", "python", requires_beam("beam")],
)
def test_rechunk_group(tmp_path, executor):
store_source = str(tmp_path / "source.zarr")
group = zarr.group(store_source)
group.attrs["foo"] = "bar"
Expand All @@ -130,7 +149,12 @@ def test_rechunk_group(tmp_path):
target_chunks = {"a": (5, 10, 4), "b": (20,)}

rechunked = api.rechunk(
group, target_chunks, max_mem, target_store, temp_store=temp_store
group,
target_chunks,
max_mem,
target_store,
temp_store=temp_store,
executor=executor,
)
assert isinstance(rechunked, api.Rechunked)

Expand Down

0 comments on commit 22e6ca0

Please sign in to comment.