diff --git a/.github/TEST_FAIL_TEMPLATE.md b/.github/TEST_FAIL_TEMPLATE.md new file mode 100644 index 0000000..3512972 --- /dev/null +++ b/.github/TEST_FAIL_TEMPLATE.md @@ -0,0 +1,12 @@ +--- +title: "{{ env.TITLE }}" +labels: [bug] +--- +The {{ workflow }} workflow failed on {{ date | date("YYYY-MM-DD HH:mm") }} UTC + +The most recent failing test was on {{ env.PLATFORM }} py{{ env.PYTHON }} +with commit: {{ sha }} + +Full run: https://github.com/{{ repo }}/actions/runs/{{ env.RUN_ID }} + +(This post will be updated if another test fails, as long as this issue remains open.) diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000..96505a9 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +# https://docs.github.com/github/administering-a-repository/configuration-options-for-dependency-updates + +version: 2 +updates: + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "weekly" + commit-message: + prefix: "ci(dependabot):" diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b2df549..cd4f303 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -2,77 +2,98 @@ name: CI on: push: - branches: - - main - tags: - - "v*" - pull_request: {} + branches: [main] + tags: [v*] + pull_request: workflow_dispatch: + schedule: + # run every week (for --pre release tests) + - cron: "0 0 * * 0" + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true jobs: + check-manifest: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - run: pipx run check-manifest + test: name: ${{ matrix.platform }} (${{ matrix.python-version }}) runs-on: ${{ matrix.platform }} strategy: fail-fast: false matrix: - python-version: [3.8, 3.9, '3.10'] - platform: [ubuntu-latest] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] + platform: [ubuntu-latest, macos-latest, windows-latest] steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - - name: Set up miniconda ${{ matrix.python-version }} - uses: conda-incubator/setup-miniconda@v2 + - name: ๐Ÿ Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - mamba-version: "*" - channels: conda-forge - channel-priority: true + cache-dependency-path: "pyproject.toml" + cache: "pip" - - name: Install dependencies - shell: bash -l {0} + - name: Install Dependencies run: | - python -m pip install --upgrade pip - mamba install pyopencl ocl-icd-system pocl scipy numpy - pip install -e .[testing] + python -m pip install -U pip + # if running a cron job, we add the --pre flag to test against pre-releases + python -m pip install .[test] ${{ github.event_name == 'schedule' && '--pre' || '' }} + + - name: ๐Ÿงช Run Tests + run: pytest --color=yes --cov --cov-report=xml --cov-report=term-missing - - name: Test - shell: bash -l {0} - run: pytest -v --color yes + - name: ๐Ÿ“ Report --pre Failures + if: failure() && github.event_name == 'schedule' + uses: JasonEtco/create-an-issue@v2 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + PLATFORM: ${{ matrix.platform }} + PYTHON: ${{ matrix.python-version }} + RUN_ID: ${{ github.run_id }} + TITLE: "[test-bot] pip install --pre is failing" + with: + filename: .github/TEST_FAIL_TEMPLATE.md + update_existing: true - name: Coverage - uses: codecov/codecov-action@v2 + uses: codecov/codecov-action@v3 deploy: name: Deploy needs: test - if: "success() && startsWith(github.ref, 'refs/tags/')" + if: success() && startsWith(github.ref, 'refs/tags/') && github.event_name != 'schedule' runs-on: ubuntu-latest + permissions: + id-token: write + contents: write + steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@v4 + with: + fetch-depth: 0 - - name: Set up Python - uses: actions/setup-python@v2 + - name: ๐Ÿ Set up Python + uses: actions/setup-python@v4 with: python-version: "3.x" - - name: install + - name: ๐Ÿ‘ท Build run: | - git tag - pip install --upgrade pip - pip install -U build twine + python -m pip install build python -m build - twine check dist/* - ls -lh dist - - name: Build and publish - run: twine upload dist/* - env: - TWINE_USERNAME: __token__ - TWINE_PASSWORD: ${{ secrets.TWINE_API_KEY }} + - name: ๐Ÿšข Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 - uses: softprops/action-gh-release@v1 with: generate_release_notes: true + files: "./dist/*" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fab9ea8..b5da3f0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,38 +1,30 @@ +# enable pre-commit.ci at https://pre-commit.ci/ +# it adds: +# 1. auto fixing pull requests +# 2. auto updating the pre-commit configuration +ci: + autoupdate_schedule: monthly + autofix_commit_msg: "style(pre-commit.ci): auto fixes [...]" + autoupdate_commit_msg: "ci(pre-commit.ci): autoupdate" + repos: - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.1.0 - hooks: - - id: check-docstring-first - - id: end-of-file-fixer - - id: trailing-whitespace - - repo: https://github.com/asottile/setup-cfg-fmt - rev: v1.20.1 - hooks: - - id: setup-cfg-fmt - - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 - hooks: - - id: flake8 - additional_dependencies: [flake8-typing-imports] - - repo: https://github.com/myint/autoflake - rev: v1.4 - hooks: - - id: autoflake - args: ["--in-place", "--remove-all-unused-imports"] - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort - - repo: https://github.com/psf/black - rev: 22.3.0 - hooks: - - id: black - - repo: https://github.com/asottile/pyupgrade - rev: v2.31.1 - hooks: - - id: pyupgrade - args: [--py37-plus] - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.942 - hooks: - - id: mypy + - repo: https://github.com/abravalheri/validate-pyproject + rev: v0.16 + hooks: + - id: validate-pyproject + + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: v0.4.1 + hooks: + - id: ruff + args: [--fix, --unsafe-fixes] + - id: ruff-format + + # - repo: https://github.com/pre-commit/mirrors-mypy + # rev: v1.9.0 + # hooks: + # - id: mypy + # files: "^src/" + # # # you have to add the things you want to type check against here + # # additional_dependencies: + # # - numpy diff --git a/LICENSE b/LICENSE index f4035c2..628c8c2 100644 --- a/LICENSE +++ b/LICENSE @@ -1,8 +1,6 @@ +BSD 3-Clause License -BSD License - -Copyright (c) 2021, Talley Lambert -All rights reserved. +Copyright (c) 2023, Talley Lambert Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: diff --git a/MANIFEST.in b/MANIFEST.in deleted file mode 100644 index a0bc85f..0000000 --- a/MANIFEST.in +++ /dev/null @@ -1,8 +0,0 @@ -include LICENSE -include README.md - -recursive-include tests * -recursive-exclude * __pycache__ -recursive-exclude * *.py[co] - -recursive-include docs *.md conf.py Makefile make.bat *.jpg *.png *.gif diff --git a/README.md b/README.md index 20cfdf2..c6036d7 100644 --- a/README.md +++ b/README.md @@ -6,8 +6,11 @@ [![tests](https://github.com/tlambert03/anyfft/workflows/tests/badge.svg)](https://github.com/tlambert03/anyfft/actions) [![codecov](https://codecov.io/gh/tlambert03/anyfft/branch/master/graph/badge.svg)](https://codecov.io/gh/tlambert03/anyfft) -anyfft is a thin compatibility layer that wraps various python FFT implementations in the `scipy.fft` API (if they do not already provide it). my motiviation is to have a package that I can install anywhere that will take advantage of the available hardware (CUDA, OpenCL, CPU fallback) without refactoring. - +anyfft is a thin compatibility layer that wraps various python FFT +implementations in the `scipy.fft` API (if they do not already provide it). my +motivation is to have a package that I can install anywhere that will take +advantage of the available hardware (CUDA, OpenCL, CPU fallback) without +refactoring. ```python import anyfft @@ -19,7 +22,7 @@ anyfft.fft(array, plugin='reikna') anyfft.reikna.fft(array) ``` -### current backends: +## current backends - numpy (`np.fft`) - scipy (`scipy.fft`) @@ -28,7 +31,7 @@ anyfft.reikna.fft(array) - reikna (OpenCL FFT) - cupy (CUDA FFT) -### available functions: +## available functions fft, fft2, fftn, ifft, ifft2, ifftn, fftshift, ifftshift, irfft (reikna WIP), irfft2 (reikna WIP), irfftn (reikna WIP), rfft, rfft2, rfftn diff --git a/anyfft/_backend.py b/anyfft/_backend.py deleted file mode 100644 index 23f0da2..0000000 --- a/anyfft/_backend.py +++ /dev/null @@ -1,23 +0,0 @@ -from . import reikna - - -class ReiknaBackend: - """The default backend for fft calculations - - Notes - ----- - We use the domain ``numpy.scipy`` rather than ``scipy`` because in the - future, ``uarray`` will treat the domain as a hierarchy. This means the user - can install a single backend for ``numpy`` and have it implement - ``numpy.scipy.fft`` as well. - """ - - __ua_domain__ = "numpy.scipy.fft" - - @staticmethod - def __ua_function__(method, args, kwargs): - fn = getattr(reikna, method.__name__, None) - - if fn is None: - return NotImplemented - return fn(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..63af39c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,123 @@ +# https://peps.python.org/pep-0517/ +[build-system] +requires = ["hatchling", "hatch-vcs"] +build-backend = "hatchling.build" + +# https://hatch.pypa.io/latest/config/metadata/ +[tool.hatch.version] +source = "vcs" + +# read more about configuring hatch at: +# https://hatch.pypa.io/latest/config/build/ +[tool.hatch.build.targets.wheel] +only-include = ["src"] +sources = ["src"] + +# https://peps.python.org/pep-0621/ +[project] +name = "anyfft" +dynamic = ["version"] +description = "wraps various FFT implementations with single interface" +readme = "README.md" +requires-python = ">=3.8" +license = { text = "BSD-3-Clause" } +authors = [{ name = "Talley Lambert", email = "talley.lambert@example.com" }] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: BSD License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Typing :: Typed", +] +dependencies = ["numpy", "reikna"] + +# https://peps.python.org/pep-0621/#dependencies-optional-dependencies +[project.optional-dependencies] +ocl = ["pyopencl"] +test = [ + "anyfft[ocl]", + "pytest", + "pytest-cov", + "pytest-benchmark", + "scipy", + "pooch", +] +dev = ["anyfft[test]", "ipython", "mypy", "pdbpp", "pre-commit", "rich", "ruff"] + +[project.urls] +homepage = "https://github.com/tlambert03/anyfft" +repository = "https://github.com/tlambert03/anyfft" + +# https://docs.astral.sh/ruff +[tool.ruff] +line-length = 88 +target-version = "py38" +src = ["src"] + +# https://docs.astral.sh/ruff/rules +[tool.ruff.lint] +pydocstyle = { convention = "numpy" } +select = [ + "E", # style errors + "W", # style warnings + "F", # flakes + # "D", # pydocstyle + "D417", # Missing argument descriptions in Docstrings + "I", # isort + "UP", # pyupgrade + "C4", # flake8-comprehensions + "B", # flake8-bugbear + "A001", # flake8-builtins + "RUF", # ruff-specific rules + "TCH", # flake8-type-checking + "TID", # flake8-tidy-imports +] +ignore = [ + "D401", # First line should be in imperative mood (remove to opt in) +] + +[tool.ruff.lint.per-file-ignores] +"tests/*.py" = ["D", "S"] + +# https://docs.astral.sh/ruff/formatter/ +[tool.ruff.format] +docstring-code-format = true +skip-magic-trailing-comma = false # default is false + +# https://mypy.readthedocs.io/en/stable/config_file.html +[tool.mypy] +files = "src/**/" +strict = true +disallow_any_generics = false +disallow_subclassing_any = false +show_error_codes = true +pretty = true + +# https://docs.pytest.org/ +[tool.pytest.ini_options] +minversion = "7.0" +testpaths = ["tests"] +filterwarnings = ["error"] + +# https://coverage.readthedocs.io/ +[tool.coverage.report] +show_missing = true +exclude_lines = [ + "pragma: no cover", + "if TYPE_CHECKING:", + "@overload", + "except ImportError", + "\\.\\.\\.", + "raise NotImplementedError()", + "pass", +] + +[tool.coverage.run] +source = ["anyfft"] + +[tool.check-manifest] +ignore = [".pre-commit-config.yaml", ".ruff_cache/**/*", "tests/**/*"] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 7efaf2d..0000000 --- a/setup.cfg +++ /dev/null @@ -1,92 +0,0 @@ -[metadata] -name = anyfft -description = wraps various FFT implementations with single interface -long_description = file: README.md -long_description_content_type = text/markdown -url = https://github.com/tlambert03/anyfft -author = Talley Lambert -author_email = talley.lambert@gmail.com -license = BSD-3-Clause -license_file = LICENSE -classifiers = - Development Status :: 2 - Pre-Alpha - License :: OSI Approved :: BSD License - Natural Language :: English - Programming Language :: Python :: 3 - Programming Language :: Python :: 3 :: Only - Programming Language :: Python :: 3.7 - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 -project_urls = - Source Code =https://github.com/tlambert03/anyfft - -[options] -packages = find: -install_requires = - numpy - reikna -python_requires = >=3.7 -setup_requires = - setuptools-scm -zip_safe = False - -[options.extras_require] -dev = - black - flake8 - flake8-docstrings - ipython - isort - jedi<0.18.0 - mypy - pre-commit - pydocstyle -ocl = - pyopencl<2021.2.7 -testing = - pyfftw - pyopencl<2021.2.7 - pytest - pytest-benchmark - pytest-cov - scipy - tox - tox-conda - -[bdist_wheel] -universal = 1 - -[flake8] -exclude = docs,_version.py,.eggs,examples -max-line-length = 88 -docstring-convention = numpy -ignore = D100, D213, D401, D413, D107, W503 - -[isort] -profile = black -src_paths = anyfft - -[pydocstyle] -match_dir = anyfft -convention = numpy -add_select = D402,D415,D417 -ignore = D100, D213, D401, D413, D107 - -[tool:pytest] -addopts = --benchmark-autosave --benchmark-columns=min,max,mean,stddev,rounds -filterwarnings = - error::: - ignore:`np.bool` is a deprecated alias::reikna - ignore:`np.bool` is a deprecated alias::pycuda - ignore:Non-empty compiler output encountered:: - -[mypy] -files = anyfft -warn_unused_configs = True -warn_unused_ignores = True -check_untyped_defs = True -implicit_reexport = False -show_column_numbers = True -show_error_codes = True -ignore_missing_imports = True diff --git a/setup.py b/setup.py deleted file mode 100644 index 0277681..0000000 --- a/setup.py +++ /dev/null @@ -1,3 +0,0 @@ -import setuptools - -setuptools.setup(use_scm_version={"write_to": "anyfft/_version.py"}) diff --git a/anyfft/__init__.py b/src/anyfft/__init__.py similarity index 63% rename from anyfft/__init__.py rename to src/anyfft/__init__.py index 54738cf..7385f14 100644 --- a/anyfft/__init__.py +++ b/src/anyfft/__init__.py @@ -1,3 +1,13 @@ +"""wrapper of various fft libraries, with standard interface.""" + +from importlib.metadata import PackageNotFoundError, version + +try: + __version__ = version("anyfft") +except PackageNotFoundError: + # package is not installed + __version__ = "unknown" + from ._fft import ( # fftfreq,; rfftfreq, fft, fft2, @@ -34,5 +44,3 @@ "ifftshift", "ReiknaBackend", ] - -from ._backend import ReiknaBackend diff --git a/anyfft/_fft.py b/src/anyfft/_fft.py similarity index 69% rename from anyfft/_fft.py rename to src/anyfft/_fft.py index 80d657c..15c656e 100644 --- a/anyfft/_fft.py +++ b/src/anyfft/_fft.py @@ -17,10 +17,10 @@ def _get_fft_module(plugin): try: return importlib.import_module(f"anyfft.{plugin}") - except KeyError: + except KeyError as e: raise ValueError( f"unrecognized plugin: {plugin!r}. Options include: {set(_PLUGINS)}" - ) + ) from e def _implements(scipy_func): @@ -29,80 +29,66 @@ def _inner(*args, plugin="reikna", **kwargs): f = getattr(_get_fft_module(plugin), func.__name__) return f(*args, **kwargs) - _inner.__doc__ == scipy_func.__doc__ + _inner.__doc__ = scipy_func.__doc__ return _inner return plugin_func @_implements(scipy.fft.fft) -def fft(*args, **kwargs): - ... +def fft(*args, **kwargs): ... @_implements(scipy.fft.ifft) -def ifft(*args, **kwargs): - ... +def ifft(*args, **kwargs): ... @_implements(scipy.fft.fft2) -def fft2(*args, **kwargs): - ... +def fft2(*args, **kwargs): ... @_implements(scipy.fft.ifft2) -def ifft2(*args, **kwargs): - ... +def ifft2(*args, **kwargs): ... @_implements(scipy.fft.fftn) -def fftn(*args, **kwargs): - ... +def fftn(*args, **kwargs): ... @_implements(scipy.fft.ifftn) -def ifftn(*args, **kwargs): - ... +def ifftn(*args, **kwargs): ... @_implements(scipy.fft.rfft) -def rfft(*args, **kwargs): - ... +def rfft(*args, **kwargs): ... @_implements(scipy.fft.irfft) -def irfft(*args, **kwargs): - ... +def irfft(*args, **kwargs): ... @_implements(scipy.fft.rfft2) -def rfft2(*args, **kwargs): - ... +def rfft2(*args, **kwargs): ... @_implements(scipy.fft.irfft2) -def irfft2(*args, **kwargs): - ... +def irfft2(*args, **kwargs): ... @_implements(scipy.fft.rfftn) -def rfftn(*args, **kwargs): - ... +def rfftn(*args, **kwargs): ... @_implements(scipy.fft.irfftn) -def irfftn(*args, **kwargs): - ... +def irfftn(*args, **kwargs): ... @_implements(scipy.fft.fftshift) -def fftshift(*args, **kwargs): - ... +def fftshift(*args, **kwargs): ... @_implements(scipy.fft.ifftshift) -def ifftshift(*args, **kwargs): - ... +def ifftshift(*args, **kwargs): ... # @_implements(scipy.fft.fftfreq) diff --git a/src/anyfft/_fftconv.py b/src/anyfft/_fftconv.py new file mode 100644 index 0000000..3e9fb22 --- /dev/null +++ b/src/anyfft/_fftconv.py @@ -0,0 +1,347 @@ +import operator +from numbers import Number + +import numpy as np +import scipy.fft as sp_fft + + +def fftconvolve(in1, in2, mode="full", axes=None): + """Convolve two N-dimensional arrays using FFT. + + Convolve `in1` and `in2` using the fast Fourier transform method, with + the output size determined by the `mode` argument. + + This is generally much faster than `convolve` for large arrays (n > ~500), + but can be slower when only a few output values are needed, and can only + output float arrays (int or object array inputs will be cast to float). + + As of v0.19, `convolve` automatically chooses this method or the direct + method based on an estimation of which is faster. + + Parameters + ---------- + in1 : array_like + First input. + in2 : array_like + Second input. Should have the same number of dimensions as `in1`. + mode : str {'full', 'valid', 'same'}, optional + A string indicating the size of the output: + + ``full`` + The output is the full discrete linear convolution + of the inputs. (Default) + ``valid`` + The output consists only of those elements that do not + rely on the zero-padding. In 'valid' mode, either `in1` or `in2` + must be at least as large as the other in every dimension. + ``same`` + The output is the same size as `in1`, centered + with respect to the 'full' output. + axes : int or array_like of ints or None, optional + Axes over which to compute the convolution. + The default is over all axes. + + Returns + ------- + out : array + An N-dimensional array containing a subset of the discrete linear + convolution of `in1` with `in2`. + + See Also + -------- + convolve : Uses the direct convolution or FFT convolution algorithm + depending on which is faster. + oaconvolve : Uses the overlap-add method to do convolution, which is + generally faster when the input arrays are large and + significantly different in size. + + Examples + -------- + Autocorrelation of white noise is an impulse. + + >>> import numpy as np + >>> from scipy import signal + >>> rng = np.random.default_rng() + >>> sig = rng.standard_normal(1000) + >>> autocorr = signal.fftconvolve(sig, sig[::-1], mode="full") + + >>> import matplotlib.pyplot as plt + >>> fig, (ax_orig, ax_mag) = plt.subplots(2, 1) + >>> ax_orig.plot(sig) + >>> ax_orig.set_title("White noise") + >>> ax_mag.plot(np.arange(-len(sig) + 1, len(sig)), autocorr) + >>> ax_mag.set_title("Autocorrelation") + >>> fig.tight_layout() + >>> fig.show() + + Gaussian blur implemented using FFT convolution. Notice the dark borders + around the image, due to the zero-padding beyond its boundaries. + The `convolve2d` function allows for other types of image boundaries, + but is far slower. + + >>> from scipy import datasets + >>> face = datasets.face(gray=True) + >>> kernel = np.outer( + ... signal.windows.gaussian(70, 8), signal.windows.gaussian(70, 8) + ... ) + >>> blurred = signal.fftconvolve(face, kernel, mode="same") + + >>> fig, (ax_orig, ax_kernel, ax_blurred) = plt.subplots(3, 1, figsize=(6, 15)) + >>> ax_orig.imshow(face, cmap="gray") + >>> ax_orig.set_title("Original") + >>> ax_orig.set_axis_off() + >>> ax_kernel.imshow(kernel, cmap="gray") + >>> ax_kernel.set_title("Gaussian kernel") + >>> ax_kernel.set_axis_off() + >>> ax_blurred.imshow(blurred, cmap="gray") + >>> ax_blurred.set_title("Blurred") + >>> ax_blurred.set_axis_off() + >>> fig.show() + + """ + # in1 = np.asarray(in1) + # in2 = np.asarray(in2) + + if in1.ndim == in2.ndim == 0: # scalar inputs + return in1 * in2 + elif in1.ndim != in2.ndim: + raise ValueError("in1 and in2 should have the same dimensionality") + elif in1.size == 0 or in2.size == 0: # empty arrays + return np.array([]) + + in1, in2, axes = _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False) + + s1 = in1.shape + s2 = in2.shape + + shape = [ + max((s1[i], s2[i])) if i not in axes else s1[i] + s2[i] - 1 + for i in range(in1.ndim) + ] + + ret = _freq_domain_conv(in1, in2, axes, shape, calc_fast_len=True) + + return _apply_conv_mode(ret, s1, s2, mode, axes) + + +def _apply_conv_mode(ret, s1, s2, mode, axes): + """Calculate the convolution result shape based on the `mode` argument. + + Returns the result sliced to the correct size for the given mode. + + Parameters + ---------- + ret : array + The result array, with the appropriate shape for the 'full' mode. + s1 : list of int + The shape of the first input. + s2 : list of int + The shape of the second input. + mode : str {'full', 'valid', 'same'} + A string indicating the size of the output. + See the documentation `fftconvolve` for more information. + axes : list of ints + Axes over which to compute the convolution. + + Returns + ------- + ret : array + A copy of `res`, sliced to the correct size for the given `mode`. + + """ + if mode == "full": + return ret + elif mode == "same": + return _centered(ret, s1) + elif mode == "valid": + shape_valid = [ + ret.shape[a] if a not in axes else s1[a] - s2[a] + 1 + for a in range(ret.ndim) + ] + return _centered(ret, shape_valid) + else: + raise ValueError("acceptable mode flags are 'valid'," " 'same', or 'full'") + + +def _centered(arr, newshape): + # Return the center newshape portion of the array. + newshape = np.asarray(newshape) + currshape = np.array(arr.shape) + startind = (currshape - newshape) // 2 + endind = startind + newshape + myslice = [slice(startind[k], endind[k]) for k in range(len(endind))] + return arr[tuple(myslice)] + + +def _freq_domain_conv(in1, in2, axes, shape, calc_fast_len=False): + """Convolve two arrays in the frequency domain. + + This function implements only base the FFT-related operations. + Specifically, it converts the signals to the frequency domain, multiplies + them, then converts them back to the time domain. Calculations of axes, + shapes, convolution mode, etc. are implemented in higher level-functions, + such as `fftconvolve` and `oaconvolve`. Those functions should be used + instead of this one. + + Parameters + ---------- + in1 : array_like + First input. + in2 : array_like + Second input. Should have the same number of dimensions as `in1`. + axes : array_like of ints + Axes over which to compute the FFTs. + shape : array_like of ints + The sizes of the FFTs. + calc_fast_len : bool, optional + If `True`, set each value of `shape` to the next fast FFT length. + Default is `False`, use `axes` as-is. + + Returns + ------- + out : array + An N-dimensional array containing the discrete linear convolution of + `in1` with `in2`. + + """ + if not len(axes): + return in1 * in2 + + if hasattr(in1.dtype, "kind"): + complex_result = in1.dtype.kind == "c" or in2.dtype.kind == "c" + else: + complex_result = in1.dtype.is_complex + + if calc_fast_len: + # Speed up FFT by padding to optimal size. + fshape = [sp_fft.next_fast_len(shape[a], not complex_result) for a in axes] + else: + fshape = shape + + if not complex_result: + fft, ifft = sp_fft.rfftn, sp_fft.irfftn + else: + fft, ifft = sp_fft.fftn, sp_fft.ifftn + + sp1 = fft(in1, fshape, axes=axes) + sp2 = fft(in2, fshape, axes=axes) + + if sp1.shape != sp2.shape: + breakpoint() + ret = ifft(sp1 * sp2, fshape, axes=axes) + + if calc_fast_len: + fslice = tuple([slice(sz) for sz in shape]) + ret = ret[fslice] + + return ret + + +def _init_freq_conv_axes(in1, in2, mode, axes, sorted_axes=False): + """Handle the axes argument for frequency-domain convolution. + + Returns the inputs and axes in a standard form, eliminating redundant axes, + swapping the inputs if necessary, and checking for various potential + errors. + + Parameters + ---------- + in1 : array + First input. + in2 : array + Second input. + mode : str {'full', 'valid', 'same'}, optional + A string indicating the size of the output. + See the documentation `fftconvolve` for more information. + axes : list of ints + Axes over which to compute the FFTs. + sorted_axes : bool, optional + If `True`, sort the axes. + Default is `False`, do not sort. + + Returns + ------- + in1 : array + The first input, possible swapped with the second input. + in2 : array + The second input, possible swapped with the first input. + axes : list of ints + Axes over which to compute the FFTs. + + """ + s1 = in1.shape + s2 = in2.shape + noaxes = axes is None + + _, axes = _init_nd_shape_and_axes(in1, shape=None, axes=axes) + + if not noaxes and not len(axes): + raise ValueError("when provided, axes cannot be empty") + + # Axes of length 1 can rely on broadcasting rules for multiply, + # no fft needed. + axes = [a for a in axes if s1[a] != 1 and s2[a] != 1] + + if sorted_axes: + axes.sort() + + if not all( + s1[a] == s2[a] or s1[a] == 1 or s2[a] == 1 + for a in range(in1.ndim) + if a not in axes + ): + raise ValueError("incompatible shapes for in1 and in2:" f" {s1} and {s2}") + + return in1, in2, axes + + +def _init_nd_shape_and_axes(x, shape, axes): + """Handles shape and axes arguments for nd transforms""" + noshape = shape is None + noaxes = axes is None + + if not noaxes: + axes = _iterable_of_int(axes, "axes") + axes = [a + x.ndim if a < 0 else a for a in axes] + + if any(a >= x.ndim or a < 0 for a in axes): + raise ValueError("axes exceeds dimensionality of input") + if len(set(axes)) != len(axes): + raise ValueError("all axes must be unique") + + if not noshape: + shape = _iterable_of_int(shape, "shape") + + if axes and len(axes) != len(shape): + raise ValueError( + "when given, axes and shape arguments" " have to be of the same length" + ) + if noaxes: + if len(shape) > x.ndim: + raise ValueError("shape requires more axes than are present") + axes = range(x.ndim - len(shape), x.ndim) + + shape = [x.shape[a] if s == -1 else s for s, a in zip(shape, axes)] + elif noaxes: + shape = list(x.shape) + axes = range(x.ndim) + else: + shape = [x.shape[a] for a in axes] + + if any(s < 1 for s in shape): + raise ValueError(f"invalid number of data points ({shape}) specified") + + return tuple(shape), list(axes) + + +def _iterable_of_int(x, name=None): + if isinstance(x, Number): + x = (x,) + + try: + x = [operator.index(a) for a in x] + except TypeError as e: + name = name or "value" + raise ValueError(f"{name} must be a scalar or iterable of integers") from e + + return x diff --git a/src/anyfft/backends/__init__.py b/src/anyfft/backends/__init__.py new file mode 100644 index 0000000..07f01d9 --- /dev/null +++ b/src/anyfft/backends/__init__.py @@ -0,0 +1,5 @@ +from ._jax import JaxBackend +from ._reikna import ReiknaBackend +from ._torch import TorchBackend + +__all__ = ["ReiknaBackend", "TorchBackend", "JaxBackend"] diff --git a/src/anyfft/backends/_base.py b/src/anyfft/backends/_base.py new file mode 100644 index 0000000..29ee2ee --- /dev/null +++ b/src/anyfft/backends/_base.py @@ -0,0 +1,20 @@ +import importlib +from types import FunctionType +from typing import Any + + +class _FFTBackend: + _source_: str = "anyfft.reikna" + __ua_domain__ = "numpy.scipy.fft" + + @classmethod + def __ua_function__(cls, method: FunctionType, args: Any, kwargs: Any) -> Any: + module = importlib.import_module(cls._source_) + if fn := getattr(module, method.__name__, None): + return cls.execute(fn, *args, **kwargs) + breakpoint() + return NotImplemented + + @classmethod + def execute(cls, func, *args: Any, **kwargs: Any) -> Any: + return func(*args, **kwargs) diff --git a/src/anyfft/backends/_jax.py b/src/anyfft/backends/_jax.py new file mode 100644 index 0000000..0ebe085 --- /dev/null +++ b/src/anyfft/backends/_jax.py @@ -0,0 +1,21 @@ +from typing import Any + +import jax.numpy as jnp + +from ._base import _FFTBackend + + +class JaxBackend(_FFTBackend): + _source_: str = "jax.numpy.fft" + + @classmethod + def __ua_convert__(cls, dispatchables, coerce): + if coerce: + return [jnp.asarray(d.value) for d in dispatchables] + return NotImplemented + + @classmethod + def execute(cls, func, *args: Any, **kwargs: Any) -> Any: + # if "axes" in kwargs: + # kwargs["dim"] = kwargs.pop("axes") + return func(*args, **kwargs) diff --git a/src/anyfft/backends/_reikna.py b/src/anyfft/backends/_reikna.py new file mode 100644 index 0000000..96915ae --- /dev/null +++ b/src/anyfft/backends/_reikna.py @@ -0,0 +1,5 @@ +from ._base import _FFTBackend + + +class ReiknaBackend(_FFTBackend): + _source_: str = "anyfft.reikna" diff --git a/src/anyfft/backends/_torch.py b/src/anyfft/backends/_torch.py new file mode 100644 index 0000000..7b84aa5 --- /dev/null +++ b/src/anyfft/backends/_torch.py @@ -0,0 +1,22 @@ +from typing import Any + +import torch +import torch.fft + +from ._base import _FFTBackend + + +class TorchBackend(_FFTBackend): + _source_: str = "torch.fft" + + @classmethod + def __ua_convert__(cls, dispatchables, coerce): + if coerce: + return [torch.as_tensor(d.value) for d in dispatchables] + return NotImplemented + + @classmethod + def execute(cls, func, *args: Any, **kwargs: Any) -> Any: + if "axes" in kwargs: + kwargs["dim"] = kwargs.pop("axes") + return func(*args, **kwargs) diff --git a/anyfft/cupy.py b/src/anyfft/cupy.py similarity index 100% rename from anyfft/cupy.py rename to src/anyfft/cupy.py diff --git a/anyfft/fftpack.py b/src/anyfft/fftpack.py similarity index 100% rename from anyfft/fftpack.py rename to src/anyfft/fftpack.py diff --git a/anyfft/gpyfft/__init__.py b/src/anyfft/gpyfft/__init__.py similarity index 100% rename from anyfft/gpyfft/__init__.py rename to src/anyfft/gpyfft/__init__.py diff --git a/anyfft/gpyfft/_fft.py b/src/anyfft/gpyfft/_fft.py similarity index 99% rename from anyfft/gpyfft/_fft.py rename to src/anyfft/gpyfft/_fft.py index 65f725a..ab5a237 100644 --- a/anyfft/gpyfft/_fft.py +++ b/src/anyfft/gpyfft/_fft.py @@ -3,10 +3,9 @@ from typing import TYPE_CHECKING, Union from warnings import filterwarnings +import numpy as np import pyopencl as cl import pyopencl.array as cla - -import numpy as np from gpyfft.fft import FFT from ._util import get_context @@ -35,7 +34,7 @@ def _normalize_axes(dshape, axes): try: return tuple(np.arange(len(dshape))[_axes]) except Exception as e: - raise TypeError(f"Cannot normalize axes {axes}: {e}") + raise TypeError(f"Cannot normalize axes {axes}: {e}") from e def _get_fft_plan(arr, axes=None, fast_math=False): diff --git a/anyfft/gpyfft/_util.py b/src/anyfft/gpyfft/_util.py similarity index 100% rename from anyfft/gpyfft/_util.py rename to src/anyfft/gpyfft/_util.py diff --git a/anyfft/numpy.py b/src/anyfft/numpy.py similarity index 100% rename from anyfft/numpy.py rename to src/anyfft/numpy.py diff --git a/src/anyfft/py.typed b/src/anyfft/py.typed new file mode 100644 index 0000000..07d3fbd --- /dev/null +++ b/src/anyfft/py.typed @@ -0,0 +1,5 @@ +You may remove this file if you don't intend to add types to your package + +Details at: + +https://mypy.readthedocs.io/en/stable/installed_packages.html#creating-pep-561-compatible-packages diff --git a/anyfft/pyfftw.py b/src/anyfft/pyfftw.py similarity index 79% rename from anyfft/pyfftw.py rename to src/anyfft/pyfftw.py index e5a4fad..1f1d251 100644 --- a/anyfft/pyfftw.py +++ b/src/anyfft/pyfftw.py @@ -1,4 +1,4 @@ -# import pyfftw.interfaces # noqa +# import pyfftw.interfaces from pyfftw.interfaces.scipy_fft import * # noqa # pyfftw.interfaces.cache.enable() diff --git a/anyfft/reikna/__init__.py b/src/anyfft/reikna/__init__.py similarity index 100% rename from anyfft/reikna/__init__.py rename to src/anyfft/reikna/__init__.py diff --git a/anyfft/reikna/_fft.py b/src/anyfft/reikna/_fft.py similarity index 99% rename from anyfft/reikna/_fft.py rename to src/anyfft/reikna/_fft.py index 9a80eae..97f984e 100644 --- a/anyfft/reikna/_fft.py +++ b/src/anyfft/reikna/_fft.py @@ -168,7 +168,7 @@ def _normalize_axes(dshape, axes): try: return tuple(np.arange(len(dshape))[_axes]) except Exception as e: - raise TypeError(f"Cannot normalize axes {axes}: {e}") + raise TypeError(f"Cannot normalize axes {axes}: {e}") from e def fft( diff --git a/anyfft/reikna/_fftconvolve.py b/src/anyfft/reikna/_fftconvolve.py similarity index 99% rename from anyfft/reikna/_fftconvolve.py rename to src/anyfft/reikna/_fftconvolve.py index 8afae4b..9bd1930 100644 --- a/anyfft/reikna/_fftconvolve.py +++ b/src/anyfft/reikna/_fftconvolve.py @@ -1,6 +1,5 @@ -from pyopencl.elementwise import ElementwiseKernel - import numpy as np +from pyopencl.elementwise import ElementwiseKernel from ._fft import fftn, ifftn from ._util import empty, get_thread, is_cluda_array, to_device diff --git a/anyfft/reikna/_fftshift.py b/src/anyfft/reikna/_fftshift.py similarity index 100% rename from anyfft/reikna/_fftshift.py rename to src/anyfft/reikna/_fftshift.py diff --git a/anyfft/reikna/_util.py b/src/anyfft/reikna/_util.py similarity index 100% rename from anyfft/reikna/_util.py rename to src/anyfft/reikna/_util.py diff --git a/anyfft/reikna/_version.py b/src/anyfft/reikna/_version.py similarity index 100% rename from anyfft/reikna/_version.py rename to src/anyfft/reikna/_version.py diff --git a/anyfft/reikna/from reikna.transformations import Annot.py b/src/anyfft/reikna/from reikna.transformations import Annot.py similarity index 90% rename from anyfft/reikna/from reikna.transformations import Annot.py rename to src/anyfft/reikna/from reikna.transformations import Annot.py index f1ead3d..08dcdf1 100644 --- a/anyfft/reikna/from reikna.transformations import Annot.py +++ b/src/anyfft/reikna/from reikna.transformations import Annot.py @@ -5,7 +5,7 @@ from reikna.transformations import Annotation, Parameter, Transformation api = cluda.ocl_api() -thr = api.Thread.create() +tread_ = api.Thread.create() def tform_r2c(arr): @@ -45,10 +45,10 @@ def tform_c2r(arr): plan.parameter.input.connect(r2c, r2c.output, new_input=r2c.input) c2r = tform_c2r(plan.parameter.output) plan.parameter.output.connect(c2r, c2r.input, new_output=c2r.output) -planc = plan.compile(thread=thr) +planc = plan.compile(thread=tread_) -arr_dev = thr.to_device(arr) -out_dev = thr.array(arr.shape, np.float32) +arr_dev = tread_.to_device(arr) +out_dev = tread_.array(arr.shape, np.float32) planc(out_dev, arr_dev) assert np.allclose(out_dev.get(), np.fft.fft(arr).real) diff --git a/anyfft/reikna_fft.py b/src/anyfft/reikna_fft.py similarity index 92% rename from anyfft/reikna_fft.py rename to src/anyfft/reikna_fft.py index 6aac7ff..975d1d7 100644 --- a/anyfft/reikna_fft.py +++ b/src/anyfft/reikna_fft.py @@ -32,7 +32,7 @@ # # fmt: off -''' +""" This module implements those functions that replace aspects of the :mod:`scipy.fft` module. This module *provides* the entire documented namespace of :mod:`scipy.fft`, but those functions that are not included here are @@ -42,12 +42,12 @@ equivalents in :mod:`scipy.fft`, though there are some corner cases in which this may not be true. -Some corner (mis)usages of :mod:`scipy.fft` may not transfer neatly. +Some corner misusages of :mod:`scipy.fft` may not transfer neatly. For example, using :func:`scipy.fft.fft2` with a non 1D array and a 2D `s` argument will return without exception whereas :func:`pyfftw.interfaces.scipy_fft.fft2` will raise a `ValueError`. -''' +""" import os import numpy as np @@ -103,7 +103,7 @@ def __ua_function__(method, args, kwargs): def _implements(scipy_func): - '''Decorator adds function to the dictionary of implemented functions''' + """Decorator adds function to the dictionary of implemented functions.""" def inner(func): _implemented[scipy_func] = func return func @@ -122,8 +122,8 @@ def _workers_to_threads(workers): if workers >= -_cpu_count: workers += 1 + _cpu_count else: - raise ValueError("workers value out of range; got {}, must not be" - " less than {}".format(workers, -_cpu_count)) + raise ValueError(f"workers value out of range; got {workers}, must not be" + f" less than {-_cpu_count}") elif workers == 0: raise ValueError("workers must not be zero") return workers @@ -132,24 +132,24 @@ def _workers_to_threads(workers): @_implements(_fft.fft) def fft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, *, plan=None): - '''Perform a 1D FFT. + """Perform a 1D FFT. The first six arguments are as per :func:`scipy.fft.fft`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ return reikna.fft(x, n, axis, norm, overwrite_x) @_implements(_fft.ifft) def ifft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 1D inverse FFT. + """Perform a 1D inverse FFT. The first six arguments are as per :func:`scipy.fft.ifft`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ threads = _workers_to_threads(workers) return _fft.ifft(x, n, axis, norm, overwrite_x, planner_effort, threads, auto_align_input, @@ -159,12 +159,12 @@ def ifft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, @_implements(_fft.fft2) def fft2(x, s=None, axes=(-2, -1), norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 2D FFT. + """Perform a 2D FFT. The first six arguments are as per :func:`scipy.fft.fft2`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ threads = _workers_to_threads(workers) return _fft.fft2(x, s, axes, norm, overwrite_x, planner_effort, threads, auto_align_input, auto_contiguous) @@ -173,12 +173,12 @@ def fft2(x, s=None, axes=(-2, -1), norm=None, overwrite_x=False, workers=None, @_implements(_fft.ifft2) def ifft2(x, s=None, axes=(-2, -1), norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 2D inverse FFT. + """Perform a 2D inverse FFT. The first six arguments are as per :func:`scipy.fft.ifft2`; the rest of the arguments are documented in the :ref:`additional argument docs `. - ''' + """ threads = _workers_to_threads(workers) return _fft.ifft2(x, s, axes, norm, overwrite_x, planner_effort, threads, auto_align_input, auto_contiguous) @@ -187,12 +187,12 @@ def ifft2(x, s=None, axes=(-2, -1), norm=None, overwrite_x=False, workers=None, @_implements(_fft.fftn) def fftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform an n-D FFT. + """Perform an n-D FFT. The first six arguments are as per :func:`scipy.fft.fftn`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ threads = _workers_to_threads(workers) return _fft.fftn(x, s, axes, norm, overwrite_x, planner_effort, threads, auto_align_input, auto_contiguous) @@ -201,12 +201,12 @@ def fftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None, @_implements(_fft.ifftn) def ifftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform an n-D inverse FFT. + """Perform an n-D inverse FFT. The first six arguments are as per :func:`scipy.fft.ifftn`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ threads = _workers_to_threads(workers) return _fft.ifftn(x, s, axes, norm, overwrite_x, planner_effort, threads, auto_align_input, auto_contiguous) @@ -215,12 +215,12 @@ def ifftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None, @_implements(_fft.rfft) def rfft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 1D real FFT. + """Perform a 1D real FFT. The first six arguments are as per :func:`scipy.fft.rfft`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ x = np.asanyarray(x) if x.dtype.kind == 'c': raise TypeError('x must be a real sequence') @@ -232,12 +232,12 @@ def rfft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, @_implements(_fft.irfft) def irfft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 1D real inverse FFT. + """Perform a 1D real inverse FFT. The first six arguments are as per :func:`scipy.fft.irfft`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ threads = _workers_to_threads(workers) return _fft.irfft(x, n, axis, norm, overwrite_x, planner_effort, threads, auto_align_input, auto_contiguous) @@ -246,12 +246,12 @@ def irfft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, @_implements(_fft.rfft2) def rfft2(x, s=None, axes=(-2, -1), norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 2D real FFT. + """Perform a 2D real FFT. The first six arguments are as per :func:`scipy.fft.rfft2`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ x = np.asanyarray(x) if x.dtype.kind == 'c': raise TypeError('x must be a real sequence') @@ -264,12 +264,12 @@ def rfft2(x, s=None, axes=(-2, -1), norm=None, overwrite_x=False, workers=None, def irfft2(x, s=None, axes=(-2, -1), norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 2D real inverse FFT. + """Perform a 2D real inverse FFT. The first six arguments are as per :func:`scipy.fft.irfft2`; the rest of the arguments are documented in the :ref:`additional argument docs `. - ''' + """ threads = _workers_to_threads(workers) return _fft.irfft2(x, s, axes, norm, overwrite_x, planner_effort, threads, auto_align_input, auto_contiguous) @@ -278,12 +278,12 @@ def irfft2(x, s=None, axes=(-2, -1), norm=None, overwrite_x=False, @_implements(_fft.rfftn) def rfftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform an n-D real FFT. + """Perform an n-D real FFT. The first six arguments are as per :func:`scipy.fft.rfftn`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ x = np.asanyarray(x) if x.dtype.kind == 'c': raise TypeError('x must be a real sequence') @@ -295,12 +295,12 @@ def rfftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None, @_implements(_fft.irfftn) def irfftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform an n-D real inverse FFT. + """Perform an n-D real inverse FFT. The first six arguments are as per :func:`scipy.fft.irfftn`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ threads = _workers_to_threads(workers) return _fft.irfftn(x, s, axes, norm, overwrite_x, planner_effort, threads, auto_align_input, auto_contiguous) @@ -309,12 +309,12 @@ def irfftn(x, s=None, axes=None, norm=None, overwrite_x=False, workers=None, @_implements(_fft.hfft) def hfft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 1D Hermitian FFT. + """Perform a 1D Hermitian FFT. The first six arguments are as per :func:`scipy.fft.hfft`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ threads = _workers_to_threads(workers) return _fft.hfft(x, n, axis, norm, overwrite_x, planner_effort, threads, auto_align_input, auto_contiguous) @@ -323,12 +323,12 @@ def hfft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, @_implements(_fft.ihfft) def ihfft(x, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, planner_effort=None, auto_align_input=True, auto_contiguous=True): - '''Perform a 1D Hermitian inverse FFT. + """Perform a 1D Hermitian inverse FFT. The first six arguments are as per :func:`scipy.fft.ihfft`; the rest of the arguments are documented in the :ref:`additional argument docs`. - ''' + """ x = np.asanyarray(x) if x.dtype.kind == 'c': raise TypeError('x must be a real sequence') diff --git a/anyfft/scipy.py b/src/anyfft/scipy.py similarity index 100% rename from anyfft/scipy.py rename to src/anyfft/scipy.py diff --git a/anyfft/reikna/tests/test_reikna_fft.py b/tests/reikna/test_reikna_fft.py similarity index 97% rename from anyfft/reikna/tests/test_reikna_fft.py rename to tests/reikna/test_reikna_fft.py index e314137..655bc51 100644 --- a/anyfft/reikna/tests/test_reikna_fft.py +++ b/tests/reikna/test_reikna_fft.py @@ -1,12 +1,12 @@ +import numpy as np +import numpy.testing as npt import pytest +from scipy import datasets, fftpack, signal import anyfft.reikna -import numpy as np -import numpy.testing as npt from anyfft.reikna._util import empty_like, to_device -from scipy import fftpack, misc, signal -FACE = misc.face(gray=True).astype("float32") +FACE = datasets.face(gray=True).astype("float32") KERNEL = np.outer( signal.windows.gaussian(70, 8), signal.windows.gaussian(70, 8) ).astype("float32") @@ -62,7 +62,6 @@ def test_fft_inplace(): def test_fft_errors(): - with pytest.raises(TypeError): # existing OCLArray must be of complex type anyfft.reikna.fftn(to_device(FACE)) diff --git a/anyfft/tests/test_benchmarks.py b/tests/test_benchmarks.py similarity index 100% rename from anyfft/tests/test_benchmarks.py rename to tests/test_benchmarks.py index 5ceacb9..b5987d7 100644 --- a/anyfft/tests/test_benchmarks.py +++ b/tests/test_benchmarks.py @@ -1,10 +1,10 @@ -import pytest - -import anyfft import numpy as np import numpy.testing as npt +import pytest import scipy +import anyfft + @pytest.fixture(scope="session") def img(): diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 0c6805d..0000000 --- a/tox.ini +++ /dev/null @@ -1,28 +0,0 @@ -[tox] -envlist = py{37,38,39}-{linux,macos,windows} -toxworkdir=/tmp/.tox - -[gh-actions] -python = - 3.7: py37 - 3.8: py38 - 3.9: py39 - -[gh-actions:env] -PLATFORM = - ubuntu-latest: linux - macos-latest: macos - windows-latest: windows - -[testenv] -platform = - macos: darwin - linux: linux - windows: win32 -passenv = CI GITHUB_ACTIONS DISPLAY XAUTHORITY -setenv = - PYTHONPATH = {toxinidir} -extras = - testing -commands = - pytest -v --color=yes --basetemp={envtmpdir} {posargs} diff --git a/x.py b/x.py new file mode 100644 index 0000000..95a3734 --- /dev/null +++ b/x.py @@ -0,0 +1,60 @@ +import time +from typing import Any + +import matplotlib.pyplot as plt +import numpy as np +import torch +from scipy import fft +from skimage.data import cells3d + +# from scipy.signal import fftconvolve +from anyfft._fftconv import fftconvolve +from anyfft.backends import JaxBackend, TorchBackend + +ARY = cells3d()[:, 1].astype(np.float32) +ARY = np.tile(ARY, (7, 2, 2)) +KERN = np.zeros((32, 256, 256), dtype=np.float32) +KERN[6, 40, 40] = 1 +KERN[9, 64, 64] = 1 +KERN[12, 72, 72] = 1 +EXPECTED = None +DEVICE = "cpu" + + +def time_backend(backend: Any) -> np.ndarray: + ary = ARY + if backend is TorchBackend: + ary = torch.tensor(ARY, device=DEVICE, dtype=torch.float32) + kern = KERN + if backend is TorchBackend: + kern = torch.tensor(KERN, device=DEVICE, dtype=torch.float32) + start = time.perf_counter() + with fft.set_backend(backend, coerce=True): + result = fftconvolve(ary, kern, mode="same") + end = time.perf_counter() - start + if hasattr(result, "numpy"): + result = result.cpu().numpy() + elif hasattr(result, "get"): + result = result.get() + print(backend, "took", end) + print("L2 norm:", np.linalg.norm(result)) + if EXPECTED is not None: + if not np.allclose(result, EXPECTED, rtol=1e-2): + print(" not close") + return result # type: ignore + + +results = [] +EXPECTED = time_backend("scipy") +results.append(EXPECTED) +# time_backend(ReiknaBackend) + +results.append(time_backend(JaxBackend)) +results.append(time_backend(TorchBackend)) + + +fig, ax = plt.subplots(2, len(results), figsize=(12, 4)) +for i, r in enumerate(results): + ax[0, i].imshow(r[ARY.shape[0] // 2], cmap="gray") + ax[1, i].imshow(r[:, ARY.shape[1] // 2], cmap="gray") +plt.show()