## Polynomial Multiplication Using NTT

In [None]:
import numpy as np
import gmpy2
from gmpy2 import mpz

# Parameters for NTT
MODULUS = mpz('52435875175126190479447740508185965837690552500527637822603658699938581184513')
ROOT_OF_UNITY_BASE = mpz('10238227357739495823651030575849232062558860180284477541189508159991286009131') # 2^32

def get_root_of_unity(n):
    logn = n.bit_length() - 1
    s = 32
    if n != (1 << logn):
        raise ValueError("Expected n to be a power of 2.")
    if logn > s:
        raise ValueError("Expected logn <= s.")
    
    # Calculate the root of unity for the given n
    omega = ROOT_OF_UNITY_BASE
    for _ in range(s, logn, -1):
        omega = gmpy2.powmod(omega, 2, MODULUS)
    
    return omega

def ntt(poly, omega, modulus, is_inv = False):
    n = len(poly)
    logn = n.bit_length() - 1
    array = [mpz(e) for e in poly]
    
    # Bit-reverse permutation
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j & bit:
            j ^= bit
            bit >>= 1
        j ^= bit
        if i < j:
            array[i], array[j] = array[j], array[i]

    if is_inv == True:
        omega = gmpy2.invert(omega, modulus)
    
    # Iterative NTT
    group_size = 1
    for stage in range(1, logn+1):
        group_size *= 2
        group_omega_mult = gmpy2.powmod(omega, n//group_size, modulus)
        for group in range(0, n, group_size):
            factor = mpz(1)
            for operation in range(group_size // 2):
                bw = (factor * array[group + operation + group_size // 2]) % modulus
                a  = array[group + operation]
                array[group + operation]                   = (a + bw) % modulus
                array[group + operation + group_size // 2] = (a - bw) % modulus
                factor = (factor * group_omega_mult) % modulus
    
    if(is_inv):
        n_inv = pow(n, -1, modulus)
        array = [(x * n_inv) % modulus for x in array]

    return array

# Function to perform polynomial multiplication using NWC
def polynomial_multiplication_ntt(a, b):
    n = len(a)
    omega = get_root_of_unity(n)
    print("omega :", omega)
    modulus = MODULUS

    ntt_a = ntt(a, omega, modulus)
    ntt_b = ntt(b, omega, modulus)
    ntt_c = [(ntt_a[i] * ntt_b[i]) % modulus for i in range(n)]
    c = ntt(ntt_c, omega, modulus, is_inv=True)
    
    return c

def print_polynomial(a):
    print('-'*50)
    for i, coeff in enumerate(a):
        print(f"Coefficient {i}: {hex(coeff)} ({coeff})")
    print('-'*50)


# Test Code
original_poly = [1, 2, 3, 4]
modulus = mpz('0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001')

omega = get_root_of_unity(len(original_poly))

ntt_result = ntt(original_poly, omega, modulus)
inv_result = ntt(ntt_result, omega, modulus, is_inv=True)

print("# NTT Test")
print_polynomial(original_poly)
print_polynomial(inv_result)

# NTT Test
--------------------------------------------------
Coefficient 0: 0x1 (1)
Coefficient 1: 0x2 (2)
Coefficient 2: 0x3 (3)
Coefficient 3: 0x4 (4)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0x1 (1)
Coefficient 1: 0x2 (2)
Coefficient 2: 0x3 (3)
Coefficient 3: 0x4 (4)
--------------------------------------------------


In [5]:
# Read input polynomials from binary files
a = [mpz(1), mpz(2)]
b = [mpz(3), mpz(5)]

# Perform polynomial multiplication using NWC and NTT
c = polynomial_multiplication_ntt(a, b)

# Print the result
print_polynomial(a)
print_polynomial(b)
print_polynomial(c)

omega : 3465144826073652318776269530687742778270252468765361963008
--------------------------------------------------
Coefficient 0: 0x1 (1)
Coefficient 1: 0x2 (2)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0x3 (3)
Coefficient 1: 0x5 (5)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0xd (13)
Coefficient 1: 0xb (11)
--------------------------------------------------


In [8]:
# Double Check C++ output

def read_bigint_file(filename, element_size=32):
    values = []
    with open(filename, "rb") as f:
        while True:
            data = f.read(element_size)
            if not data:
                break
            value = int.from_bytes(data, byteorder='little', signed=False)
            values.append(mpz(value))
    return values

def main():
    # Read input polynomials from binary files
    a = read_bigint_file("../build/data/input_a.txt")
    b = read_bigint_file("../build/data/input_b.txt")
    
    # Parameters for NTT
    modulus = mpz('0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001')

    # Perform polynomial multiplication using NWC and NTT
    c = polynomial_multiplication_ntt(a, b)
    # Print the result
    print_polynomial(a)
    print_polynomial(b)
    print_polynomial(c)

if __name__ == "__main__":
    main()


omega : 14788168760825820622209131888203028446852016562542525606630160374691593895118
--------------------------------------------------
Coefficient 0: 0x4b7d47d025435094dae81f85c2ce991dbcffe828b91d97e4606e0608ea4defc8 (34144815162426052549053130451593758100829933428756958429871266687858317848520)
Coefficient 1: 0x6d1cd0753cef52936afe549f2560daf08319cbc7547b8829aea7f9cc7e01f252 (49353010937360912679503738262692982172246841030904359723178893100382196724306)
Coefficient 2: 0x552ff07754c436720774f5e6d1e880d36e35d9d97fe5837432c3e0fc48ea0c63 (38531293577906059915352461199704451367706035671621843687369485518822173183075)
Coefficient 3: 0x50d6f437e5095a6c89e1bab4cd68279cdf977898a7a48a497e1b214fc8cb95ad (36564818691549931390331471265020958081606866926664567714799240820433718318509)
Coefficient 4: 0x426baf6291b85cdd6b96e30f7b442857bde6eb3d7b769601af4346e58fbf8110 (30042911105458455215085515726541680193528287050962282037862571049252791746832)
Coefficient 5: 0x24a6bffe8b8716905c3f56ac79b413fd3c50

In [18]:
A = 0x69a06a3b3f6677a49a1bbbc34e74e4fad0f07d0ded2040f0679455341ae888c9
B = 0x5f699348721708d60f83e3205a8a172b30584d19880212f22b6d0c83fad7c5e6
P = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
omega = 3465144826073652318776269530687742778270252468765361963008

A_o = (A + B * omega) % P
A_e = (A - B * omega) % P

print(hex(o))
print(hex(e))

0x3ff007093d47cd14370f5585a6f9a02850bdac81591dce2ad4de45951169918e
0x1f63261a17e7a4ecc9ee49f8ec4e51c7fd65a997812457b6fa4a64d424678003


### Root of Unity Check

In [3]:
import gmpy2

# Given values
_modulus = mpz("52435875175126190479447740508185965837690552500527637822603658699938581184513")
_primitive_root = mpz("52435875175126190479447740508185965837690552500527637822603658699938581184512")

n = 0
_poly_size = 1
for i in range (32):
    _poly_size *= 2
    if gmpy2.powmod(_primitive_root, _poly_size, _modulus) == 1:
        n = _poly_size.bit_length() - 1
        break

if n == 0:
    print("Not a Primitive Root")
else : 
    print("Primitive Root of 2^" + str(n))



Primitive Root of 2^1
