Skip to content

Commit

Permalink
#1161 WIP to vectorize isel_points (#1162)
Browse files Browse the repository at this point in the history
* First draft of dask vindex enabled isel_points

WIP to use dask vindex to point based selection

* remove old sel points

* completely re-worked logic

Work to get tests passing - completely re-worked logic to allow for
wide range of Dataset formats and normal behaviour

* improve support for using array dim

* clean up comments

* Still trying to get semantics of dim and coords indexing right

* Tests for sel_points, isel_points working

* First draft of dask vindex enabled isel_points

WIP to use dask vindex to point based selection

* remove old sel points

* completely re-worked logic

Work to get tests passing - completely re-worked logic to allow for
wide range of Dataset formats and normal behaviour

* improve support for using array dim

* clean up comments

* Still trying to get semantics of dim and coords indexing right

* Tests for sel_points, isel_points working

* more tweaks, fixed tests after merge

* revert to origin test now working

* clean up coordinate construction

* further simplify coords construction

* Update tests to require generation of selection axis coordinate array

* revert change to test to reflect latest dataset behaviour

* fix coord generation to match latest dataset behaviour

Datasets no longer require/generate index coordinate for every dimension

* cleanup

* Simplified code by just looping once over variables

* tidy up

* #1162 further improvements

use _replace_vars_and_dims, simplify new dataset creation, preserve
attributes, clarify dim vs dim_name (don’t re-use variable name to
reduce confusion)

* Formatting

* remove impl detail from docstr

* Add performance improvements section and short descr
  • Loading branch information
mangecoeur authored and shoyer committed Jan 23, 2017
1 parent d5f4af5 commit 4bb630f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 29 deletions.
9 changes: 9 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,15 @@ Bug fixes
- Fix to make ``.copy()`` actually copy dask arrays, which will be relevant for
future releases of dask in which dask arrays will be mutable (:issue:`1180`).

Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~

- :py:meth:`~xarray.Dataset.isel_points` and
:py:meth:`~xarray.Dataset.sel_points` now use vectorised indexing in numpy
and dask (:issue:`1161`), which can result in several orders of magnitude
speedup.
By `Jonathan Chambers <https://github.com/mangecoeur>`_.

.. _whats-new.0.8.2:

v0.8.2 (18 August 2016)
Expand Down
101 changes: 76 additions & 25 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .pycompat import (iteritems, basestring, OrderedDict,
dask_array_type, range)
from .combine import concat
from .formatting import ensure_valid_repr
from .options import OPTIONS

# list of attributes of pd.DatetimeIndex that are ndarrays of time info
Expand Down Expand Up @@ -117,11 +118,11 @@ def calculate_dimensions(variables):


def merge_indexes(
indexes, # type: Dict[Any, Union[Any, List[Any]]]
variables, # type: Dict[Any, Variable]
coord_names, # type: Set
indexes, # type: Dict[Any, Union[Any, List[Any]]]
variables, # type: Dict[Any, Variable]
coord_names, # type: Set
append=False, # type: bool
):
):
# type: (...) -> Tuple[OrderedDict[Any, Variable], Set]
"""Merge variables into multi-indexes.
Expand Down Expand Up @@ -154,7 +155,7 @@ def merge_indexes(
names.append(n)
var = variables[n]
if (current_index_variable is not None and
var.dims != current_index_variable.dims):
var.dims != current_index_variable.dims):
raise ValueError(
"dimension mismatch between %r %s and %r %s"
% (dim, current_index_variable.dims, n, var.dims))
Expand All @@ -178,11 +179,11 @@ def merge_indexes(

def split_indexes(
dims_or_levels, # type: Union[Any, List[Any]]
variables, # type: Dict[Any, Variable]
coord_names, # type: Set
level_coords, # type: Dict[Any, Any]
drop=False, # type: bool
):
variables, # type: Dict[Any, Variable]
coord_names, # type: Set
level_coords, # type: Dict[Any, Any]
drop=False, # type: bool
):
# type: (...) -> Tuple[OrderedDict[Any, Variable], Set]
"""Extract (multi-)indexes (levels) as variables.
Expand Down Expand Up @@ -712,10 +713,12 @@ def __delitem__(self, key):

def _all_compat(self, other, compat_str):
"""Helper function for equals and identical"""

# some stores (e.g., scipy) do not seem to preserve order, so don't
# require matching order for equality
def compat(x, y):
return getattr(x, compat_str)(y)

return (self._coord_names == other._coord_names and
utils.dict_equiv(self._variables, other._variables,
compat=compat))
Expand Down Expand Up @@ -1160,6 +1163,7 @@ def sel(self, method=None, tolerance=None, drop=False, **indexers):
return result._replace_indexes(new_indexes)

def isel_points(self, dim='points', **indexers):
# type: (...) -> Dataset
"""Returns a new dataset with each array indexed pointwise along the
specified dimension(s).
Expand Down Expand Up @@ -1200,16 +1204,27 @@ def isel_points(self, dim='points', **indexers):

indexer_dims = set(indexers)

def take(variable, slices):
# Note: remove helper function when once when numpy
# supports vindex https://github.com/numpy/numpy/pull/6075
if hasattr(variable.data, 'vindex'):
# Special case for dask backed arrays to use vectorised list indexing
sel = variable.data.vindex[slices]
else:
# Otherwise assume backend is numpy array with 'fancy' indexing
sel = variable.data[slices]
return sel

def relevant_keys(mapping):
return [k for k, v in mapping.items()
if any(d in indexer_dims for d in v.dims)]

data_vars = relevant_keys(self.data_vars)
coords = relevant_keys(self.coords)

# all the indexers should be iterables
keys = indexers.keys()
indexers = [(k, np.asarray(v)) for k, v in iteritems(indexers)]
indexers_dict = dict(indexers)
non_indexed_dims = set(self.dims) - indexer_dims

# All the indexers should be iterables
# Check that indexers are valid dims, integers, and 1D
for k, v in indexers:
if k not in self.dims:
Expand All @@ -1230,22 +1245,54 @@ def relevant_keys(mapping):
# dim is an invalid string
raise ValueError('Existing dimension names are not valid '
'choices for the dim argument in sel_points')

elif hasattr(dim, 'dims'):
# dim is a DataArray or Coordinate
if dim.name in self.dims:
# dim already exists
raise ValueError('Existing dimensions are not valid choices '
'for the dim argument in sel_points')

if not utils.is_scalar(dim) and not isinstance(dim, DataArray):
dim = as_variable(dim, name='points')
# Set the new dim_name, and optionally the new dim coordinate
# dim is either an array-like or a string
if not utils.is_scalar(dim):
# dim is array like get name or assign 'points', get as variable
dim_name = 'points' if not hasattr(dim, 'name') else dim.name
dim_coord = as_variable(dim, name=dim_name)
else:
# dim is a string
dim_name = dim
dim_coord = None

# TODO: This would be sped up with vectorized indexing. This will
# require dask to support pointwise indexing as well.
return concat([self.isel(**d) for d in
[dict(zip(keys, inds)) for inds in
zip(*[v for k, v in indexers])]],
dim=dim, coords=coords, data_vars=data_vars)
reordered = self.transpose(*(list(indexer_dims) + list(non_indexed_dims)))

variables = OrderedDict()

for name, var in reordered.variables.items():
if name in indexers_dict or any(d in indexer_dims for d in var.dims):
# slice if var is an indexer or depends on an indexed dim
slc = [indexers_dict[k]
if k in indexers_dict
else slice(None) for k in var.dims]

var_dims = [dim_name] + [d for d in var.dims
if d in non_indexed_dims]
selection = take(var, tuple(slc))
var_subset = type(var)(var_dims, selection, var.attrs)
variables[name] = var_subset
else:
# If not indexed just add it back to variables or coordinates
variables[name] = var

coord_names = set(coords) & set(variables)

dset = self._replace_vars_and_dims(variables, coord_names=coord_names)
# Add the dim coord to the new dset. Must be done after creation
# because_replace_vars_and_dims can only access existing coords,
# not add new ones
if dim_coord is not None:
dset.coords[dim_name] = dim_coord
return dset

def sel_points(self, dim='points', method=None, tolerance=None,
**indexers):
Expand Down Expand Up @@ -2004,8 +2051,8 @@ def reduce(self, func, dim=None, keep_attrs=False, numeric_only=False,
if reduce_dims or not var.dims:
if name not in self.coords:
if (not numeric_only or
np.issubdtype(var.dtype, np.number) or
var.dtype == np.bool_):
np.issubdtype(var.dtype, np.number) or
var.dtype == np.bool_):
if len(reduce_dims) == 1:
# unpack dimensions for the benefit of functions
# like np.argmin which can't handle tuple arguments
Expand Down Expand Up @@ -2284,6 +2331,7 @@ def func(self, *args, **kwargs):
if keep_attrs:
ds._attrs = self._attrs
return ds

return func

@staticmethod
Expand All @@ -2299,6 +2347,7 @@ def func(self, other):
ds = self._calculate_binary_op(g, other, join=align_type,
fillna=fillna)
return ds

return func

@staticmethod
Expand All @@ -2317,6 +2366,7 @@ def func(self, other):
self._replace_vars_and_dims(ds._variables, ds._coord_names,
attrs=ds._attrs, inplace=True)
return self

return func

def _calculate_binary_op(self, f, other, join='inner',
Expand Down Expand Up @@ -2721,8 +2771,9 @@ def filter_by_attrs(self, **kwargs):
for attr_name, pattern in kwargs.items():
attr_value = variable.attrs.get(attr_name)
if ((callable(pattern) and pattern(attr_value))
or attr_value == pattern):
or attr_value == pattern):
selection.append(var_name)
return self[selection]


ops.inject_all_ops_and_reduce_methods(Dataset, array_only=False)
8 changes: 4 additions & 4 deletions xarray/test/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -910,7 +910,6 @@ def test_isel_points(self):
pdim1 = [1, 2, 3]
pdim2 = [4, 5, 1]
pdim3 = [1, 2, 3]

actual = data.isel_points(dim1=pdim1, dim2=pdim2, dim3=pdim3,
dim='test_coord')
assert 'test_coord' in actual.dims
Expand Down Expand Up @@ -984,7 +983,6 @@ def test_sel_points(self):

# add in a range() index
data['dim1'] = data.dim1
print(data)

pdim1 = [1, 2, 3]
pdim2 = [4, 5, 1]
Expand All @@ -996,13 +994,15 @@ def test_sel_points(self):
self.assertDatasetIdentical(expected, actual)

data = Dataset({'foo': (('x', 'y'), np.arange(9).reshape(3, 3))})
expected = Dataset({'foo': ('points', [0, 4, 8])})
expected = Dataset({'foo': ('points', [0, 4, 8])}
)
actual = data.sel_points(x=[0, 1, 2], y=[0, 1, 2])
self.assertDatasetIdentical(expected, actual)

data.coords.update({'x': [0, 1, 2], 'y': [0, 1, 2]})
expected.coords.update({'x': ('points', [0, 1, 2]),
'y': ('points', [0, 1, 2])})
'y': ('points', [0, 1, 2])
})
actual = data.sel_points(x=[0.1, 1.1, 2.5], y=[0, 1.2, 2.0],
method='pad')
self.assertDatasetIdentical(expected, actual)
Expand Down

0 comments on commit 4bb630f

Please sign in to comment.