Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Dask-proxy #22

Merged
merged 3 commits into from
Nov 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ install_requires =
lxml
numpy
typing-extensions
wrapt
python_requires = >=3.7,<3.10
include_package_data = True
package_dir =
Expand Down
89 changes: 89 additions & 0 deletions src/nd2/_dask_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""
experimental Dask array proxy for file IO
"""
from __future__ import annotations

from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Callable

import dask.array as da
import numpy as np
from wrapt import ObjectProxy

if TYPE_CHECKING:
from typing import Protocol

# fmt: off
class CheckableContext(Protocol):
@property
def closed(self) -> bool: ... # noqa: E704
def __enter__(self): ... # noqa: E704
def __exit__(self, *a): ... # noqa: E704
# fmt: on


class DaskArrayProxy(ObjectProxy):
"""Dask array wrapper that provides a 'file open' context when computing.

This is necessary when the dask array contains a delayed underlying reader function
that requires the file to be open. We don't want to open/close the file on
every single chunk. So this wrap the `compute` and `__array__` method in an
open-file context manager. For __getattr__, we return an `ArrayMethodProxy` that
again wraps the resulting object in a DaskArrayProxy if it is a dask array.

Experimental!
The state held by the `file_ctx` may be problematic for dask distributed.

Parameters
----------
wrapped : da.Array
the dask array that requires some file
file_ctx : ContextManager
A context in which the file is open.
IMPORTANT: the context must be reusable, and preferably re-entrant:
https://docs.python.org/3/library/contextlib.html#reentrant-context-managers
"""

__wrapped__: da.Array

def __init__(self, wrapped: da.Array, file_ctx: CheckableContext) -> None:
super().__init__(wrapped)
self._file_ctx = file_ctx

def __getitem__(self, key: Any) -> DaskArrayProxy:
return DaskArrayProxy(self.__wrapped__.__getitem__(key), self._file_ctx)

def __getattr__(self, key: Any) -> Any:
attr = getattr(self.__wrapped__, key)
return _ArrayMethodProxy(attr, self._file_ctx) if callable(attr) else attr

def __repr__(self) -> str:
return repr(self.__wrapped__)

def compute(self, **kwargs: Any) -> np.ndarray:
with self._file_ctx if self._file_ctx.closed else nullcontext():
return self.__wrapped__.compute(**kwargs)

def __array__(self, dtype: str = None, **kwargs: Any) -> np.ndarray:
with self._file_ctx if self._file_ctx.closed else nullcontext():
return self.__wrapped__.__array__(dtype, **kwargs)


class _ArrayMethodProxy:
"""Wraps method on a dask array and returns a DaskArrayProxy if the result of the
method is a dask array. see details in DaskArrayProxy docstring."""

def __init__(self, method: Callable, file_ctx: CheckableContext) -> None:
self.method = method
self._file_ctx = file_ctx

def __repr__(self) -> str:
return repr(self.method)

def __call__(self, *args: Any, **kwds: Any) -> Any:
with self._file_ctx if self._file_ctx.closed else nullcontext():
result = self.method(*args, **kwds)

if isinstance(result, da.Array):
return DaskArrayProxy(result, self._file_ctx)
return result
2 changes: 2 additions & 0 deletions src/nd2/_sdk/latest.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,8 @@ cdef class ND2Reader:
"""Read a chunk directly without using SDK"""
if index > self._max_safe:
raise IndexError(f"Frame out of range: {index}")
if not self._is_open:
raise ValueError("Attempt to read from closed nd2 file")
offset = self._frame_map["safe"].get(index, None)
if offset is None:
return self._missing_frame(index)
Expand Down
7 changes: 6 additions & 1 deletion src/nd2/nd2file.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,14 @@ def to_dask(self, copy=True) -> da.Array:
"""
from dask.array import map_blocks

from ._dask_proxy import DaskArrayProxy

chunks = [(1,) * x for x in self._coord_shape]
chunks += [(x,) for x in self._frame_shape]
return map_blocks(self._dask_block, copy, chunks=chunks, dtype=self.dtype)
dask_arr = map_blocks(self._dask_block, copy, chunks=chunks, dtype=self.dtype)
# this proxy allows the dask array to re-open the underlying
# nd2 file on compute.
return DaskArrayProxy(dask_arr, self)

_NO_IDX = -1

Expand Down
5 changes: 5 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
NEW.append(x) if is_new_format(str(x)) else OLD.append(x)


@pytest.fixture
def single_nd2():
return DATA / "dims_rgb_t3p2c2z3x64y64.nd2"


@pytest.fixture(params=ALL, ids=lambda x: x.name)
def any_nd2(request):
return request.param
Expand Down
115 changes: 115 additions & 0 deletions tests/test_dask_proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import warnings
from typing import Any, Tuple

import dask.array as da
import numpy as np
import pytest
from nd2._dask_proxy import DaskArrayProxy


# a *re-entrant* file context manager
class FileContext:
FILE_OPEN = False
OPEN_COUNT = 0

@property
def closed(self):
return not self.FILE_OPEN

def __enter__(self) -> "FileContext":
if not self.FILE_OPEN:
self.OPEN_COUNT += 1
self.FILE_OPEN = True
return self

def __exit__(self, *args: Any) -> None:
self.FILE_OPEN = False


@pytest.fixture
def dask_arr() -> da.Array:
called = [0]
ctx = FileContext()

def get_chunk(block_id: Tuple[int, ...]) -> np.ndarray:
if not ctx.FILE_OPEN:
warnings.warn("You didn't open the file!")
nonlocal called

if not isinstance(block_id, np.ndarray):
called[0] += 1
return np.arange(100).reshape(10, 10)[np.newaxis, np.newaxis]

d = da.map_blocks(get_chunk, chunks=((1,) * 10, (1,) * 10, 10, 10), dtype=float)
d.called = called
d.ctx = ctx
return d


@pytest.fixture
def proxy(dask_arr: da.Array) -> DaskArrayProxy:
return DaskArrayProxy(dask_arr, dask_arr.ctx)


def test_array(dask_arr: da.Array) -> None:
with pytest.warns(UserWarning):
dask_arr.compute()

assert dask_arr.ctx.OPEN_COUNT == 0

with dask_arr.ctx:
assert dask_arr.compute().shape == (10, 10, 10, 10)

assert dask_arr.ctx.OPEN_COUNT == 1


def test_proxy_compute(proxy: DaskArrayProxy) -> None:
assert proxy.ctx.OPEN_COUNT == 0
ary = proxy.compute()
assert isinstance(ary, np.ndarray)
assert ary.shape == (10, 10, 10, 10)
assert proxy.ctx.OPEN_COUNT == 1
assert proxy.__wrapped__.called[0] == 100


def test_proxy_asarray(proxy: DaskArrayProxy) -> None:
assert proxy.ctx.OPEN_COUNT == 0
ary = np.asarray(proxy)
assert isinstance(ary, np.ndarray)
assert ary.shape == (10, 10, 10, 10)
assert proxy.ctx.OPEN_COUNT == 1
assert proxy.__wrapped__.called[0] == 100


def test_proxy_getitem(dask_arr: da.Array, proxy: DaskArrayProxy) -> None:
dask_arr.ctx.FILE_OPEN = True
a = dask_arr[0, 1:3]
b = proxy[0, 1:3]
assert isinstance(a, da.Array)
assert isinstance(b, DaskArrayProxy)
np.testing.assert_array_equal(a.compute(), b.compute())


def test_proxy_methods(dask_arr: da.Array, proxy: DaskArrayProxy) -> None:
dmean = proxy.mean()
assert isinstance(dmean, DaskArrayProxy)
assert isinstance(dmean.compute(), float)
with pytest.warns(UserWarning):
assert dmean.compute() == dask_arr.mean().compute()

# non array-returning methods don't return proxies
assert isinstance(proxy.to_svg(), str)


def test_proxy_ufunc(dask_arr: da.Array, proxy: DaskArrayProxy) -> None:
amean = np.mean(dask_arr)
pmean = np.mean(proxy)
assert isinstance(amean, da.Array)
assert isinstance(pmean, DaskArrayProxy)
dask_arr.ctx.FILE_OPEN = True
assert amean.compute() == pmean.compute()


def test_proxy_repr(dask_arr: da.Array, proxy: DaskArrayProxy) -> None:
assert repr(dask_arr) == repr(proxy)
assert repr(dask_arr.mean) == repr(proxy.mean)
6 changes: 6 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def test_dask(new_nd2):
assert arr.shape == nd.shape[-2:]


def test_dask_closed(single_nd2):
with ND2File(single_nd2) as nd:
dsk = nd.to_dask()
assert isinstance(dsk.compute(), np.ndarray)


@pytest.mark.skipif(bool(os.getenv("CIBUILDWHEEL")), reason="slow")
def test_full_read(new_nd2):
with ND2File(new_nd2) as nd:
Expand Down