Skip to content

Commit

Permalink
ENH: Improve fftn performance
Browse files Browse the repository at this point in the history
This function is a bad use case for numpy arrays because the sizes are normally
very small. Moving over to list comprehensions gives a nice speed improvement.
  • Loading branch information
peterbell10 authored and larsoner committed Jul 19, 2019
1 parent 7f0f667 commit 2f34eaf
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 43 deletions.
4 changes: 3 additions & 1 deletion scipy/fft/_pocketfft/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import numpy as np
import functools
from scipy.fft._pocketfft import pypocketfft as pfft
from scipy.fftpack.helper import _init_nd_shape_and_axes
from scipy.fftpack.helper import (_init_nd_shape_and_axes_impl
as _init_nd_shape_and_axes)


# TODO: Build with OpenMp and add configuration support
_default_workers = 1


def _asfarray(x):
"""
Convert to array with floating or complex dtype.
Expand Down
12 changes: 6 additions & 6 deletions scipy/fft/_pocketfft/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,12 +747,12 @@ def test_shape_argument_more(self):
def test_invalid_sizes(self):
with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[1 0\]\) specified"):
r" \(\[1, 0\]\) specified"):
fftn([[]])

with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[ 4 -3\]\) specified"):
r" \(\[4, -3\]\) specified"):
fftn([[1, 1], [2, 2]], (4, -3))

def test_no_axes(self):
Expand Down Expand Up @@ -794,12 +794,12 @@ def test_random_complex(self, maxnlp, size):
def test_invalid_sizes(self):
with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[1 0\]\) specified"):
r" \(\[1, 0\]\) specified"):
ifftn([[]])

with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[ 4 -3\]\) specified"):
r" \(\[4, -3\]\) specified"):
ifftn([[1, 1], [2, 2]], (4, -3))

def test_no_axes(self):
Expand Down Expand Up @@ -839,12 +839,12 @@ def test_random(self, size):
def test_invalid_sizes(self, func):
with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[1 0\]\) specified"):
r" \(\[1, 0\]\) specified"):
func([[]])

with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[ 4 -3\]\) specified"):
r" \(\[4, -3\]\) specified"):
func([[1, 1], [2, 2]], (4, -3))

@pytest.mark.parametrize('func', [rfftn, irfftn])
Expand Down
68 changes: 36 additions & 32 deletions scipy/fftpack/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from numpy import (arange, array, asarray, atleast_1d, intc, integer,
isscalar, issubdtype, take, unique, where)
from numpy.fft.helper import fftshift, ifftshift, fftfreq
import numpy as np
from bisect import bisect_left

__all__ = ['fftshift', 'ifftshift', 'fftfreq', 'rfftfreq', 'next_fast_len']
Expand Down Expand Up @@ -184,53 +185,56 @@ def _init_nd_shape_and_axes(x, shape, axes):
The shape of the result. It is a 1D integer array.
"""
x = asarray(x)
s, a = _init_nd_shape_and_axes_impl(np.asarray(x), shape, axes)
return np.asarray(s), np.asarray(a)


def _init_nd_shape_and_axes_impl(x, shape, axes):
"""Implementation of _init_nd_shape_and_axes"""
noshape = shape is None
noaxes = axes is None

if noaxes:
axes = arange(x.ndim, dtype=intc)
axes = range(x.ndim)
else:
axes = atleast_1d(axes)
axes = np.atleast_1d(axes)

if axes.size == 0:
axes = axes.astype(intc)
if axes.size == 0:
axes = axes.astype(np.intc)

if not axes.ndim == 1:
raise ValueError("when given, axes values must be a scalar or vector")
if not issubdtype(axes.dtype, integer):
raise ValueError("when given, axes values must be integers")
if not axes.ndim == 1:
raise ValueError("when given, axes values must be a scalar or vector")
if not np.issubdtype(axes.dtype, np.integer):
raise ValueError("when given, axes values must be integers")

axes = where(axes < 0, axes + x.ndim, axes)
axes = [a + x.ndim if a < 0 else a for a in axes]

if axes.size != 0 and (axes.max() >= x.ndim or axes.min() < 0):
raise ValueError("axes exceeds dimensionality of input")
if axes.size != 0 and unique(axes).shape != axes.shape:
raise ValueError("all axes must be unique")
if any(a >= x.ndim or a < 0 for a in axes):
raise ValueError("axes exceeds dimensionality of input")
if len(set(axes)) != len(axes):
raise ValueError("all axes must be unique")

if not noshape:
shape = atleast_1d(shape)
elif isscalar(x):
shape = array([], dtype=intc)
elif noaxes:
shape = array(x.shape, dtype=intc)
else:
shape = take(x.shape, axes)
shape = np.atleast_1d(shape)

if shape.size == 0:
shape = shape.astype(intc)
if shape.size == 0:
shape = shape.astype(np.intc)

if shape.ndim != 1:
raise ValueError("when given, shape values must be a scalar or vector")
if not issubdtype(shape.dtype, integer):
raise ValueError("when given, shape values must be integers")
if axes.shape != shape.shape:
raise ValueError("when given, axes and shape arguments"
" have to be of the same length")
if shape.ndim != 1:
raise ValueError("when given, shape values must be a scalar or vector")
if not np.issubdtype(shape.dtype, np.integer):
raise ValueError("when given, shape values must be integers")
if len(axes) != len(shape):
raise ValueError("when given, axes and shape arguments"
" have to be of the same length")

shape = where(shape == -1, array(x.shape)[axes], shape)
shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)]
elif noaxes:
shape = list(x.shape)
else:
shape = [x.shape[a] for a in axes]

if shape.size != 0 and (shape < 1).any():
if any(s < 1 for s in shape):
raise ValueError(
"invalid number of data points ({0}) specified".format(shape))

Expand Down
8 changes: 4 additions & 4 deletions scipy/fftpack/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,12 +748,12 @@ def test_shape_argument_more(self):
def test_invalid_sizes(self):
with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[1 0\]\) specified"):
r" \(\[1, 0\]\) specified"):
fftn([[]])

with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[ 4 -3\]\) specified"):
r" \(\[4, -3\]\) specified"):
fftn([[1, 1], [2, 2]], (4, -3))


Expand Down Expand Up @@ -791,12 +791,12 @@ def test_random_complex(self, maxnlp, size):
def test_invalid_sizes(self):
with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[1 0\]\) specified"):
r" \(\[1, 0\]\) specified"):
ifftn([[]])

with assert_raises(ValueError,
match="invalid number of data points"
r" \(\[ 4 -3\]\) specified"):
r" \(\[4, -3\]\) specified"):
ifftn([[1, 1], [2, 2]], (4, -3))


Expand Down

0 comments on commit 2f34eaf

Please sign in to comment.