In [370]:
def center(x,q):
    if q % 1 == 0:
        qhalf = Integer((q-1)/2)
        
        return ((Integer(x) + qhalf) % q) - qhalf;
    else:
        return 0;

In [371]:
def find_n_cyclotomic_trinomial(lbits,hbits):
    set = [];
    for i in range(1,hbits+1):
        for j in range(0,hbits+1):
            n = 2^i*3^j;
            
            if n > 2^hbits or n < 2^lbits:
                continue;
            else:
                set.append(n);
    
    set.sort();
            
    return set;

In [372]:
def find_ntt_prime_cyclotomic(n,b,lbits,hbits):
    
    fac = factor(n);
    
    if len(fac) > 2 or len(fac) == 0:
        #print("does not support cyclotomic trinomial");
        return 0;

    if len(fac) == 1:
        if fac[0][0] != 2:
            #print("does not support cyclotomic trinomial");
            return 0;
            
    if len(fac) == 2:
        if fac[0][0] != 2 or fac[1][0] != 3:
            #print("does not support cyclotomic trinomial");
            return 0;
        
    if ((n % b) != 0) or (n/2 < b):
        #print("please select adequete b");
        return 0;
    
    w_order = 3*n/b;
    
    qs =[];

    k = 1;

    while True:
        q = w_order*k+1;

        k += 1;

        if q < 2^lbits:
            continue;
            
        if q > 2^hbits:
            break;

        if q in Primes():
            qs.append(q);

    return qs;

In [373]:
def find_generator(q):

    Zq = IntegerModRing(q);

    gs = range(1,q);

    for x in list(factor(q-1)):
        p = x[0];

        t = [];

        for g in gs:
            if Zq(g)^((q-1)/p) != 1:
                t.append(Zq(g));

        gs = t;

    return gs;

In [374]:
def find_w_order_cyclotomic_trinomial(n,b):

    return 3*n/b;

In [375]:
def find_w_cyclotomic_trinomial(n,q,b):

    w_order = find_w_order_cyclotomic_trinomial(n,b);
    
    k = Integer((q-1)/w_order);
    
    ws = [g^k for g in find_generator(q)];
    
    return sorted(list(set(ws)));

In [376]:
def gen_zetas(n,q,b):
    
    Zq = IntegerModRing(q);
    
    fac = factor(n/b);
    
    if len(fac) == 2:
        Radix2 = fac[0][1];
        Radix3 = fac[1][1];
    
    if len(fac) == 1:
        if fac[0][0] == 2:
            Radix2 = fac[0][1];
            Radix3 = 0;

        if fac[0][0] == 3:
            Radix2 = 0;
            Radix3 = fac[0][1];

    level = Radix2+Radix3;

    tree = zero_matrix(ZZ,level+1,n/b);
    tree[0,0] = find_w_order_cyclotomic_trinomial(n,b);

    zetas = [center(Integer(Zq(2^16)),q)];
                
    #Radix-2 NTT with Cyclotomic Polynomial
    tree[1,0] = tree[0,0] / 6;
    tree[1,1] = 5*tree[0,0] / 6;

    x = Integer(Zq(w)^(tree[1,0]) * 2^16);
    x = center(x, q);
    zetas.append(x);
    
    #Radix-3 NTT
    for l in range(1,Radix3+1):
        for i in range(2*3^(l-1)):
            tree[l+1,3*i  ] = tree[l  ,  i] / 3;
            tree[l+1,3*i+1] = tree[l  ,  i] / 3 + tree[0,0]/3;
            tree[l+1,3*i+2] = tree[l  ,  i] / 3 + 2*tree[0,0]/3;

            x = center(Zq(w)^(tree[l+1,3*i]) * 2^16, q);
            zetas.append(x);
            x = center(Zq(w)^(tree[l+1,3*i]*2) * 2^16, q);
            zetas.append(x);

    #Radix-2 NTT
    for l in range(Radix3+1,level):
        for i in range(2*3^(Radix3)*2^(l-(Radix3+1))):
            tree[l+1,2*i  ] = tree[l  ,  i] / 2;
            tree[l+1,2*i+1] = tree[l  ,  i] / 2 + tree[0,0]/2;

            x = center(Zq(w)^(tree[l+1,2*i]) * 2^16, q);
            zetas.append(x);
            
    return zetas;    

In [377]:
def print_zetas(zetas):
    
    l = len(zetas)
    str1 = "const int16_t zetas[%d] = {" % l
    print(str1)
    for i in range(ceil(l/8)-1):
        str = "\t";
        for j in range(7):
            str += "%5d, " %zetas[8*i+j];
        str += "%5d," %zetas[8*i+7];
        print(str);

    str = "\t";
    for j in range((l%8-1)%8):
        str += "%5d, " %zetas[8*ceil(l/8)-8+j];
    str += "%5d" %zetas[l-1];  
    print(str);    
    print("};");

In [378]:
def print_params(n,q,d,w):
    
    zetas = gen_zetas(n,q,b);
    print_zetas(zetas);

    Zq = IntegerModRing(q);

    print("====ntt====");
    print("w %d" %center(Integer(Zq(w)^(find_w_order_cyclotomic_trinomial(n,b)/3)*2^16),q));

    print("====invntt====");
    print("w %d" %center(Integer(Zq(w)^(find_w_order_cyclotomic_trinomial(n,b)/3)*2^16),q));
    z = Zq(w)^(find_w_order_cyclotomic_trinomial(n,b)/6);
    print("(z - z^5)^-1 %d" %center(Integer(Zq(2^16) * Zq(z - z^5)^-1),q))
    print("2^-1 %d" %center(Integer(Zq(2^-1 *2^16)),q))
    
    
    fac = factor(n/b);
    
    if len(fac) == 2:
        Radix2 = fac[0][1];
        Radix3 = fac[1][1];
    
    if len(fac) == 1:
        if fac[0][0] == 2:
            Radix2 = fac[0][1];
            Radix3 = 0;

        if fac[0][0] == 3:
            Radix2 = 0;
            Radix3 = fac[0][1];

    print("level %d" %center(Integer(Zq(2^16*2^(-Radix2+1)*3^(-Radix3))),q))
    print("====basemul====");    
    print("basemul : %d" %center(Integer(Zq(2^32)),q))
    print("====reduce.h====");    
    print("QINV : %d" %(q^-1 % 2^16))



In [379]:
n_set = find_n_cyclotomic_trinomial(9,11);
b_set = [1,2,3,4,6];

In [380]:
for n in n_set:
    print("n : ", n);
    for b in b_set:
        qs = find_ntt_prime_cyclotomic(n,b,0,15);
        if qs != 0:
            print("b : ", b, ", qs : ", qs);
    print("");

n :  512
b :  1 , qs :  [7681, 10753, 12289, 15361, 18433, 23041, 26113, 32257]
b :  2 , qs :  [769, 7681, 10753, 12289, 14593, 15361, 18433, 22273, 23041, 26113, 26881, 31489, 32257]
b :  4 , qs :  [769, 1153, 2689, 3457, 4993, 6529, 7297, 7681, 9601, 10369, 10753, 12289, 13441, 14593, 15361, 18049, 18433, 20353, 21121, 22273, 23041, 26113, 26497, 26881, 29569, 31489, 31873, 32257]

n :  576
b :  1 , qs :  [3457, 8641, 10369, 12097, 19009]
b :  2 , qs :  [2593, 3457, 8641, 10369, 12097, 16417, 19009, 21601, 25057, 28513, 30241]
b :  3 , qs :  [577, 1153, 3457, 6337, 7489, 8641, 10369, 12097, 13249, 14401, 18433, 19009, 20161, 21313, 23041, 26497, 27073, 30529, 32257]
b :  4 , qs :  [433, 1297, 2161, 2593, 3457, 3889, 6481, 8209, 8641, 10369, 12097, 15121, 16417, 17713, 19009, 19441, 21169, 21601, 23761, 25057, 28081, 28513, 30241, 32401]
b :  6 , qs :  [577, 1153, 2017, 2593, 3169, 3457, 6337, 7489, 8353, 8641, 8929, 10369, 10657, 12097, 13249, 13537, 14401, 16417, 16993, 17569, 18433

In [385]:
n = 1296
q = 9721
b = 4

In [386]:
ws = find_w_cyclotomic_trinomial(n,q,b);

In [387]:
w = min(ws)

In [388]:
print_params(n,q,b,w)

const int16_t zetas[324] = {
	-2511,  -173,   147, -4782,  4502,  -147,   894, -4360,
	-3036,  1904,     2, -1866,   383,   561,  -583, -3848,
	 4720,  3036, -2746,    67,   322,  3110,  4713, -4212,
	 1310,  2584, -4253,   245,  3352,  4263,  3347,  3173,
	 1682,  4661,  1856,  -555, -4263,  2832, -4704,  2616,
	-4185,  2746, -1386, -4528, -2535,   925,  -786, -3347,
	  555, -2701,   -58, -4225,   -64,  4253,  2093,   164,
	-3389, -2979, -4410,  2603, -3342, -1800,  2346, -2157,
	-1769,  -592,  1212,    77,  2343,  3284,  3591,  2704,
	 -292,  2676,    13,  3790,  2696, -2101,  3736, -3290,
	  100,  1080,  2792,  -650,  1071,  2679,  1517,  1898,
	-4695,  1918, -1413,  3595, -2900,  4227, -3200, -2274,
	-1896, -2233,  4612,  1974,    61,  -648,  -377, -1098,
	 -416,  2281,  1217, -3083, -4776,  -269, -1918,  1152,
	-1446,  2547,   650, -2975, -1294, -3160,  1589,  3342,
	-4721, -2438,  3506, -1593,  -489,  3389,   592, -3678,
	-1358, -3727,  2524,  1253,   815,   773, -4464,  1785,
	 

In [389]:
Zq = IntegerModRing(q)
center(Integer(Zq(867)^-1*(2^16)),q)

-1382