In [1]:
from sympy import Poly, GF, symbols
from sympy import ntheory
from sympy.ntheory import residue_ntheory
import random
random.seed(0)

# q, n = 7, 8  # toy example
# q, n = 3329, 256  # Kyber: has 256-th primitive root, not 512-th primitive root
q, n = (2 ** 23 - 2 ** 13 + 1), 256  # Dilithium, has 512-th primitive root

from enum import Enum

class Convolution(Enum):
    # linear convolution, the polynomial ring is not a quotient ring
    Linear = 1

    # cyclic convolution, the polynomial ring is quotient on (x^n - 1)
    Cyclic = 2

    # negative wrapping convolution, quotient on (x^n + 1)
    NegativeWrapping = 3

def generate_polynomial_ring(q: int, n: int, conv_type: Convolution):
    """Return the domain GF(q) and the quotient polynomial x^n + 1"""
    ff = GF(q)
    indet = symbols("x")
    modulus_coeffs = [0 for _ in range(n+1)]
    modulus_coeffs[0] = 1
    if conv_type == Convolution.Cyclic:
        modulus_coeffs[-1] = -1
    elif conv_type == Convolution.NegativeWrapping:
        modulus_coeffs[-1] = 1
    else:
        raise NotImplementedError("Linear convolution currently not supported")
    modulus_poly = Poly(modulus_coeffs, indet, domain=ff)

    return ff, indet, modulus_poly

ff, indet, modulus_poly = generate_polynomial_ring(
    q, n, Convolution.NegativeWrapping)


a_coeffs = [random.randint(1, q-1) for _ in range(n)]
b_coeffs = [0 for _ in range(n)]  # x
b_coeffs[-2] = 1

a_poly = Poly(a_coeffs, indet, domain=ff)
b_poly = Poly(b_coeffs, indet, domain=ff)

expected_coeffs = a_coeffs[1:].copy()
expected_coeffs.append(-a_coeffs[0])
expected_poly = Poly(expected_coeffs, indet, domain=ff)

assert (a_poly * b_poly) % modulus_poly == expected_poly, \
    "Negative wrapping convolution failed"

ff, indet, modulus_poly = generate_polynomial_ring(q, n, Convolution.Cyclic)
expected_coeffs = a_coeffs[1:].copy()
expected_coeffs.append(a_coeffs[0])
expected_poly = Poly(expected_coeffs, indet, domain=ff)
assert (a_poly * b_poly) % modulus_poly == expected_poly, \
    "Cyclic convolution failed"

In [2]:
def n_primitive_root(q: int, n: int):
    """Return the n-th primitive root of GF(q) if it exists; otherwise return
    None.

    ValueError will be raised if q is not prime, or if n does not divide q-1
    """
    if not ntheory.isprime(q):
        raise ValueError(f"Modulus {q} is not prime")
    if (q - 1) % n != 0:
        raise ValueError(f"n={n} does not divide group order {q-1}")

    # First find the primitive root, then raise it to the appropriate power
    return GF(q)(residue_ntheory.primitive_root(q)) ** ((q-1) // n)

# Cyclic convolution

In [3]:
def cyclic_ntt(polynomial, q, n):
    w = n_primitive_root(q, n)
    return [polynomial.eval(w ** (n-1-j)) for j in range(n)]

def cyclic_inv_ntt(points, q, n):
    w_inv = n_primitive_root(q, n) ** (-1)
    n_inv = pow(n, -1, q)
    coeffs = [0 for _ in range(n)]

    # Sympy polynomials' coefficients are arranged such that the coefficient at a
    # lower index corresponds to the higher power term; we will use loc to locate
    # the index, and power <- (n - 1 - loc)
    for i in range(n):
        coeffs[n - 1 - i] = n_inv * sum(
            [points[n - 1 - j] * w_inv ** (i * j) for j in range(n)]
        )

    return coeffs

a_poly = Poly([random.randint(1, q-1) for _ in range(n)], indet, domain=ff)
a_ntt = cyclic_ntt(a_poly, q, n)

# Check that inv_ntt indeed inverts ntt
assert Poly(cyclic_inv_ntt(a_ntt, q, n), indet, domain=ff) == a_poly, \
    "NTT inversion failed"

In [4]:
a_poly = Poly([random.randint(1, q-1) for _ in range(n)], indet, domain=ff)
b_poly = Poly([random.randint(1, q-1) for _ in range(n)], indet, domain=ff)
expected_prod = (a_poly * b_poly) % modulus_poly

a_ntt = cyclic_ntt(a_poly, q, n)
b_ntt = cyclic_ntt(b_poly, q, n)
c_ntt = [ff(a) * ff(b) for (a, b) in zip(a_ntt, b_ntt)]
c_poly = Poly(cyclic_inv_ntt(c_ntt, q, n), indet, domain=ff)
assert c_poly == expected_prod, "NTT multiplication failed"

# Negative-wrapped convolution

$$
\phi_{2n} \leftarrow \text{2n-th primitive root}
$$

In [5]:
ff, indet, modulus_poly = generate_polynomial_ring(
    q, n, Convolution.NegativeWrapping)

def negwrap_ntt(poly: Poly, q, n):
    """Perform negative-wrapped number theoretic transform"""
    phi_2n = n_primitive_root(q, 2 * n)
    omega = phi_2n ** 2
    return [ff(poly.eval(phi_2n * omega ** (n - 1 - j))) for j in range(n)]

def negwrap_inv_ntt(points, q, n):
    """Invert negative-wrapped number theoretic transform"""
    phi_2n = n_primitive_root(q, 2 * n)
    omega = phi_2n ** 2
    n_inv = GF(q)(n) ** -1
    coeffs = [0 for _ in range(n)]
    for i in range(n):
        coeffs[n - 1 - i] = n_inv * (phi_2n ** -i) * sum(
            [points[n - 1 - j] * (omega ** (-1 * i * j)) for j in range(n)]
        )
    return Poly(coeffs, indet, domain=ff)


# Check that inversion indeed inverts the transform
a_poly = Poly([random.randint(1, q-1) for _ in range(n)], indet, domain=ff)
a_ntt = negwrap_ntt(a_poly, q, n)
assert Poly(negwrap_inv_ntt(a_ntt, q, n), indet, domain=ff) == a_poly, \
    "NTT inversion failed"

# Check NTT multiplication
a_poly = Poly([random.randint(1, q-1) for _ in range(n)], indet, domain=ff)
b_poly = Poly([random.randint(1, q-1) for _ in range(n)], indet, domain=ff)
expected_prod = (a_poly * b_poly) % modulus_poly

a_ntt = negwrap_ntt(a_poly, q, n)
b_ntt = negwrap_ntt(b_poly, q, n)
c_ntt = [ff(a) * ff(b) for (a, b) in zip(a_ntt, b_ntt)]
c_poly = Poly(negwrap_inv_ntt(c_ntt, q, n), indet, domain=ff)
assert c_poly == expected_prod, "NTT multiplication failed"