Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added discrete module, transforms #14725

Merged
merged 11 commits into from May 24, 2018
2 changes: 2 additions & 0 deletions setup.py
Expand Up @@ -91,6 +91,7 @@
'sympy.crypto',
'sympy.deprecated',
'sympy.diffgeom',
'sympy.discrete',
'sympy.external',
'sympy.functions',
'sympy.functions.combinatorial',
Expand Down Expand Up @@ -312,6 +313,7 @@ def run(self):
'sympy.crypto.tests',
'sympy.deprecated.tests',
'sympy.diffgeom.tests',
'sympy.discrete.tests',
'sympy.external.tests',
'sympy.functions.combinatorial.tests',
'sympy.functions.elementary.tests',
Expand Down
1 change: 1 addition & 0 deletions sympy/__init__.py
Expand Up @@ -62,6 +62,7 @@ def __sympy_debug():
from .functions import *
from .ntheory import *
from .concrete import *
from .discrete import *
from .simplify import *
from .sets import *
from .solvers import *
Expand Down
9 changes: 9 additions & 0 deletions sympy/discrete/__init__.py
@@ -0,0 +1,9 @@
"""A module containing discrete functions.

Transforms - fft, ifft, ntt, intt, fwht, ifwht, fzt, ifzt, fmt, ifmt
Convolution - conv, conv_xor, conv_and, conv_or, conv_sub, conv_sup
Recurrence Evaluation - reval_hcc
"""


from .transforms import (fft, ifft, ntt, intt)
Empty file.
61 changes: 61 additions & 0 deletions sympy/discrete/tests/test_transforms.py
@@ -0,0 +1,61 @@
from __future__ import print_function, division

from sympy import sqrt
from sympy.core import S, Symbol, I
from sympy.core.compatibility import range
from sympy.discrete import fft, ifft, ntt, intt
from sympy.utilities.pytest import raises


def test_fft_ifft():
assert all(tf(ls) == ls for tf in (fft, ifft) \
for ls in ([], [S(5)/3]))

ls = list(range(6))
fls = [15, -7*sqrt(2)/2 - 4 - sqrt(2)*I/2 + 2*I, 2 + 3*I,
-4 + 7*sqrt(2)/2 - 2*I - sqrt(2)*I/2, -3,
-4 + 7*sqrt(2)/2 + sqrt(2)*I/2 + 2*I,
2 - 3*I, -7*sqrt(2)/2 - 4 - 2*I + sqrt(2)*I/2]

assert fft(ls) == fls
assert ifft(fls) == ls + [S.Zero]*2

ls = [1 + 2*I, 3 + 4*I, 5 + 6*I]
ifls = [S(9)/4 + 3*I, -7*I/4, S(3)/4 + I, -2 - I/4]

assert ifft(ls) == ifls
assert fft(ifls) == ls + [S.Zero]

x = Symbol('x', real=True)
raises(TypeError, lambda: fft(x))
raises(ValueError, lambda: ifft([x, 2*x, 3*x**2, 4*x**3]))


def test_ntt_intt():
# prime modulo of the form (m*2**k + 1), sequence length should
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that the word is 'modulus'. 'modulo' is its ablative case that means 'by modulus'.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, will make the changes. Thanks.

# be a divisor of 2**k
p = 7*17*2**23 + 1
q = 2*500000003 + 1 # prime modulo only for (length = 1)
r = 2*3*5*7 # composite modulo

assert all(tf(ls, p) == ls for tf in (ntt, intt) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Backslash is not needed inside parentheses, etc. PEP 8:

The preferred way of wrapping long lines is by using Python's implied line continuation inside parentheses, brackets and braces. Long lines can be broken over multiple lines by wrapping expressions in parentheses. These should be used in preference to using a backslash for line continuation.

for ls in ([], [5]))

ls = list(range(6))
nls = [15, 801133602, 738493201, 334102277, 998244350, 849020224, \
259751156, 12232587]

assert ntt(ls, p) == nls
assert intt(nls, p) == ls + [0]*2

ls = [1 + 2*I, 3 + 4*I, 5 + 6*I]
x = Symbol('x', integer=True)

raises(TypeError, lambda: ntt(x, p))
raises(ValueError, lambda: intt([x, 2*x, 3*x**2, 4*x**3], p))
raises(ValueError, lambda: intt(ls, p))
raises(ValueError, lambda: ntt([1.2, 2.1, 3.5], p))
raises(ValueError, lambda: ntt([3, 5, 6], q))
raises(ValueError, lambda: ntt([4, 5, 7], r))

assert ntt([1.0, 2.0, 3.0], p) == ntt([1, 2, 3], p)
227 changes: 227 additions & 0 deletions sympy/discrete/transforms.py
@@ -0,0 +1,227 @@
"""
Discrete Fourier Transform, Number Theoretic Transform,
Walsh Hadamard Transform, Zeta Transform, Mobius Transform
"""
from __future__ import print_function, division

from sympy.core import S, Symbol, sympify
from sympy.core.compatibility import as_int, range, iterable
from sympy.core.function import expand, expand_mul
from sympy.core.numbers import pi, I
from sympy.functions.elementary.exponential import exp
from sympy.functions.elementary.trigonometric import sin, cos
from sympy.ntheory import isprime, primitive_root
from sympy.utilities.iterables import ibin


#----------------------------------------------------------------------------#
# #
# Discrete Fourier Transform #
# #
#----------------------------------------------------------------------------#

def _fourier_transform(seq, symbolic, inverse=False):
"""Utility function for the Discrete Fourier Transform (DFT)"""

if not iterable(seq):
raise TypeError("Expected a sequence of numeric coefficients " +
"for Fourier Transform")

a = [sympify(arg) for arg in seq]
if any(x.has(Symbol) for x in a):
raise ValueError("Expected non-symbolic coefficients")

n = len(a)
if n < 2:
return a

b = n.bit_length() - 1
if n&(n - 1): # not a power of 2
b += 1
n = 2**b

a += [S.Zero]*(n - len(a))
for i in range(1, n):
j = int(ibin(i, b, str=True)[::-1], 2)
if i < j:
a[i], a[j] = a[j], a[i]

ang = -2*pi/n if inverse else 2*pi/n

if not symbolic:
ang = ang.evalf()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The precision should probably be respected here, too.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What changes are expected here, for deciding the precision?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that there are several possibilities. The precision could be given as a keyword parameter or derived from the argument values. The precision for ang should probably be higher than the expected output precision to minimize rounding errors. Probably some experimentation is necessary.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, will try them out. Thanks.


w = [cos(ang*i) + I*sin(ang*i) for i in range(n // 2)]

h = 2
while h <= n:
hf, ut = h // 2, n // h
for i in range(0, n, h):
for j in range(hf):
u, v = a[i + j], expand_mul(a[i + j + hf]*w[ut * j])
a[i + j], a[i + j + hf] = u + v, u - v
h *= 2

if inverse:
for i in range(n):
a[i] /= n

return a


def fft(seq, symbolic=True):
r"""
Performs the Discrete Fourier Transform (DFT) in the complex domain.

The sequence is automatically padded to the right with zeros, as the
radix 2 FFT requires the number of sample points to be a power of 2.

Parameters
==========

seq : iterable
The sequence on which DFT is to be applied.
symbolic : bool
Determines if DFT is to be performed using symbolic values or
numerical values.

Examples
========

>>> from sympy import fft, ifft

>>> fft([1, 2, 3, 4])
[10, -2 - 2*I, -2, -2 + 2*I]
>>> ifft(_)
[1, 2, 3, 4]

>>> ifft([1, 2, 3, 4])
[5/2, -1/2 + I/2, -1/2, -1/2 - I/2]
>>> fft(_)
[1, 2, 3, 4]

>>> fft([5])
[5]
>>> ifft([7])
[7]

References
==========

.. [1] https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm
.. [2] http://mathworld.wolfram.com/FastFourierTransform.html

"""

return _fourier_transform(seq, symbolic=symbolic)


def ifft(seq, symbolic=True):
return _fourier_transform(seq, symbolic=symbolic, inverse=True)

ifft.__doc__ = fft.__doc__


#----------------------------------------------------------------------------#
# #
# Number Theoretic Transform #
# #
#----------------------------------------------------------------------------#

def _number_theoretic_transform(seq, q, inverse=False):
"""Utility function for the Number Theoretic transform (NTT)"""

if not iterable(seq):
raise TypeError("Expected a sequence of integer coefficients " +
"for Number Theoretic Transform")

q = as_int(q)
if isprime(q) == False:
raise ValueError("Expected prime modulo for " +
"Number Theoretic Transform")

a = [as_int(x) for x in seq]

n = len(a)
if n < 1:
return a

b = n.bit_length() - 1
if n&(n - 1):
b += 1
n = 2**b

if (q - 1) % n:
raise ValueError("Expected prime modulo of the form (m*2**k + 1)")

a += [0]*(n - len(a))
for i in range(1, n):
j = int(ibin(i, b, str=True)[::-1], 2)
if i < j:
a[i], a[j] = a[j] % q, a[i] % q

pr = primitive_root(q)

rt = pow(pr, (q - 1) // n, q)
if inverse:
rt = pow(rt, q - 2, q)

w = [1]*(n // 2)
for i in range(1, n // 2):
w[i] = w[i - 1] * rt % q

h = 2
while h <= n:
hf, ut = h // 2, n // h
for i in range(0, n, h):
for j in range(hf):
u, v = a[i + j], a[i + j + hf]*w[ut * j]
a[i + j], a[i + j + hf] = (u + v) % q, (u - v) % q
h *= 2

if inverse:
rv = pow(n, q - 2, q)
for i in range(n):
a[i] = a[i]*rv % q

return a


def ntt(seq, q):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the second argument here be p, to match the docstring?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

r"""
Performs the Number Theoretic Transform (NTT), which specializes the
Discrete Fourier Transform (DFT) over quotient ring Z/pZ for prime p
instead of complex numbers C.


The sequence is automatically padded to the right with zeros, as the
radix 2 NTT requires the number of sample points to be a power of 2.

Examples
========

>>> from sympy import ntt, intt
>>> ntt([1, 2, 3, 4], 3*2**8 + 1)
[10, 643, 767, 122]
>>> intt(_, 3*2**8 + 1)
[1, 2, 3, 4]
>>> intt([1, 2, 3, 4], 3*2**8 + 1)
[387, 415, 384, 353]
>>> ntt(_, 3*2**8 + 1)
[1, 2, 3, 4]

References
==========

.. [1] http://www.apfloat.org/ntt.html
.. [2] http://mathworld.wolfram.com/NumberTheoreticTransform.html

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A reference to Wikipedia could also be included: https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, will do.

"""

return _number_theoretic_transform(seq, q)


def intt(seq, q):
return _number_theoretic_transform(seq, q, inverse=True)

intt.__doc__ = ntt.__doc__