## Polynomial Multiplication Using NTT

In [6]:
import numpy as np
import gmpy2
from gmpy2 import mpz

# Parameters for NTT
MODULUS = mpz('52435875175126190479447740508185965837690552500527637822603658699938581184513')
ROOT_OF_UNITY_BASE = mpz('10238227357739495823651030575849232062558860180284477541189508159991286009131') # 2^32

def get_root_of_unity(n):
    logn = n.bit_length() - 1
    s = 32
    if n != (1 << logn):
        raise ValueError("Expected n to be a power of 2.")
    if logn > s:
        raise ValueError("Expected logn <= s.")
    
    # Calculate the root of unity for the given n
    omega = ROOT_OF_UNITY_BASE
    for _ in range(s, logn, -1):
        omega = gmpy2.powmod(omega, 2, MODULUS)
    
    return omega

def primitive_root_check(modulus, primitive_root):
    n = 0
    _poly_size = 1
    for i in range (32):
        _poly_size *= 2
        if gmpy2.powmod(primitive_root, _poly_size, modulus) == 1:
            n = _poly_size.bit_length() - 1
            break

    if n == 0:
        print("Not a Primitive Root")
    else : 
        print("Primitive Root of 2^" + str(n))

def ntt(poly, omega, modulus, is_inv = False):
    n = len(poly)
    logn = n.bit_length() - 1
    array = [mpz(e) for e in poly]
    
    # Bit-reverse permutation
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j & bit:
            j ^= bit
            bit >>= 1
        j ^= bit
        if i < j:
            array[i], array[j] = array[j], array[i]

    if is_inv == True:
        omega = gmpy2.invert(omega, modulus)
    
    # Iterative NTT
    group_size = 1
    for stage in range(1, logn+1):
        group_size *= 2
        group_omega_mult = gmpy2.powmod(omega, n//group_size, modulus)
        for group in range(0, n, group_size):
            factor = mpz(1)
            for operation in range(group_size // 2):
                bw = (factor * array[group + operation + group_size // 2]) % modulus
                a  = array[group + operation]
                array[group + operation]                   = (a + bw) % modulus
                array[group + operation + group_size // 2] = (a - bw) % modulus
                factor = (factor * group_omega_mult) % modulus
    
    if(is_inv):
        n_inv = pow(n, -1, modulus)
        array = [(x * n_inv) % modulus for x in array]

    return array

# Function to perform polynomial multiplication using NWC
def polynomial_multiplication_ntt(a, b):
    n = len(a)
    omega = get_root_of_unity(n)
    print("omega :", hex(omega))
    primitive_root_check(MODULUS, omega)
    modulus = MODULUS

    ntt_a = ntt(a, omega, modulus)
    ntt_b = ntt(b, omega, modulus)
    ntt_c = [(ntt_a[i] * ntt_b[i]) % modulus for i in range(n)]
    c = ntt(ntt_c, omega, modulus, is_inv=True)
    
    return c

def print_polynomial(a):
    print('-'*50)
    for i, coeff in enumerate(a):
        print(f"Coefficient {i}: {hex(coeff)} ({coeff})")
    print('-'*50)

# ===========================================================
# Test Code
## Parameter Check
_modulus = MODULUS
_primitive_root = ROOT_OF_UNITY_BASE

print("# Parameter Check")
print("- Modulus:", hex(MODULUS))
print("- Base Primitive Root:", hex(_primitive_root))
print("=> ", end ='')
primitive_root_check(_modulus, _primitive_root)
print()

## NTT & INTT Check
print("# NTT Test")
original_poly = [1, 2, 3, 4]
omega = get_root_of_unity(len(original_poly))
print([int(x) for x in original_poly])
ntt_result = ntt(original_poly, omega, MODULUS)
print([int(x) for x in ntt_result])
inv_result = ntt(ntt_result, omega, MODULUS, is_inv=True)
print([int(x) for x in inv_result])
print()

## Polynomial Multiplication Test
print("# Example Test")
print("(2x+1)(5x+3) => (11x+13)")
# Read input polynomials from binary files
a = [mpz(1), mpz(2)]
b = [mpz(3), mpz(5)]
c = polynomial_multiplication_ntt(a, b)
print([int(x) for x in a])
print([int(x) for x in b])
print([int(x) for x in c])

# Parameter Check
- Modulus: 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
- Base Primitive Root: 0x16a2a19edfe81f20d09b681922c813b4b63683508c2280b93829971f439f0d2b
=> Primitive Root of 2^32

# NTT Test
[1, 2, 3, 4]
[10, 52435875175126190472517450856038661200138013439152152266063153762407857258495, 52435875175126190479447740508185965837690552500527637822603658699938581184511, 6930289652147304637552539061375485556540504937530723926014]
[1, 2, 3, 4]

# Example Test
(2x+1)(5x+3) => (11x+13)
omega : 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000000
Primitive Root of 2^1
[1, 2]
[3, 5]
[13, 11]


## Test
### Verify Multiplication

In [9]:
# Double Check C++ output

def read_bigint_file(filename, element_size=32):
    values = []
    with open(filename, "rb") as f:
        while True:
            data = f.read(element_size)
            if not data:
                break
            value = int.from_bytes(data, byteorder='little', signed=False)
            values.append(mpz(value))
    return values

def verify_output(polynomial, filename, element_size=32):
    # Read values from the output file
    file_value = read_bigint_file(filename, element_size)

    # Check if the lengths of the calculated values and output values match
    if len(polynomial) != len(file_value):
        print("Length mismatch: calculated ({}) vs file ({})".format(len(polynomial), len(file_value)))
        return False

    # Compare each element
    for i in range(len(polynomial)):
        if polynomial[i] != file_value[i]:
            print("Mismatch found at index {}: calculated ({}) vs file ({})".format(i, polynomial[i], file_value[i]))
            return False

    print("[+] Verified")
    return True

def main():
    # Read input polynomials from binary files
    a = read_bigint_file("../build/data/input_a.txt")
    b = read_bigint_file("../build/data/input_b.txt")
    
    # Parameters for NTT
    modulus = mpz('0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001')

    # Perform polynomial multiplication using NWC and NTT
    c = polynomial_multiplication_ntt(a, b)
    # Print the result
    verify_output(c, "../build/data/output_c.txt")

if __name__ == "__main__":
    main()


omega : 0x50e0903a157988bab4bcd40e22f55448bf6e88fb4c38fb8a360c60997369df4e
Primitive Root of 2^5
[+] Verified


### DRAM Access

In [9]:
import numpy as np
import gmpy2
from gmpy2 import mpz, random_state

# Parameters for NTT
MODULUS = mpz('52435875175126190479447740508185965837690552500527637822603658699938581184513')
ROOT_OF_UNITY_BASE = mpz('10238227357739495823651030575849232062558860180284477541189508159991286009131') # 2^32

def get_root_of_unity(n):
    logn = n.bit_length() - 1
    s = 32
    if n != (1 << logn):
        raise ValueError("Expected n to be a power of 2.")
    if logn > s:
        raise ValueError("Expected logn <= s.")
    
    # Calculate the root of unity for the given n
    omega = ROOT_OF_UNITY_BASE
    for _ in range(s, logn, -1):
        omega = gmpy2.powmod(omega, 2, MODULUS)
    
    return omega

def primitive_root_check(modulus, primitive_root):
    n = 0
    _poly_size = 1
    for i in range (32):
        _poly_size *= 2
        if gmpy2.powmod(primitive_root, _poly_size, modulus) == 1:
            n = _poly_size.bit_length() - 1
            break

    if n == 0:
        print("Not a Primitive Root")
    else : 
        print("Primitive Root of 2^" + str(n))

def ntt(poly, omega, modulus, is_inv = False):
    n = len(poly)
    logn = n.bit_length() - 1
    array = [mpz(e) for e in poly]
    
    # Bit-reverse permutation
    j = 0
    for i in range(1, n):
        bit = n >> 1
        while j & bit:
            j ^= bit
            bit >>= 1
        j ^= bit
        if i < j:
            array[i], array[j] = array[j], array[i]

    if is_inv == True:
        omega = gmpy2.invert(omega, modulus)
    
    # Iterative NTT
    group_size = 1
    for stage in range(1, logn+1):
        group_size *= 2
        group_omega_mult = gmpy2.powmod(omega, n//group_size, modulus)
        for group in range(0, n, group_size):
            factor = mpz(1)
            for operation in range(group_size // 2):
                bw = (factor * array[group + operation + group_size // 2]) % modulus
                a  = array[group + operation]
                array[group + operation]                   = (a + bw) % modulus
                array[group + operation + group_size // 2] = (a - bw) % modulus
                factor = (factor * group_omega_mult) % modulus
    
    if(is_inv):
        n_inv = pow(n, -1, modulus)
        array = [(x * n_inv) % modulus for x in array]

    return array


def generate_random_poly(size, modulus):
    rand_state = random_state()  # GMP random state 초기화
    return [gmpy2.mpz_urandomb(rand_state, modulus.bit_length()) % modulus for _ in range(size)]


poly_size = 2**9
np.random.seed(42)
original_poly = generate_random_poly(poly_size, MODULUS)
omega = get_root_of_unity(poly_size)

print("# Parameter Check")
print("- Modulus:", hex(MODULUS))
print("- Omega:", hex(omega))
print("=> ", end ='')
primitive_root_check(MODULUS, omega)


ntt_result = ntt(original_poly, omega, MODULUS)
inverse_ntt_result = ntt(ntt_result, omega, MODULUS, is_inv=True)

if original_poly == inverse_ntt_result:
    print("\n[Success] NTT & INTT Results are Same")
else:
    print("\n[Fail] NTT & INTT Results are Different")

# Parameter Check
- Modulus: 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001
- Omega: 0x95166525526a65439feec240d80689fd697168a3a6000fe4541b8ff2ee0434e
=> Primitive Root of 2^9

[Success] NTT & INTT Results are Same
