-
-
Notifications
You must be signed in to change notification settings - Fork 5.1k
/
_backend.py
191 lines (151 loc) · 6.24 KB
/
_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import scipy._lib.uarray as ua
from . import _fftlog
from . import _pocketfft
class _ScipyBackend:
"""The default backend for fft calculations
Notes
-----
We use the domain ``numpy.scipy`` rather than ``scipy`` because ``uarray``
treats the domain as a hierarchy. This means the user can install a single
backend for ``numpy`` and have it implement ``numpy.scipy.fft`` as well.
"""
__ua_domain__ = "numpy.scipy.fft"
@staticmethod
def __ua_function__(method, args, kwargs):
fn = getattr(_pocketfft, method.__name__, None)
if fn is None:
fn = getattr(_fftlog, method.__name__, None)
if fn is None:
return NotImplemented
return fn(*args, **kwargs)
_named_backends = {
'scipy': _ScipyBackend,
}
def _backend_from_arg(backend):
"""Maps strings to known backends and validates the backend"""
if isinstance(backend, str):
try:
backend = _named_backends[backend]
except KeyError as e:
raise ValueError(f'Unknown backend {backend}') from e
if backend.__ua_domain__ != 'numpy.scipy.fft':
raise ValueError('Backend does not implement "numpy.scipy.fft"')
return backend
def set_global_backend(backend, coerce=False, only=False, try_last=False):
"""Sets the global fft backend
This utility method replaces the default backend for permanent use. It
will be tried in the list of backends automatically, unless the
``only`` flag is set on a backend. This will be the first tried
backend outside the :obj:`set_backend` context manager.
Parameters
----------
backend : {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
coerce : bool
Whether to coerce input types when trying this backend.
only : bool
If ``True``, no more backends will be tried if this fails.
Implied by ``coerce=True``.
try_last : bool
If ``True``, the global backend is tried after registered backends.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``.
Notes
-----
This will overwrite the previously set global backend, which, by default, is
the SciPy implementation.
Examples
--------
We can set the global fft backend:
>>> from scipy.fft import fft, set_global_backend
>>> set_global_backend("scipy") # Sets global backend. "scipy" is the default backend.
>>> fft([1]) # Calls the global backend
array([1.+0.j])
"""
backend = _backend_from_arg(backend)
ua.set_global_backend(backend, coerce=coerce, only=only, try_last=try_last)
def register_backend(backend):
"""
Register a backend for permanent use.
Registered backends have the lowest priority and will be tried after the
global backend.
Parameters
----------
backend : {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
Raises
------
ValueError: If the backend does not implement ``numpy.scipy.fft``.
Examples
--------
We can register a new fft backend:
>>> from scipy.fft import fft, register_backend, set_global_backend
>>> class NoopBackend: # Define an invalid Backend
... __ua_domain__ = "numpy.scipy.fft"
... def __ua_function__(self, func, args, kwargs):
... return NotImplemented
>>> set_global_backend(NoopBackend()) # Set the invalid backend as global
>>> register_backend("scipy") # Register a new backend
>>> fft([1]) # The registered backend is called because the global backend returns `NotImplemented`
array([1.+0.j])
>>> set_global_backend("scipy") # Restore global backend to default
"""
backend = _backend_from_arg(backend)
ua.register_backend(backend)
def set_backend(backend, coerce=False, only=False):
"""Context manager to set the backend within a fixed scope.
Upon entering the ``with`` statement, the given backend will be added to
the list of available backends with the highest priority. Upon exit, the
backend is reset to the state before entering the scope.
Parameters
----------
backend : {object, 'scipy'}
The backend to use.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
coerce : bool, optional
Whether to allow expensive conversions for the ``x`` parameter. e.g.,
copying a NumPy array to the GPU for a CuPy backend. Implies ``only``.
only : bool, optional
If only is ``True`` and this backend returns ``NotImplemented``, then a
BackendNotImplemented error will be raised immediately. Ignoring any
lower priority backends.
Examples
--------
>>> import scipy.fft as fft
>>> with fft.set_backend('scipy', only=True):
... fft.fft([1]) # Always calls the scipy implementation
array([1.+0.j])
"""
backend = _backend_from_arg(backend)
return ua.set_backend(backend, coerce=coerce, only=only)
def skip_backend(backend):
"""Context manager to skip a backend within a fixed scope.
Within the context of a ``with`` statement, the given backend will not be
called. This covers backends registered both locally and globally. Upon
exit, the backend will again be considered.
Parameters
----------
backend : {object, 'scipy'}
The backend to skip.
Can either be a ``str`` containing the name of a known backend
{'scipy'} or an object that implements the uarray protocol.
Examples
--------
>>> import scipy.fft as fft
>>> fft.fft([1]) # Calls default SciPy backend
array([1.+0.j])
>>> with fft.skip_backend('scipy'): # We explicitly skip the SciPy backend
... fft.fft([1]) # leaving no implementation available
Traceback (most recent call last):
...
BackendNotImplementedError: No selected backends had an implementation ...
"""
backend = _backend_from_arg(backend)
return ua.skip_backend(backend)
set_global_backend('scipy', try_last=True)