## Polynomial Multiplication Using NTT

In [3]:
import numpy as np
import gmpy2
from gmpy2 import mpz


# Parameters for NTT
MODULUS = mpz('52435875175126190479447740508185965837690552500527637822603658699938581184513')
ROOT_OF_UNITY_BASE = mpz('10238227357739495823651030575849232062558860180284477541189508159991286009131') # 2^32

def get_root_of_unity(n):
    logn = n.bit_length() - 1
    s = 32
    if n != (1 << logn):
        raise ValueError("Expected n to be a power of 2.")
    if logn > s:
        raise ValueError("Expected logn <= s.")
    
    # Calculate the root of unity for the given n
    omega = ROOT_OF_UNITY_BASE
    for _ in range(s, logn, -1):
        omega = gmpy2.powmod(omega, 2, MODULUS)
    
    return omega

def ntt(values, omega, modulus):
    n = len(values)
    logn = n.bit_length() - 1
    result = [mpz(v) for v in values]
    
    # Bit-reverse permutation
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j & bit:
            j ^= bit
            bit >>= 1
        j ^= bit
        if i < j:
            result[i], result[j] = result[j], result[i]
    
    # Iterative NTT
    len_ = 2
    for s in range(1, logn + 1):
        m = len_
        len_ *= 2
        omega_m = gmpy2.powmod(omega, n // m, modulus)
        for k in range(0, n, m):
            factor = mpz(1)
            for j in range(m // 2):
                t = (factor * result[k + j + m // 2]) % modulus
                u = result[k + j]
                result[k + j] = (u + t) % modulus
                result[k + j + m // 2] = (u - t) % modulus
                factor = (factor * omega_m) % modulus
    
    return result

def inverse_ntt(values, omega_inv, modulus):
    n = len(values)
    result = ntt(values, omega_inv, modulus)
    n_inv = pow(n, -1, modulus)
    return [(x * n_inv) % modulus for x in result]

# Function to perform polynomial multiplication using NWC
def polynomial_multiplication_ntt(a, b):
    n = len(a)
    omega = get_root_of_unity(2*n)
    print("omega :", omega)
    modulus = MODULUS

    # Forward NTT on both polynomials
    ntt_a = ntt(a, omega, modulus)
    ntt_b = ntt(b, omega, modulus)
    
    # Pointwise multiplication in NTT domain
    ntt_c = [(ntt_a[i] * ntt_b[i]) % modulus for i in range(n)]
    
    # Inverse NTT
    omega_inv = gmpy2.invert(omega, modulus)

    c = inverse_ntt(ntt_c, omega_inv, modulus)
    
    return c

def print_polynomial(a):
    print('-'*50)
    for i, coeff in enumerate(a):
        print(f"Coefficient {i}: {hex(coeff)} ({coeff})")
    print('-'*50)

In [4]:
# Read input polynomials from binary files
a = [mpz(1), mpz(2)]
b = [mpz(3), mpz(5)]

# Perform polynomial multiplication using NWC and NTT
c = polynomial_multiplication_ntt(a, b)

# Print the result
print_polynomial(a)
print_polynomial(b)
print_polynomial(c)



omega : 23674694431658770659612952115660802947967373701506253797663184111817857449850
--------------------------------------------------
Coefficient 0: 0x1 (1)
Coefficient 1: 0x2 (2)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0x3 (3)
Coefficient 1: 0x5 (5)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0xd (13)
Coefficient 1: 0xb (11)
--------------------------------------------------


In [52]:
# Double Check C++ output

def read_bigint_file(filename, element_size=32):
    values = []
    with open(filename, "rb") as f:
        while True:
            data = f.read(element_size)
            if not data:
                break
            value = int.from_bytes(data, byteorder='little', signed=False)
            values.append(mpz(value))
    return values

def main():
    # Read input polynomials from binary files
    a = read_bigint_file("../build/data/input_a.txt")
    b = read_bigint_file("../build/data/input_b.txt")
    
    # Parameters for NTT
    modulus = mpz('0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001')
    omega = mpz('23674694431658770659612952115660802947967373701506253797663184111817857449850')
    # Primitive 2n-th root of unity

    # Perform polynomial multiplication using NWC and NTT
    c = polynomial_multiplication_ntt(a, b, omega, modulus)

    # Print the result
    print_polynomial(a)
    print_polynomial(b)
    print_polynomial(c)

if __name__ == "__main__":
    main()


--------------------------------------------------
Coefficient 0: 0x103ddbb56511e8ad72635cf9fcb37761303b3357b585ddef4196433a9e05aeb1 (7346299621127946297075567016002085968144408807773727290984631736966461107889)
Coefficient 1: 0x61985853a9a672a5e2c58180c6335eb2654606026d8b7d8faa1510dc01cc570c (44143516675643494996759749620782296438877098200381385840048821502585403627276)
Coefficient 2: 0x8783ccde718c38770a453132f7730d10f19ccc27b28049ed5c860bb77ba5732 (3830944092346738032114966288029954732053062639602335535784584908736449632050)
Coefficient 3: 0x66d98795a1b03fb5f62df7f91cfaf7df144595481193a4d07eb8ae0bcc151f9c (46520252138366734683603620910390867249064017136673153941519093918913103470492)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0x6737126bed7b2f0236e93a38fcb61b9f5edb692295571e55c0f53ee63303b5fb (46685527133797472686420480590910538180586674551187103060334512123811375330811)
Coefficient 1: 0x34515984d975851fe2b3d44

In [18]:
A = 0x69a06a3b3f6677a49a1bbbc34e74e4fad0f07d0ded2040f0679455341ae888c9
B = 0x5f699348721708d60f83e3205a8a172b30584d19880212f22b6d0c83fad7c5e6
P = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
omega = 3465144826073652318776269530687742778270252468765361963008

A_o = (A + B * omega) % P
A_e = (A - B * omega) % P

print(hex(o))
print(hex(e))

0x3ff007093d47cd14370f5585a6f9a02850bdac81591dce2ad4de45951169918e
0x1f63261a17e7a4ecc9ee49f8ec4e51c7fd65a997812457b6fa4a64d424678003


### Root of Unity Check

In [55]:
import gmpy2

def verify_root_of_unity(modulus_str, root_of_unity_str, s, n):
    # Initialize modulus and root of unity
    modulus = gmpy2.mpz(modulus_str)
    root_of_unity = gmpy2.mpz(root_of_unity_str)
    
    # Verify that n is a power of 2 and n <= 2^s
    if not (n & (n - 1) == 0) or n > (1 << s):
        print("Invalid n: n must be a power of 2 and less than or equal to 2^s.")
        return
    
    # Calculate the expected exponent
    logn = n.bit_length() - 1
    current_s = s
    omega = root_of_unity
    
    # Calculate omega^(2^(s-logn))
    while current_s > logn:
        omega = gmpy2.powmod(omega, 2, modulus)
        current_s -= 1
    
    # Verify that omega^n % modulus == 1
    result = gmpy2.powmod(omega, n, modulus)
    if result != 1:
        print("The given root_of_unity is NOT a valid 2^{}-th root of unity.".format(logn))
        return
    
    print("The given root_of_unity is a valid 2^{}-th root of unity.".format(logn))

# Given values
modulus_str = "52435875175126190479447740508185965837690552500527637822603658699938581184513"
root_of_unity_str = "10238227357739495823651030575849232062558860180284477541189508159991286009131"
s = 32
n = 2**32

# Verify root of unity
verify_root_of_unity(modulus_str, root_of_unity_str, s, n)


The given root_of_unity is a valid 2^32-th root of unity.
