Skip to content

Commit

Permalink
WIP, ENH: signal array API support
Browse files Browse the repository at this point in the history
* related to scipygh-18286

- this only drafts array API support for `welch()`, not all of `signal`, and I think we may not have been keen on partial submodule support; nonetheless, initial feedback may be helpful before I expand more broadly (if that's the desired approach)
- this has been adjusted a bit from the initial draft that was used for the `welch()` benchmarks in the array API conference paper, because decisions and internal array API infrastructure changed in the last few weeks, so should double check that those performance improvements remain
- there are some pretty tricky shims in a few places, that are no doubt going to require some discussion

I did check locally that these pass:
- `SCIPY_DEVICE=cuda python dev.py test -b numpy -b pytorch -b cupy -- -k "TestWelch"`
- `SCIPY_DEVICE=cpu python dev.py test -b numpy -b pytorch -b cupy -- -k "TestWelch"`
- `python dev.py test -j 32`

[skip cirrus] [skip circle]
  • Loading branch information
tylerjereddy committed Mar 27, 2024
1 parent 25a256b commit 650b0a6
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 194 deletions.
49 changes: 32 additions & 17 deletions scipy/signal/_signaltools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
import timeit
import warnings

from scipy._lib._array_api import (
array_namespace, size,
)
from scipy.spatial import cKDTree
from . import _sigtools
from ._ltisys import dlti
Expand Down Expand Up @@ -1956,8 +1959,8 @@ def medfilt2d(input, kernel_size=3):
if kernel_size.shape == ():
kernel_size = np.repeat(kernel_size.item(), 2)

for size in kernel_size:
if (size % 2) != 1:
for ksize in kernel_size:
if (ksize % 2) != 1:
raise ValueError("Each element of kernel_size should be odd.")

return _sigtools._medfilt2d(image, kernel_size)
Expand Down Expand Up @@ -3560,20 +3563,28 @@ def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False):
0.06 # random
"""
xp = array_namespace(data)
if type not in ['linear', 'l', 'constant', 'c']:
raise ValueError("Trend type must be 'linear' or 'constant'.")
data = np.asarray(data)
dtype = data.dtype.char
if dtype not in 'dfDF':
dtype = 'd'
data = xp.asarray(data)
dtype = data.dtype
if data.dtype not in [xp.float32, xp.float64, xp.complex128, xp.complex64]:
dtype = xp.float64
if type in ['constant', 'c']:
ret = data - np.mean(data, axis, keepdims=True)
ret = data - xp.mean(xp.asarray(data, dtype=dtype), axis=axis, keepdims=True)
return ret
else:
dshape = data.shape
N = dshape[axis]
bp = np.sort(np.unique(np.concatenate(np.atleast_1d(0, bp, N))))
if np.any(bp > N):
if isinstance(bp, int):
bp = xp.sort(xp.unique(xp.asarray([0, bp, N])))
else:
bp = xp.asarray(bp)
new_bp = xp.empty(size(bp) + 1)
new_bp[:size(bp)] = bp
new_bp[size(bp)] = N
bp = xp.sort(xp.unique(xp.asarray([0] + new_bp)))
if xp.any(bp > N):
raise ValueError("Breakpoints must be less than length "
"of data along given axis.")

Expand All @@ -3582,28 +3593,32 @@ def detrend(data, axis=-1, type='linear', bp=0, overwrite_data=False):
rnk = len(dshape)
if axis < 0:
axis = axis + rnk
newdata = np.moveaxis(data, axis, 0)
newdata = xp.moveaxis(data, axis, 0)
newdata_shape = newdata.shape
newdata = newdata.reshape(N, -1)

if not overwrite_data:
newdata = newdata.copy() # make sure we have a copy
if newdata.dtype.char not in 'dfDF':
newdata = xp.asarray(newdata, copy=True) # make sure we have a copy
if newdata.dtype not in [xp.float64, xp.float32, xp.complex128, xp.complex64]:
newdata = newdata.astype(dtype)

# Nreg = len(bp) - 1
# Find leastsq fit and remove it for each piece
for m in range(len(bp) - 1):
Npts = bp[m + 1] - bp[m]
A = np.ones((Npts, 2), dtype)
A[:, 0] = np.arange(1, Npts + 1, dtype=dtype) / Npts
sl = slice(bp[m], bp[m + 1])
coef, resids, rank, s = linalg.lstsq(A, newdata[sl])
A = xp.ones((int(Npts), 2), dtype=dtype)
A[:, 0] = xp.arange(1, Npts + 1, dtype=dtype) / Npts
sl = slice(int(bp[m]), int(bp[m + 1]))
# NOTE: lstsq isn't in the array API standard
if "cupy" in xp.__name__ or "torch" in xp.__name__:
coef, resids, rank, s = xp.linalg.lstsq(A, newdata[sl], rcond=None)
else:
coef, resids, rank, s = linalg.lstsq(A, newdata[sl])
newdata[sl] = newdata[sl] - A @ coef

# Put data back in original shape.
newdata = newdata.reshape(newdata_shape)
ret = np.moveaxis(newdata, 0, axis)
ret = xp.moveaxis(newdata, 0, axis)
return ret


Expand Down
118 changes: 74 additions & 44 deletions scipy/signal/_spectral_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
"""
import numpy as np
from scipy import fft as sp_fft
from scipy._lib._array_api import (
array_namespace, size,
)
from . import _signaltools
from .windows import get_window
from ._spectral import _lombscargle
Expand Down Expand Up @@ -597,28 +600,38 @@ def csd(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None, nfft=None,
>>> plt.show()
"""
xp = array_namespace(x)
freqs, _, Pxy = _spectral_helper(x, y, fs, window, nperseg, noverlap,
nfft, detrend, return_onesided, scaling,
axis, mode='psd')

# Average over windows.
if len(Pxy.shape) >= 2 and Pxy.size > 0:
if len(Pxy.shape) >= 2 and size(Pxy) > 0:
if Pxy.shape[-1] > 1:
if average == 'median':
# np.median must be passed real arrays for the desired result
# xp.median must be passed real arrays for the desired result
bias = _median_bias(Pxy.shape[-1])
if np.iscomplexobj(Pxy):
Pxy = (np.median(np.real(Pxy), axis=-1)
+ 1j * np.median(np.imag(Pxy), axis=-1))
if Pxy.dtype in [xp.complex64, xp.complex128]:
Pxy = (xp.median(xp.real(Pxy), axis=-1)
+ 1j * xp.median(xp.imag(Pxy), axis=-1))
else:
Pxy = np.median(Pxy, axis=-1)
Pxy /= bias
Pxy = xp.median(Pxy, axis=-1)
# for PyTorch, Pxy is torch.return_types.median
# which is super confusing...
try:
device_pxy = xp.device(Pxy)
except AttributeError:
Pxy = Pxy.values
device_pxy = xp.device(Pxy)
bias = xp.asarray(bias)
bias = xp.to_device(bias, device_pxy)
Pxy = Pxy / bias
elif average == 'mean':
Pxy = Pxy.mean(axis=-1)
else:
raise ValueError(f'average must be "median" or "mean", got {average}')
else:
Pxy = np.reshape(Pxy, Pxy.shape[:-1])
Pxy = xp.reshape(Pxy, Pxy.shape[:-1])

return freqs, Pxy

Expand Down Expand Up @@ -1760,6 +1773,7 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None,
.. versionadded:: 0.16.0
"""
xp = array_namespace(x)
if mode not in ['psd', 'stft']:
raise ValueError("Unknown value for mode %s, must be one of: "
"{'psd', 'stft'}" % mode)
Expand All @@ -1782,13 +1796,15 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None,

axis = int(axis)

# Ensure we have np.arrays, get outdtype
x = np.asarray(x)
# Ensure we have xp.arrays, get outdtype
x = xp.asarray(x)
# https://github.com/data-apis/array-api-compat/issues/43
tmp = xp.asarray([0], dtype=xp.complex64)
if not same_data:
y = np.asarray(y)
outdtype = np.result_type(x, y, np.complex64)
y = xp.asarray(y)
outdtype = xp.result_type(x, y, xp.complex64)
else:
outdtype = np.result_type(x, np.complex64)
outdtype = xp.result_type(x, tmp)

if not same_data:
# Check if we can broadcast the outer axes together
Expand All @@ -1797,36 +1813,36 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None,
xouter.pop(axis)
youter.pop(axis)
try:
outershape = np.broadcast(np.empty(xouter), np.empty(youter)).shape
outershape = xp.broadcast(xp.empty(xouter), xp.empty(youter)).shape
except ValueError as e:
raise ValueError('x and y cannot be broadcast together.') from e

if same_data:
if x.size == 0:
return np.empty(x.shape), np.empty(x.shape), np.empty(x.shape)
if size(x) == 0:
return xp.empty(x.shape), xp.empty(x.shape), xp.empty(x.shape)
else:
if x.size == 0 or y.size == 0:
if size(x) == 0 or size(y) == 0:
outshape = outershape + (min([x.shape[axis], y.shape[axis]]),)
emptyout = np.moveaxis(np.empty(outshape), -1, axis)
emptyout = xp.moveaxis(xp.empty(outshape), -1, axis)
return emptyout, emptyout, emptyout

if x.ndim > 1:
if axis != -1:
x = np.moveaxis(x, axis, -1)
x = xp.moveaxis(x, axis, -1)
if not same_data and y.ndim > 1:
y = np.moveaxis(y, axis, -1)
y = xp.moveaxis(y, axis, -1)

# Check if x and y are the same length, zero-pad if necessary
if not same_data:
if x.shape[-1] != y.shape[-1]:
if x.shape[-1] < y.shape[-1]:
pad_shape = list(x.shape)
pad_shape[-1] = y.shape[-1] - x.shape[-1]
x = np.concatenate((x, np.zeros(pad_shape)), -1)
x = xp.concatenate((x, xp.zeros(pad_shape)), -1)
else:
pad_shape = list(y.shape)
pad_shape[-1] = x.shape[-1] - y.shape[-1]
y = np.concatenate((y, np.zeros(pad_shape)), -1)
y = xp.concatenate((y, xp.zeros(pad_shape)), -1)

if nperseg is not None: # if specified by user
nperseg = int(nperseg)
Expand All @@ -1835,6 +1851,7 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None,

# parse window; if array like, then set nperseg = win.shape
win, nperseg = _triage_segments(window, nperseg, input_length=x.shape[-1])
win = xp.asarray(win)

if nfft is None:
nfft = nperseg
Expand Down Expand Up @@ -1868,10 +1885,10 @@ def _spectral_helper(x, y, fs=1.0, window='hann', nperseg=None, noverlap=None,
# I.e make x.shape[-1] = nperseg + (nseg-1)*nstep, with integer nseg
nadd = (-(x.shape[-1]-nperseg) % nstep) % nperseg
zeros_shape = list(x.shape[:-1]) + [nadd]
x = np.concatenate((x, np.zeros(zeros_shape)), axis=-1)
x = xp.concatenate((x, xp.zeros(zeros_shape)), axis=-1)
if not same_data:
zeros_shape = list(y.shape[:-1]) + [nadd]
y = np.concatenate((y, np.zeros(zeros_shape)), axis=-1)
y = xp.concatenate((y, xp.zeros(zeros_shape)), axis=-1)

# Handle detrending and window functions
if not detrend:
Expand All @@ -1884,14 +1901,14 @@ def detrend_func(d):
# Wrap this function so that it receives a shape that it could
# reasonably expect to receive.
def detrend_func(d):
d = np.moveaxis(d, -1, axis)
d = xp.moveaxis(d, -1, axis)
d = detrend(d)
return np.moveaxis(d, axis, -1)
return xp.moveaxis(d, axis, -1)
else:
detrend_func = detrend

if np.result_type(win, np.complex64) != outdtype:
win = win.astype(outdtype)
if xp.result_type(win, xp.complex64) != outdtype:
win = xp.astype(win, outdtype)

if scaling == 'density':
scale = 1.0 / (fs * (win*win).sum())
Expand All @@ -1901,17 +1918,23 @@ def detrend_func(d):
raise ValueError('Unknown scaling: %r' % scaling)

if mode == 'stft':
scale = np.sqrt(scale)
scale = xp.sqrt(scale)

if return_onesided:
if np.iscomplexobj(x):
try:
is_complex = xp.iscomplexobj(x)
except AttributeError:
# torch shim
is_complex = xp.is_complex(x)

if is_complex:
sides = 'twosided'
warnings.warn('Input data is complex, switching to return_onesided=False',
stacklevel=3)
else:
sides = 'onesided'
if not same_data:
if np.iscomplexobj(y):
if xp.iscomplexobj(y):
sides = 'twosided'
warnings.warn('Input data is complex, switching to '
'return_onesided=False',
Expand All @@ -1920,9 +1943,9 @@ def detrend_func(d):
sides = 'twosided'

if sides == 'twosided':
freqs = sp_fft.fftfreq(nfft, 1/fs)
freqs = xp.fft.fftfreq(nfft, 1/fs)
elif sides == 'onesided':
freqs = sp_fft.rfftfreq(nfft, 1/fs)
freqs = xp.fft.rfftfreq(nfft, 1/fs)

# Perform the windowed FFTs
result = _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides)
Expand All @@ -1931,9 +1954,9 @@ def detrend_func(d):
# All the same operations on the y data
result_y = _fft_helper(y, win, detrend_func, nperseg, noverlap, nfft,
sides)
result = np.conjugate(result) * result_y
result = xp.conj(result) * result_y
elif mode == 'psd':
result = np.conjugate(result) * result
result = xp.conj(result) * result

result *= scale
if sides == 'onesided' and mode == 'psd':
Expand All @@ -1943,12 +1966,12 @@ def detrend_func(d):
# Last point is unpaired Nyquist freq point, don't double
result[..., 1:-1] *= 2

time = np.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1,
time = xp.arange(nperseg/2, x.shape[-1] - nperseg/2 + 1,
nperseg - noverlap)/float(fs)
if boundary is not None:
time -= (nperseg/2) / fs

result = result.astype(outdtype)
result = xp.astype(result, outdtype)

# All imaginary parts are zero anyways
if same_data and mode != 'stft':
Expand All @@ -1960,7 +1983,7 @@ def detrend_func(d):
axis -= 1

# Roll frequency axis back to axis where the data came from
result = np.moveaxis(result, -1, axis)
result = xp.moveaxis(result, -1, axis)

return freqs, time, result

Expand All @@ -1987,12 +2010,14 @@ def _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides):
.. versionadded:: 0.16.0
"""
# Created sliding window view of array
xp = array_namespace(x)
x = xp.asarray(x)
# Created strided array of data segments
if nperseg == 1 and noverlap == 0:
result = x[..., np.newaxis]
result = x[..., xp.newaxis]
else:
step = nperseg - noverlap
result = np.lib.stride_tricks.sliding_window_view(
result = xp.lib.stride_tricks.sliding_window_view(
x, window_shape=nperseg, axis=-1, writeable=True
)
result = result[..., 0::step, :]
Expand All @@ -2001,14 +2026,18 @@ def _fft_helper(x, win, detrend_func, nperseg, noverlap, nfft, sides):
result = detrend_func(result)

# Apply window by multiplication
# NOTE: torch device shim -- needs
# deeper analysis
result_device = xp.device(result)
win = xp.to_device(win, result_device)
result = win * result

# Perform the fft. Acts on last axis by default. Zero-pads automatically
if sides == 'twosided':
func = sp_fft.fft
func = xp.fft.fft
else:
result = result.real
func = sp_fft.rfft
func = xp.fft.rfft
result = func(result, n=nfft)

return result
Expand Down Expand Up @@ -2059,7 +2088,8 @@ def _triage_segments(window, nperseg, input_length):
nperseg = input_length
win = get_window(window, nperseg)
else:
win = np.asarray(window)
xp = array_namespace(window)
win = xp.asarray(window)
if len(win.shape) != 1:
raise ValueError('window must be 1-D')
if input_length < win.shape[-1]:
Expand Down

0 comments on commit 650b0a6

Please sign in to comment.