## Polynomial Multiplication Using NTT

In [23]:
import numpy as np
import gmpy2
from gmpy2 import mpz

# Parameters for NTT
MODULUS = mpz('52435875175126190479447740508185965837690552500527637822603658699938581184513')
ROOT_OF_UNITY_BASE = mpz('10238227357739495823651030575849232062558860180284477541189508159991286009131') # 2^32

def get_root_of_unity(n):
    logn = n.bit_length() - 1
    s = 32
    if n != (1 << logn):
        raise ValueError("Expected n to be a power of 2.")
    if logn > s:
        raise ValueError("Expected logn <= s.")
    
    # Calculate the root of unity for the given n
    omega = ROOT_OF_UNITY_BASE
    for _ in range(s, logn, -1):
        omega = gmpy2.powmod(omega, 2, MODULUS)
    
    return omega

def ntt(poly, omega, modulus, is_inv = False):
    n = len(poly)
    logn = n.bit_length() - 1
    array = [mpz(e) for e in poly]
    
    # Bit-reverse permutation
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j & bit:
            j ^= bit
            bit >>= 1
        j ^= bit
        if i < j:
            array[i], array[j] = array[j], array[i]

    if is_inv == True:
        omega = gmpy2.invert(omega, modulus)
    
    # Iterative NTT
    group_size = 1
    for stage in range(1, logn+1):
        group_size *= 2
        # group_omega_mult = gmpy2.powmod(omega, n//group_size, modulus)   # CC
        group_omega_mult = gmpy2.powmod(omega, 2*n//group_size, modulus) # NWC
        for group in range(0, n, group_size):
            # factor = mpz(1) # CC
            factor = omega  # NWC
            for operation in range(group_size // 2):
                bw = (factor * array[group + operation + group_size // 2]) % modulus
                a  = array[group + operation]
                array[group + operation]                   = (a + bw) % modulus
                array[group + operation + group_size // 2] = (a - bw) % modulus
                factor = (factor * group_omega_mult) % modulus
    
    if(is_inv):
        n_inv = pow(n, -1, modulus)
        array = [(x * n_inv) % modulus for x in array]

    return array

# Function to perform polynomial multiplication using NWC
def polynomial_multiplication_ntt(a, b):
    n = len(a)
    omega = get_root_of_unity(2*n)
    print("omega :", omega)
    modulus = MODULUS

    ntt_a = ntt(a, omega, modulus)
    ntt_b = ntt(b, omega, modulus)
    ntt_c = [(ntt_a[i] * ntt_b[i]) % modulus for i in range(n)]
    c = ntt(ntt_c, omega, modulus, is_inv=True)
    
    return c

def print_polynomial(a):
    print('-'*50)
    for i, coeff in enumerate(a):
        print(f"Coefficient {i}: {hex(coeff)} ({coeff})")
    print('-'*50)


# Test Code
original_poly = [1, 2, 3, 4]
modulus = mpz('0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001')

omega = get_root_of_unity(2*len(original_poly))

ntt_result = ntt(original_poly, omega, modulus)
inv_result = ntt(ntt_result, omega, modulus, is_inv=True)

print("# NTT Test")
print_polynomial(original_poly)
print_polynomial(inv_result)

# NTT Test
--------------------------------------------------
Coefficient 0: 0x1 (1)
Coefficient 1: 0x2 (2)
Coefficient 2: 0x3 (3)
Coefficient 3: 0x4 (4)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0x17eca7483bff965dbb7b51f57d45212903bf46cb83d2d5d5b4e01d1972ff3ff (676332872736091715400932514805816817517058529951182600044600057210235515903)
Coefficient 1: 0x25445706b42db4fbb8700f1eba81edce3b3a4d680b1db0c9c47dffac3a36bbd3 (16856321630652652541711063120926853213552839024594827881273456975097823149011)
Coefficient 2: 0xf130fef4fccb1eca790c9e4675a164dda55545e2cfa8636adac7d875d3d83a8 (6818372801006118123099606233844428212578938972943040083795105839868077245352)
Coefficient 3: 0x5bf12a10628c24f182738a1288ef557d383a6ea08b1ad2a56d5872fdc6a109e8 (41586569678772800452265997504835559933602557602265551713680108862036018006504)
--------------------------------------------------


In [24]:
# Read input polynomials from binary files
a = [mpz(1), mpz(2)]
b = [mpz(3), mpz(5)]

# Perform polynomial multiplication using NWC and NTT
c = polynomial_multiplication_ntt(a, b)

# Print the result
print_polynomial(a)
print_polynomial(b)
print_polynomial(c)

omega : 3465144826073652318776269530687742778270252468765361963008
--------------------------------------------------
Coefficient 0: 0x1 (1)
Coefficient 1: 0x2 (2)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0x3 (3)
Coefficient 1: 0x5 (5)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0x4f7e03342261b2b584c1b0016261b00000008fffffffffff7 (31186303434662870868986425776189685004432272218888257667063)
Coefficient 1: 0x11aa3999cec0609a1d8060004ec0600000002000000000002 (6930289652147304637552539061375485556540504937530723926018)
--------------------------------------------------


In [52]:
# Double Check C++ output

def read_bigint_file(filename, element_size=32):
    values = []
    with open(filename, "rb") as f:
        while True:
            data = f.read(element_size)
            if not data:
                break
            value = int.from_bytes(data, byteorder='little', signed=False)
            values.append(mpz(value))
    return values

def main():
    # Read input polynomials from binary files
    a = read_bigint_file("../build/data/input_a.txt")
    b = read_bigint_file("../build/data/input_b.txt")
    
    # Parameters for NTT
    modulus = mpz('0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001')
    omega = mpz('23674694431658770659612952115660802947967373701506253797663184111817857449850')
    # Primitive 2n-th root of unity

    # Perform polynomial multiplication using NWC and NTT
    c = polynomial_multiplication_ntt(a, b, omega, modulus)

    # Print the result
    print_polynomial(a)
    print_polynomial(b)
    print_polynomial(c)

if __name__ == "__main__":
    main()


--------------------------------------------------
Coefficient 0: 0x103ddbb56511e8ad72635cf9fcb37761303b3357b585ddef4196433a9e05aeb1 (7346299621127946297075567016002085968144408807773727290984631736966461107889)
Coefficient 1: 0x61985853a9a672a5e2c58180c6335eb2654606026d8b7d8faa1510dc01cc570c (44143516675643494996759749620782296438877098200381385840048821502585403627276)
Coefficient 2: 0x8783ccde718c38770a453132f7730d10f19ccc27b28049ed5c860bb77ba5732 (3830944092346738032114966288029954732053062639602335535784584908736449632050)
Coefficient 3: 0x66d98795a1b03fb5f62df7f91cfaf7df144595481193a4d07eb8ae0bcc151f9c (46520252138366734683603620910390867249064017136673153941519093918913103470492)
--------------------------------------------------
--------------------------------------------------
Coefficient 0: 0x6737126bed7b2f0236e93a38fcb61b9f5edb692295571e55c0f53ee63303b5fb (46685527133797472686420480590910538180586674551187103060334512123811375330811)
Coefficient 1: 0x34515984d975851fe2b3d44

In [18]:
A = 0x69a06a3b3f6677a49a1bbbc34e74e4fad0f07d0ded2040f0679455341ae888c9
B = 0x5f699348721708d60f83e3205a8a172b30584d19880212f22b6d0c83fad7c5e6
P = 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
omega = 3465144826073652318776269530687742778270252468765361963008

A_o = (A + B * omega) % P
A_e = (A - B * omega) % P

print(hex(o))
print(hex(e))

0x3ff007093d47cd14370f5585a6f9a02850bdac81591dce2ad4de45951169918e
0x1f63261a17e7a4ecc9ee49f8ec4e51c7fd65a997812457b6fa4a64d424678003


### Root of Unity Check

In [3]:
import gmpy2

# Given values
_modulus = mpz("52435875175126190479447740508185965837690552500527637822603658699938581184513")
_primitive_root = mpz("52435875175126190479447740508185965837690552500527637822603658699938581184512")

n = 0
_poly_size = 1
for i in range (32):
    _poly_size *= 2
    if gmpy2.powmod(_primitive_root, _poly_size, _modulus) == 1:
        n = _poly_size.bit_length() - 1
        break

if n == 0:
    print("Not a Primitive Root")
else : 
    print("Primitive Root of 2^" + str(n))



Primitive Root of 2^1
