Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify replacement of fftpack by pyfftw #5295

Merged
merged 2 commits into from Oct 10, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 33 additions & 0 deletions scipy/fftpack/tests/test_import.py
@@ -0,0 +1,33 @@
"""Test possibility of patching fftpack with pyfftw.

No module source outside of scipy.fftpack should contain an import of
the form `from scipy.fftpack import ...`, so that a simple replacement
of scipy.fftpack by the corresponding fftw interface completely swaps
the two FFT implementations.

Because this simply inspects source files, we only need to run the test
on one version of Python.
"""


import sys
if sys.version_info >= (3, 4):
from pathlib import Path
import re
from numpy.testing import TestCase, assert_, run_module_suite
import scipy

class TestFFTPackImport(TestCase):
def test_fftpack_import(self):
base = Path(scipy.__file__).parent
regexp = r"\s*from.+\.fftpack import .*\n"
for path in base.rglob("*.py"):
if base / "fftpack" in path.parents:
continue
with path.open() as file:
assert_(all(not re.fullmatch(regexp, line)
for line in file),
"{} contains an import from fftpack".format(path))

if __name__ == "__main__":
run_module_suite(argv=sys.argv)
4 changes: 2 additions & 2 deletions scipy/linalg/tests/test_special_matrices.py
Expand Up @@ -10,11 +10,11 @@

from scipy._lib.six import xrange

from scipy import fftpack
from scipy.special import comb
from scipy.linalg import (toeplitz, hankel, circulant, hadamard, leslie,
companion, tri, triu, tril, kron, block_diag,
helmert, hilbert, invhilbert, pascal, invpascal, dft)
from scipy.fftpack import fft
from numpy.linalg import cond


Expand Down Expand Up @@ -576,7 +576,7 @@ def test_dft():
x = array([0, 1, 2, 3, 4, 5, 0, 1])
m = dft(8)
mx = m.dot(x)
fx = fft(x)
fx = fftpack.fft(x)
yield (assert_array_almost_equal, mx, fx)


Expand Down
29 changes: 14 additions & 15 deletions scipy/signal/signaltools.py
Expand Up @@ -9,10 +9,7 @@
from . import sigtools
from scipy._lib.six import callable
from scipy._lib._version import NumpyVersion
from scipy import linalg
from scipy.fftpack import (fft, ifft, ifftshift, fft2, ifft2, fftn,
ifftn, fftfreq)
from numpy.fft import rfftn, irfftn
from scipy import fftpack, linalg
from numpy import (allclose, angle, arange, argsort, array, asarray,
atleast_1d, atleast_2d, cast, dot, exp, expand_dims,
iscomplexobj, mean, ndarray, newaxis, ones, pi,
Expand Down Expand Up @@ -357,8 +354,9 @@ def fftconvolve(in1, in2, mode="full"):
# sure we only call rfftn/irfftn from one thread at a time.
if not complex_result and (_rfft_mt_safe or _rfft_lock.acquire(False)):
try:
ret = irfftn(rfftn(in1, fshape) *
rfftn(in2, fshape), fshape)[fslice].copy()
ret = (np.fft.irfftn(np.fft.rfftn(in1, fshape) *
np.fft.rfftn(in2, fshape), fshape)[fslice].
copy())
finally:
if not _rfft_mt_safe:
_rfft_lock.release()
Expand All @@ -367,7 +365,8 @@ def fftconvolve(in1, in2, mode="full"):
# failed to acquire _rfft_lock (meaning rfftn isn't threadsafe and
# is already in use by another thread). In either case, use the
# (threadsafe but slower) SciPy complex-FFT routines instead.
ret = ifftn(fftn(in1, fshape) * fftn(in2, fshape))[fslice].copy()
ret = fftpack.ifftn(fftpack.fftn(in1, fshape) *
fftpack.fftn(in2, fshape))[fslice].copy()
if not complex_result:
ret = ret.real

Expand Down Expand Up @@ -1181,7 +1180,7 @@ def hilbert(x, N=None, axis=-1):
if N <= 0:
raise ValueError("N must be positive.")

Xf = fft(x, N, axis=axis)
Xf = fftpack.fft(x, N, axis=axis)
h = zeros(N)
if N % 2 == 0:
h[0] = h[N // 2] = 1
Expand All @@ -1194,7 +1193,7 @@ def hilbert(x, N=None, axis=-1):
ind = [newaxis] * x.ndim
ind[axis] = slice(None)
h = h[ind]
x = ifft(Xf * h, axis=axis)
x = fftpack.ifft(Xf * h, axis=axis)
return x


Expand Down Expand Up @@ -1235,7 +1234,7 @@ def hilbert2(x, N=None):
raise ValueError("When given as a tuple, N must hold exactly "
"two positive integers")

Xf = fft2(x, N, axes=(0, 1))
Xf = fftpack.fft2(x, N, axes=(0, 1))
h1 = zeros(N[0], 'd')
h2 = zeros(N[1], 'd')
for p in range(2):
Expand All @@ -1254,7 +1253,7 @@ def hilbert2(x, N=None):
while k > 2:
h = h[:, newaxis]
k -= 1
x = ifft2(Xf * h, axes=(0, 1))
x = fftpack.ifft2(Xf * h, axes=(0, 1))
return x


Expand Down Expand Up @@ -1708,17 +1707,17 @@ def resample(x, num, t=None, axis=0, window=None):
>>> plt.show()
"""
x = asarray(x)
X = fft(x, axis=axis)
X = fftpack.fft(x, axis=axis)
Nx = x.shape[axis]
if window is not None:
if callable(window):
W = window(fftfreq(Nx))
W = window(fftpack.fftfreq(Nx))
elif isinstance(window, ndarray):
if window.shape != (Nx,):
raise ValueError('window must have the same length as data')
W = window
else:
W = ifftshift(get_window(window, Nx))
W = fftpack.ifftshift(get_window(window, Nx))
newshape = [1] * x.ndim
newshape[axis] = len(W)
W.shape = newshape
Expand All @@ -1732,7 +1731,7 @@ def resample(x, num, t=None, axis=0, window=None):
Y[sl] = X[sl]
sl[axis] = slice(-(N - 1) // 2, None)
Y[sl] = X[sl]
y = ifft(Y, axis=axis) * (float(num) / float(Nx))
y = fftpack.ifft(Y, axis=axis) * (float(num) / float(Nx))

if x.dtype.char not in ['F', 'D']:
y = y.real
Expand Down
7 changes: 3 additions & 4 deletions scipy/signal/windows.py
Expand Up @@ -4,8 +4,7 @@
import warnings

import numpy as np
from scipy import special, linalg
from scipy.fftpack import fft
from scipy import fftpack, linalg, special
from scipy._lib.six import string_types

__all__ = ['boxcar', 'triang', 'parzen', 'bohman', 'blackman', 'nuttall',
Expand Down Expand Up @@ -1324,13 +1323,13 @@ def chebwin(M, at, sym=True):
# Appropriate IDFT and filling up
# depending on even/odd M
if M % 2:
w = np.real(fft(p))
w = np.real(fftpack.fft(p))
n = (M + 1) // 2
w = w[:n]
w = np.concatenate((w[n - 1:0:-1], w))
else:
p = p * np.exp(1.j * np.pi / M * np.r_[0:M])
w = np.real(fft(p))
w = np.real(fftpack.fft(p))
n = M // 2 + 1
w = np.concatenate((w[n - 1:0:-1], w[1:n]))
w = w / max(w)
Expand Down