From ca4999b91e34fbdd09746a9cddb674727ddf7f58 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Sat, 30 Jul 2016 20:53:42 -0700 Subject: [PATCH] Pickle xarray.ufunc functions Fixes GH901 --- doc/whats-new.rst | 3 +++ xarray/test/test_ufuncs.py | 7 +++++++ xarray/ufuncs.py | 22 ++++++++++++++-------- 3 files changed, 24 insertions(+), 8 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ea910b03637..0052f37b67d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -106,6 +106,9 @@ Bug fixes use to numpy functions instead of dask.array functions (:issue:`876`). By `Stephan Hoyer `_. +- Support for pickling functions from ``xarray.ufuncs`` (:issue:`901`). By + `Stephan Hoyer `_. + - ``Variable.copy(deep=True)`` no longer converts MultiIndex into a base Index (:issue:`769`). By `Benoit Bovy `_. diff --git a/xarray/test/test_ufuncs.py b/xarray/test/test_ufuncs.py index d0584765f00..6120b045fda 100644 --- a/xarray/test/test_ufuncs.py +++ b/xarray/test/test_ufuncs.py @@ -1,3 +1,5 @@ +import pickle + import numpy as np import xarray.ufuncs as xu @@ -56,3 +58,8 @@ def test_groupby(self): with self.assertRaisesRegexp(TypeError, 'only support binary ops'): xu.maximum(ds.a.variable, ds_grouped) + + def test_pickle(self): + a = 1.0 + cos_pickled = pickle.loads(pickle.dumps(xu.cos)) + self.assertIdentical(cos_pickled(a), xu.cos(a)) diff --git a/xarray/ufuncs.py b/xarray/ufuncs.py index a63c2a98bcc..489d0e4611e 100644 --- a/xarray/ufuncs.py +++ b/xarray/ufuncs.py @@ -35,32 +35,38 @@ def _dispatch_priority(obj): return -1 -def _create_op(name): +class _UFuncDispatcher(object): + """Wrapper for dispatching ufuncs.""" + def __init__(self, name): + self._name = name - def func(*args, **kwargs): + def __call__(self, *args, **kwargs): new_args = args - f = _dask_or_eager_func(name, n_array_args=len(args)) + f = _dask_or_eager_func(self._name, n_array_args=len(args)) if len(args) > 2 or len(args) == 0: raise TypeError('cannot handle %s arguments for %r' % - (len(args), name)) + (len(args), self._name)) elif len(args) == 1: if isinstance(args[0], _xarray_types): - f = args[0]._unary_op(func) + f = args[0]._unary_op(self) else: # len(args) = 2 p1, p2 = map(_dispatch_priority, args) if p1 >= p2: if isinstance(args[0], _xarray_types): - f = args[0]._binary_op(func) + f = args[0]._binary_op(self) else: if isinstance(args[1], _xarray_types): - f = args[1]._binary_op(func, reflexive=True) + f = args[1]._binary_op(self, reflexive=True) new_args = tuple(reversed(args)) res = f(*new_args, **kwargs) if res is NotImplemented: raise TypeError('%r not implemented for types (%r, %r)' - % (name, type(args[0]), type(args[1]))) + % (self._name, type(args[0]), type(args[1]))) return res + +def _create_op(name): + func = _UFuncDispatcher(name) func.__name__ = name doc = getattr(_np, name).__doc__ func.__doc__ = ('xarray specific variant of numpy.%s. Handles '