diff --git a/setup.py b/setup.py index 196a6ba655ce..a4327e32eed1 100755 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ 'sympy.crypto', 'sympy.deprecated', 'sympy.diffgeom', + 'sympy.discrete', 'sympy.external', 'sympy.functions', 'sympy.functions.combinatorial', @@ -312,6 +313,7 @@ def run(self): 'sympy.crypto.tests', 'sympy.deprecated.tests', 'sympy.diffgeom.tests', + 'sympy.discrete.tests', 'sympy.external.tests', 'sympy.functions.combinatorial.tests', 'sympy.functions.elementary.tests', diff --git a/sympy/__init__.py b/sympy/__init__.py index 21936f61b96a..982b6fa66c88 100644 --- a/sympy/__init__.py +++ b/sympy/__init__.py @@ -62,6 +62,7 @@ def __sympy_debug(): from .functions import * from .ntheory import * from .concrete import * +from .discrete import * from .simplify import * from .sets import * from .solvers import * diff --git a/sympy/discrete/__init__.py b/sympy/discrete/__init__.py new file mode 100644 index 000000000000..39b406ab7269 --- /dev/null +++ b/sympy/discrete/__init__.py @@ -0,0 +1,9 @@ +"""A module containing discrete functions. + +Transforms - fft, ifft, ntt, intt, fwht, ifwht, fzt, ifzt, fmt, ifmt +Convolution - conv, conv_xor, conv_and, conv_or, conv_sub, conv_sup +Recurrence Evaluation - reval_hcc +""" + + +from .transforms import (fft, ifft, ntt, intt) diff --git a/sympy/discrete/tests/__init__.py b/sympy/discrete/tests/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/sympy/discrete/tests/test_transforms.py b/sympy/discrete/tests/test_transforms.py new file mode 100644 index 000000000000..d053903a4909 --- /dev/null +++ b/sympy/discrete/tests/test_transforms.py @@ -0,0 +1,61 @@ +from __future__ import print_function, division + +from sympy import sqrt +from sympy.core import S, Symbol, I +from sympy.core.compatibility import range +from sympy.discrete import fft, ifft, ntt, intt +from sympy.utilities.pytest import raises + + +def test_fft_ifft(): + assert all(tf(ls) == ls for tf in (fft, ifft) + for ls in ([], [S(5)/3])) + + ls = list(range(6)) + fls = [15, -7*sqrt(2)/2 - 4 - sqrt(2)*I/2 + 2*I, 2 + 3*I, + -4 + 7*sqrt(2)/2 - 2*I - sqrt(2)*I/2, -3, + -4 + 7*sqrt(2)/2 + sqrt(2)*I/2 + 2*I, + 2 - 3*I, -7*sqrt(2)/2 - 4 - 2*I + sqrt(2)*I/2] + + assert fft(ls) == fls + assert ifft(fls) == ls + [S.Zero]*2 + + ls = [1 + 2*I, 3 + 4*I, 5 + 6*I] + ifls = [S(9)/4 + 3*I, -7*I/4, S(3)/4 + I, -2 - I/4] + + assert ifft(ls) == ifls + assert fft(ifls) == ls + [S.Zero] + + x = Symbol('x', real=True) + raises(TypeError, lambda: fft(x)) + raises(ValueError, lambda: ifft([x, 2*x, 3*x**2, 4*x**3])) + + +def test_ntt_intt(): + # prime moduli of the form (m*2**k + 1), sequence length + # should be a divisor of 2**k + p = 7*17*2**23 + 1 + q = 2*500000003 + 1 # only for sequences of length 1 or 2 + r = 2*3*5*7 # composite modulus + + assert all(tf(ls, p) == ls for tf in (ntt, intt) + for ls in ([], [5])) + + ls = list(range(6)) + nls = [15, 801133602, 738493201, 334102277, 998244350, 849020224, + 259751156, 12232587] + + assert ntt(ls, p) == nls + assert intt(nls, p) == ls + [0]*2 + + ls = [1 + 2*I, 3 + 4*I, 5 + 6*I] + x = Symbol('x', integer=True) + + raises(TypeError, lambda: ntt(x, p)) + raises(ValueError, lambda: intt([x, 2*x, 3*x**2, 4*x**3], p)) + raises(ValueError, lambda: intt(ls, p)) + raises(ValueError, lambda: ntt([1.2, 2.1, 3.5], p)) + raises(ValueError, lambda: ntt([3, 5, 6], q)) + raises(ValueError, lambda: ntt([4, 5, 7], r)) + + assert ntt([1.0, 2.0, 3.0], p) == ntt([1, 2, 3], p) diff --git a/sympy/discrete/transforms.py b/sympy/discrete/transforms.py new file mode 100644 index 000000000000..1c72cc9a9b82 --- /dev/null +++ b/sympy/discrete/transforms.py @@ -0,0 +1,241 @@ +""" +Discrete Fourier Transform, Number Theoretic Transform, +Walsh Hadamard Transform, Zeta Transform, Mobius Transform +""" +from __future__ import print_function, division + +from sympy.core import S, Symbol, sympify +from sympy.core.compatibility import as_int, range, iterable +from sympy.core.function import expand, expand_mul +from sympy.core.numbers import pi, I +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import sin, cos +from sympy.ntheory import isprime, primitive_root +from sympy.utilities.iterables import ibin + + +#----------------------------------------------------------------------------# +# # +# Discrete Fourier Transform # +# # +#----------------------------------------------------------------------------# + +def _fourier_transform(seq, dps, inverse=False): + """Utility function for the Discrete Fourier Transform (DFT)""" + + if not iterable(seq): + raise TypeError("Expected a sequence of numeric coefficients " + + "for Fourier Transform") + + a = [sympify(arg) for arg in seq] + if any(x.has(Symbol) for x in a): + raise ValueError("Expected non-symbolic coefficients") + + n = len(a) + if n < 2: + return a + + b = n.bit_length() - 1 + if n&(n - 1): # not a power of 2 + b += 1 + n = 2**b + + a += [S.Zero]*(n - len(a)) + for i in range(1, n): + j = int(ibin(i, b, str=True)[::-1], 2) + if i < j: + a[i], a[j] = a[j], a[i] + + ang = -2*pi/n if inverse else 2*pi/n + + if dps is not None: + ang = ang.evalf(dps + 2) + + w = [cos(ang*i) + I*sin(ang*i) for i in range(n // 2)] + + h = 2 + while h <= n: + hf, ut = h // 2, n // h + for i in range(0, n, h): + for j in range(hf): + u, v = a[i + j], expand_mul(a[i + j + hf]*w[ut * j]) + a[i + j], a[i + j + hf] = u + v, u - v + h *= 2 + + if inverse: + a = [(x/n).evalf(dps) for x in a] if dps is not None \ + else [x/n for x in a] + + return a + + +def fft(seq, dps=None): + r""" + Performs the Discrete Fourier Transform (DFT) in the complex domain. + + The sequence is automatically padded to the right with zeros, as the + radix 2 FFT requires the number of sample points to be a power of 2. + + Parameters + ========== + + seq : iterable + The sequence on which DFT is to be applied. + dps : Integer + Specifies the number of decimal digits for precision. + + Examples + ======== + + >>> from sympy import fft, ifft + + >>> fft([1, 2, 3, 4]) + [10, -2 - 2*I, -2, -2 + 2*I] + >>> ifft(_) + [1, 2, 3, 4] + + >>> ifft([1, 2, 3, 4]) + [5/2, -1/2 + I/2, -1/2, -1/2 - I/2] + >>> fft(_) + [1, 2, 3, 4] + + >>> ifft([1, 7, 3, 4], dps=15) + [3.75, -0.5 - 0.75*I, -1.75, -0.5 + 0.75*I] + >>> fft(_) + [1.0, 7.0, 3.0, 4.0] + + >>> fft([5]) + [5] + >>> ifft([7]) + [7] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm + .. [2] http://mathworld.wolfram.com/FastFourierTransform.html + + """ + + return _fourier_transform(seq, dps=dps) + + +def ifft(seq, dps=None): + return _fourier_transform(seq, dps=dps, inverse=True) + +ifft.__doc__ = fft.__doc__ + + +#----------------------------------------------------------------------------# +# # +# Number Theoretic Transform # +# # +#----------------------------------------------------------------------------# + +def _number_theoretic_transform(seq, p, inverse=False): + """Utility function for the Number Theoretic transform (NTT)""" + + if not iterable(seq): + raise TypeError("Expected a sequence of integer coefficients " + + "for Number Theoretic Transform") + + p = as_int(p) + if isprime(p) == False: + raise ValueError("Expected prime modulus for " + + "Number Theoretic Transform") + + a = [as_int(x) for x in seq] + + n = len(a) + if n < 1: + return a + + b = n.bit_length() - 1 + if n&(n - 1): + b += 1 + n = 2**b + + if (p - 1) % n: + raise ValueError("Expected prime modulus of the form (m*2**k + 1)") + + a += [0]*(n - len(a)) + for i in range(1, n): + j = int(ibin(i, b, str=True)[::-1], 2) + if i < j: + a[i], a[j] = a[j] % p, a[i] % p + + pr = primitive_root(p) + + rt = pow(pr, (p - 1) // n, p) + if inverse: + rt = pow(rt, p - 2, p) + + w = [1]*(n // 2) + for i in range(1, n // 2): + w[i] = w[i - 1] * rt % p + + h = 2 + while h <= n: + hf, ut = h // 2, n // h + for i in range(0, n, h): + for j in range(hf): + u, v = a[i + j], a[i + j + hf]*w[ut * j] + a[i + j], a[i + j + hf] = (u + v) % p, (u - v) % p + h *= 2 + + if inverse: + rv = pow(n, p - 2, p) + for i in range(n): + a[i] = a[i]*rv % p + + return a + + +def ntt(seq, p): + r""" + Performs the Number Theoretic Transform (NTT), which specializes the + Discrete Fourier Transform (DFT) over quotient ring Z/pZ for prime p + instead of complex numbers C. + + + The sequence is automatically padded to the right with zeros, as the + radix 2 NTT requires the number of sample points to be a power of 2. + + Parameters + ========== + + seq : iterable + The sequence on which DFT is to be applied. + p : Integer + Prime modulus of the form (m*2**k + 1) to be used for performing + NTT on the sequence. + + Examples + ======== + + >>> from sympy import ntt, intt + >>> ntt([1, 2, 3, 4], 3*2**8 + 1) + [10, 643, 767, 122] + >>> intt(_, 3*2**8 + 1) + [1, 2, 3, 4] + >>> intt([1, 2, 3, 4], 3*2**8 + 1) + [387, 415, 384, 353] + >>> ntt(_, 3*2**8 + 1) + [1, 2, 3, 4] + + References + ========== + + .. [1] http://www.apfloat.org/ntt.html + .. [2] http://mathworld.wolfram.com/NumberTheoreticTransform.html + .. [3] https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general) + + """ + + return _number_theoretic_transform(seq, p) + + +def intt(seq, p): + return _number_theoretic_transform(seq, p, inverse=True) + +intt.__doc__ = ntt.__doc__