Skip to content

Commit

Permalink
gh-36456: implement .interpolation() method for ProductTree
Browse files Browse the repository at this point in the history
    
The `ProductTree` class has a `.remainders()` method which can, among
other things, be used to implement the Fast Fourier Transform. In this
patch we add a corresponding `.interpolation()` method, which can, among
other things, be used to implement the *inverse* Fast Fourier Transform.
Its functionality is equivalent to `CRT_list()`, but caching the
product-tree structure makes it significantly faster for repeated
invocations.
    
URL: #36456
Reported by: Lorenz Panny
Reviewer(s): Kwankyu Lee
  • Loading branch information
Release Manager committed Jan 12, 2024
2 parents d30ac1c + 07e2c29 commit 2552ba0
Showing 1 changed file with 57 additions and 2 deletions.
59 changes: 57 additions & 2 deletions src/sage/rings/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@ class ProductTree:
sage: R.<x> = F[]
sage: ms = [x - a^i for i in range(1024)] # roots of unity
sage: ys = [F.random_element() for _ in range(1024)] # input vector
sage: zs = ProductTree(ms).remainders(R(ys)) # compute FFT!
sage: tree = ProductTree(ms)
sage: zs = tree.remainders(R(ys)) # compute FFT!
sage: zs == [R(ys) % m for m in ms]
True
Similarly, the :meth:`interpolation` method can be used to implement
the inverse Fast Fourier Transform::
sage: tree.interpolation(zs).padded_list(len(ys)) == ys
True
This class encodes the tree as *layers*: Layer `0` is just a tuple
of the leaves. Layer `i+1` is obtained from layer `i` by replacing
each pair of two adjacent elements by their product, starting from
Expand Down Expand Up @@ -177,7 +184,6 @@ def remainders(self, x):
The base ring must support the ``%`` operator for this
method to work.
INPUT:
- ``x`` -- an element of the base ring of this product tree
Expand All @@ -199,6 +205,55 @@ def remainders(self, x):
X = [X[i // 2] % V[i] for i in range(len(V))]
return X

_crt_bases = None

def interpolation(self, xs):
r"""
Given a sequence ``xs`` of values, one per leaf, return a
single element `x` which is congruent to the `i`\th value in
``xs`` modulo the `i`\th leaf, for all `i`.
This is an explicit version of the Chinese remainder theorem;
see also :meth:`CRT`. Using this product tree is faster for
repeated calls since the required CRT bases are cached after
the first run.
The base ring must support the :func:`xgcd` function for this
method to work.
EXAMPLES::
sage: from sage.rings.generic import ProductTree
sage: vs = prime_range(100)
sage: tree = ProductTree(vs)
sage: tree.interpolation([1, 1, 2, 1, 9, 1, 7, 15, 8, 20, 15, 6, 27, 11, 2, 6, 0, 25, 49, 5, 51, 4, 19, 74, 13])
1085749272377676749812331719267
This method is faster than :func:`CRT` for repeated calls with
the same moduli::
sage: vs = prime_range(1000,2000)
sage: rs = lambda: [randrange(1,100) for _ in vs]
sage: tree = ProductTree(vs)
sage: %timeit CRT(rs(), vs) # not tested
372 µs ± 3.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
sage: %timeit tree.interpolation(rs()) # not tested
146 µs ± 479 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
"""
if self._crt_bases is None:
from sage.arith.misc import CRT_basis
self._crt_bases = []
for V in self.layers[:-1]:
B = tuple(CRT_basis(V[i:i+2]) for i in range(0, len(V), 2))
self._crt_bases.append(B)
if len(xs) != len(self.layers[0]):
raise ValueError('number of given elements must equal the number of leaves')
for basis, layer in zip(self._crt_bases, self.layers[1:]):
xs = [sum(c*x for c, x in zip(cs, xs[2*i:2*i+2])) % mod
for i, (cs, mod) in enumerate(zip(basis, layer))]
assert len(xs) == 1
return xs[0]


def prod_with_derivative(pairs):
r"""
Expand Down

0 comments on commit 2552ba0

Please sign in to comment.