New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added discrete module, transforms #14725
Changes from 6 commits
99975e1
5d2bef1
8773627
a2698c3
5fa296b
b63a021
7635b86
14d03bf
c77648f
568bfc3
a962db8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""A module containing discrete functions. | ||
|
||
Transforms - fft, ifft, ntt, intt, fwht, ifwht, fzt, ifzt, fmt, ifmt | ||
Convolution - conv, conv_xor, conv_and, conv_or, conv_sub, conv_sup | ||
Recurrence Evaluation - reval_hcc | ||
""" | ||
|
||
|
||
from .transforms import (fft, ifft, ntt, intt) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from __future__ import print_function, division | ||
|
||
from sympy import sqrt | ||
from sympy.core import S, Symbol, I | ||
from sympy.core.compatibility import range | ||
from sympy.discrete import fft, ifft, ntt, intt | ||
from sympy.utilities.pytest import raises | ||
|
||
|
||
def test_fft_ifft(): | ||
assert all(tf(ls) == ls for tf in (fft, ifft) \ | ||
for ls in ([], [S(5)/3])) | ||
|
||
ls = list(range(6)) | ||
fls = [15, -7*sqrt(2)/2 - 4 - sqrt(2)*I/2 + 2*I, 2 + 3*I, | ||
-4 + 7*sqrt(2)/2 - 2*I - sqrt(2)*I/2, -3, | ||
-4 + 7*sqrt(2)/2 + sqrt(2)*I/2 + 2*I, | ||
2 - 3*I, -7*sqrt(2)/2 - 4 - 2*I + sqrt(2)*I/2] | ||
|
||
assert fft(ls) == fls | ||
assert ifft(fls) == ls + [S.Zero]*2 | ||
|
||
ls = [1 + 2*I, 3 + 4*I, 5 + 6*I] | ||
ifls = [S(9)/4 + 3*I, -7*I/4, S(3)/4 + I, -2 - I/4] | ||
|
||
assert ifft(ls) == ifls | ||
assert fft(ifls) == ls + [S.Zero] | ||
|
||
x = Symbol('x', real=True) | ||
raises(TypeError, lambda: fft(x)) | ||
raises(ValueError, lambda: ifft([x, 2*x, 3*x**2, 4*x**3])) | ||
|
||
|
||
def test_ntt_intt(): | ||
# prime modulo of the form (m*2**k + 1), sequence length should | ||
# be a divisor of 2**k | ||
p = 7*17*2**23 + 1 | ||
q = 2*500000003 + 1 # prime modulo only for (length = 1) | ||
r = 2*3*5*7 # composite modulo | ||
|
||
assert all(tf(ls, p) == ls for tf in (ntt, intt) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Backslash is not needed inside parentheses, etc. PEP 8:
|
||
for ls in ([], [5])) | ||
|
||
ls = list(range(6)) | ||
nls = [15, 801133602, 738493201, 334102277, 998244350, 849020224, \ | ||
259751156, 12232587] | ||
|
||
assert ntt(ls, p) == nls | ||
assert intt(nls, p) == ls + [0]*2 | ||
|
||
ls = [1 + 2*I, 3 + 4*I, 5 + 6*I] | ||
x = Symbol('x', integer=True) | ||
|
||
raises(TypeError, lambda: ntt(x, p)) | ||
raises(ValueError, lambda: intt([x, 2*x, 3*x**2, 4*x**3], p)) | ||
raises(ValueError, lambda: intt(ls, p)) | ||
raises(ValueError, lambda: ntt([1.2, 2.1, 3.5], p)) | ||
raises(ValueError, lambda: ntt([3, 5, 6], q)) | ||
raises(ValueError, lambda: ntt([4, 5, 7], r)) | ||
|
||
assert ntt([1.0, 2.0, 3.0], p) == ntt([1, 2, 3], p) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,227 @@ | ||
""" | ||
Discrete Fourier Transform, Number Theoretic Transform, | ||
Walsh Hadamard Transform, Zeta Transform, Mobius Transform | ||
""" | ||
from __future__ import print_function, division | ||
|
||
from sympy.core import S, Symbol, sympify | ||
from sympy.core.compatibility import as_int, range, iterable | ||
from sympy.core.function import expand, expand_mul | ||
from sympy.core.numbers import pi, I | ||
from sympy.functions.elementary.exponential import exp | ||
from sympy.functions.elementary.trigonometric import sin, cos | ||
from sympy.ntheory import isprime, primitive_root | ||
from sympy.utilities.iterables import ibin | ||
|
||
|
||
#----------------------------------------------------------------------------# | ||
# # | ||
# Discrete Fourier Transform # | ||
# # | ||
#----------------------------------------------------------------------------# | ||
|
||
def _fourier_transform(seq, symbolic, inverse=False): | ||
"""Utility function for the Discrete Fourier Transform (DFT)""" | ||
|
||
if not iterable(seq): | ||
raise TypeError("Expected a sequence of numeric coefficients " + | ||
"for Fourier Transform") | ||
|
||
a = [sympify(arg) for arg in seq] | ||
if any(x.has(Symbol) for x in a): | ||
raise ValueError("Expected non-symbolic coefficients") | ||
|
||
n = len(a) | ||
if n < 2: | ||
return a | ||
|
||
b = n.bit_length() - 1 | ||
if n&(n - 1): # not a power of 2 | ||
b += 1 | ||
n = 2**b | ||
|
||
a += [S.Zero]*(n - len(a)) | ||
for i in range(1, n): | ||
j = int(ibin(i, b, str=True)[::-1], 2) | ||
if i < j: | ||
a[i], a[j] = a[j], a[i] | ||
|
||
ang = -2*pi/n if inverse else 2*pi/n | ||
|
||
if not symbolic: | ||
ang = ang.evalf() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The precision should probably be respected here, too. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What changes are expected here, for deciding the precision? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that there are several possibilities. The precision could be given as a keyword parameter or derived from the argument values. The precision for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, will try them out. Thanks. |
||
|
||
w = [cos(ang*i) + I*sin(ang*i) for i in range(n // 2)] | ||
|
||
h = 2 | ||
while h <= n: | ||
hf, ut = h // 2, n // h | ||
for i in range(0, n, h): | ||
for j in range(hf): | ||
u, v = a[i + j], expand_mul(a[i + j + hf]*w[ut * j]) | ||
a[i + j], a[i + j + hf] = u + v, u - v | ||
h *= 2 | ||
|
||
if inverse: | ||
for i in range(n): | ||
a[i] /= n | ||
|
||
return a | ||
|
||
|
||
def fft(seq, symbolic=True): | ||
r""" | ||
Performs the Discrete Fourier Transform (DFT) in the complex domain. | ||
|
||
The sequence is automatically padded to the right with zeros, as the | ||
radix 2 FFT requires the number of sample points to be a power of 2. | ||
|
||
Parameters | ||
========== | ||
|
||
seq : iterable | ||
The sequence on which DFT is to be applied. | ||
symbolic : bool | ||
Determines if DFT is to be performed using symbolic values or | ||
numerical values. | ||
|
||
Examples | ||
======== | ||
|
||
>>> from sympy import fft, ifft | ||
|
||
>>> fft([1, 2, 3, 4]) | ||
[10, -2 - 2*I, -2, -2 + 2*I] | ||
>>> ifft(_) | ||
[1, 2, 3, 4] | ||
|
||
>>> ifft([1, 2, 3, 4]) | ||
[5/2, -1/2 + I/2, -1/2, -1/2 - I/2] | ||
>>> fft(_) | ||
[1, 2, 3, 4] | ||
|
||
>>> fft([5]) | ||
[5] | ||
>>> ifft([7]) | ||
[7] | ||
|
||
References | ||
========== | ||
|
||
.. [1] https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm | ||
.. [2] http://mathworld.wolfram.com/FastFourierTransform.html | ||
|
||
""" | ||
|
||
return _fourier_transform(seq, symbolic=symbolic) | ||
|
||
|
||
def ifft(seq, symbolic=True): | ||
return _fourier_transform(seq, symbolic=symbolic, inverse=True) | ||
|
||
ifft.__doc__ = fft.__doc__ | ||
|
||
|
||
#----------------------------------------------------------------------------# | ||
# # | ||
# Number Theoretic Transform # | ||
# # | ||
#----------------------------------------------------------------------------# | ||
|
||
def _number_theoretic_transform(seq, q, inverse=False): | ||
"""Utility function for the Number Theoretic transform (NTT)""" | ||
|
||
if not iterable(seq): | ||
raise TypeError("Expected a sequence of integer coefficients " + | ||
"for Number Theoretic Transform") | ||
|
||
q = as_int(q) | ||
if isprime(q) == False: | ||
raise ValueError("Expected prime modulo for " + | ||
"Number Theoretic Transform") | ||
|
||
a = [as_int(x) for x in seq] | ||
|
||
n = len(a) | ||
if n < 1: | ||
return a | ||
|
||
b = n.bit_length() - 1 | ||
if n&(n - 1): | ||
b += 1 | ||
n = 2**b | ||
|
||
if (q - 1) % n: | ||
raise ValueError("Expected prime modulo of the form (m*2**k + 1)") | ||
|
||
a += [0]*(n - len(a)) | ||
for i in range(1, n): | ||
j = int(ibin(i, b, str=True)[::-1], 2) | ||
if i < j: | ||
a[i], a[j] = a[j] % q, a[i] % q | ||
|
||
pr = primitive_root(q) | ||
|
||
rt = pow(pr, (q - 1) // n, q) | ||
if inverse: | ||
rt = pow(rt, q - 2, q) | ||
|
||
w = [1]*(n // 2) | ||
for i in range(1, n // 2): | ||
w[i] = w[i - 1] * rt % q | ||
|
||
h = 2 | ||
while h <= n: | ||
hf, ut = h // 2, n // h | ||
for i in range(0, n, h): | ||
for j in range(hf): | ||
u, v = a[i + j], a[i + j + hf]*w[ut * j] | ||
a[i + j], a[i + j + hf] = (u + v) % q, (u - v) % q | ||
h *= 2 | ||
|
||
if inverse: | ||
rv = pow(n, q - 2, q) | ||
for i in range(n): | ||
a[i] = a[i]*rv % q | ||
|
||
return a | ||
|
||
|
||
def ntt(seq, q): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the second argument here be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. |
||
r""" | ||
Performs the Number Theoretic Transform (NTT), which specializes the | ||
Discrete Fourier Transform (DFT) over quotient ring Z/pZ for prime p | ||
instead of complex numbers C. | ||
|
||
|
||
The sequence is automatically padded to the right with zeros, as the | ||
radix 2 NTT requires the number of sample points to be a power of 2. | ||
|
||
Examples | ||
======== | ||
|
||
>>> from sympy import ntt, intt | ||
>>> ntt([1, 2, 3, 4], 3*2**8 + 1) | ||
[10, 643, 767, 122] | ||
>>> intt(_, 3*2**8 + 1) | ||
[1, 2, 3, 4] | ||
>>> intt([1, 2, 3, 4], 3*2**8 + 1) | ||
[387, 415, 384, 353] | ||
>>> ntt(_, 3*2**8 + 1) | ||
[1, 2, 3, 4] | ||
|
||
References | ||
========== | ||
|
||
.. [1] http://www.apfloat.org/ntt.html | ||
.. [2] http://mathworld.wolfram.com/NumberTheoreticTransform.html | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A reference to Wikipedia could also be included: https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay, will do. |
||
""" | ||
|
||
return _number_theoretic_transform(seq, q) | ||
|
||
|
||
def intt(seq, q): | ||
return _number_theoretic_transform(seq, q, inverse=True) | ||
|
||
intt.__doc__ = ntt.__doc__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that the word is 'modulus'. 'modulo' is its ablative case that means 'by modulus'.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, will make the changes. Thanks.