In [126]:
Zx.<x> = ZZ[]

#k = 3
#n = 2^k
#m = 2 * n
#q = 12 * 1024 + 1
#phi = x^n + 1
#sigma = 1.17 / sqrt(q / (2*n))   
RCDT = [3024686241123004913666, 1564742784480091954050, 636254429462080897535, 199560484645026482916, 47667343854657281903, 8595902006365044063, 1163297957344668388, 117656387352093658, 8867391802663976, 496969357462633, 20680885154299, 638331848991, 14602316184, 247426747, 3104126, 28824, 198, 1, 0]

In [127]:
def balance(f,q,n):
    g = list(((f[i] + q//2) % q) - q//2 for i in range(n))
    return Zx(g)

In [128]:
def merge(a,b,n):
    a = a.subs(x=x^2)
    b = x*b.subs(x=x^2)
    return a + b 

In [129]:
def split(f,n):
    f0 = list(f[2*i] for i in range(n//2+1))
    f1 = list(f[2*i+1] for i in range(n//2+1))
    return Zx(f0), Zx(f1)

In [150]:
def InnerProduct(a,b,n):
    s = [a[i]*b[i] for i in range(n)]
    return sum(s)

In [136]:
def EuclideanNorm(a,n):
    b = InnerProduct(a,a,n)
    return sqrt(float(b))

In [137]:
def FieldNorm(f, n):
    f0, f1 = split(f,n)
    iks = f.parent()([0, 1])
    return (f0^2 - iks * f1^2) % (iks^(n//2)+1)

In [138]:
def HermitianAdjointPoly(p, n):
    f=[p[0]]
    for i in range(1,n):    
        f.append(-p[n-i])
    return Zx(f) 

In [147]:
def gs_norm(f, g, q, n):
    T = Zx.change_ring(QQ).quotient(x^n+1) 
    # Using (3.9) with (3.8) or (3.10)    
    f_star = HermitianAdjointPoly(f, n)
    g_star = HermitianAdjointPoly(g, n)
    first = EuclideanNorm([*g.coefficients(sparse=False), *(-f).coefficients(sparse=False)], n)
    s1 = (q * T(f_star)) / T((f*f_star + g*g_star))
    s2 = (q * T(g_star)) / T((f*f_star + g*g_star))
    second = EuclideanNorm(list(s1) + list(s2), n)
    gamma = max(first, second)
    return gamma

In [166]:
from os import urandom

LN2 = 0.69314718056

def UniformBits(k):
    return int.from_bytes(bytes(list(floor(uniform(0, 256)) for i in range(k / 8))), 'little')

def BaseSampler():
    u = UniformBits(72)
    z_0 = 0
    RCDT = [3024686241123004913666, 1564742784480091954050, 636254429462080897535, 199560484645026482916, 47667343854657281903, 8595902006365044063, 1163297957344668388, 117656387352093658, 8867391802663976, 496969357462633, 20680885154299, 638331848991, 14602316184, 247426747, 3104126, 28824, 198, 1, 0]
    for i in range(0, 18):
        z_0 = z_0 + int(u<RCDT[i]) 
    return z_0

def ApproxExp(x, ccs):
    C = [0x00000004741183A3,0x00000036548CFC06,0x0000024FDCBF140A,0x0000171D939DE045,0x0000D00CF58F6F84,0x000680681CF796E3,0x002D82D8305B0FEA,0x011111110E066FD0,0x0555555555070F00,0x155555555581FF00,0x400000000002B400,0x7FFFFFFFFFFF4800,0x8000000000000000]
    y = C[0]
    z = floor(2^63*x)
    for i in range(1, 13):
        y = C[i] - (z*y) >> 63
    z = floor(2^63*ccs)
    y = (z*y) >> 63
    return y

def BerExp(x, ccs):
    s = floor(x/LN2)
    r = x - s*LN2
    s = min(s, 63)
    z = (2*ApproxExp(r, ccs) - 1) >> s
    for i in range(56, -8, -8):
        p = UniformBits(8)
        w = p - ((z >> i) & 0xFF)
        if int(w) == 0:
            break
    return int(w < 0)

def SamplerZ(mu, sigma, sigmamin, sigmamax):
    r = mu - int(floor(mu))
    ccs = sigmamin/sigma
    while True:
        z_0 = BaseSampler()
        b = UniformBits(8)&0x1
        z = b + (2*b-1)*z_0
        x = (z-r)^2/2/sigma^2 - z_0^2/2/sigmamax^2
        if BerExp(x, ccs) == 1:
            return z + int(floor(mu))

 

In [141]:
def gen_poly(n, q):
    phi = x^n + 1
    def D(mu=0):
        z = 0
        for i in range(1, 4096/n + 1):
            sigma_star = 1.17 * sqrt(q / 8192)
            sigmamin, sigmamax = 1.277833697, 1.8205
            zi = SamplerZ(mu, sigma_star, sigmamin, sigmamax)
            z += zi
        return z

    f = [0] * n
    g = [0] * n
    for i in range(n):
        f[i] = D()
        g[i] = D()
    f = Zx(f) % phi
    g = Zx(g) % phi
    return f, g


(166*x^7 + 128*x^6 + 186*x^5 + 122*x^4 + 151*x^3 + 59*x^2 + 161*x + 114, 114*x^7 + 166*x^6 + 106*x^5 + 103*x^4 + 128*x^3 + 143*x^2 + 66*x + 209)


In [152]:
def Reduce(f, g, F, G, n):
    
    T = Zx.change_ring(QQ).quotient(x^n+1) 
    
    iks = f.parent()([0, 1])
    f_star = HermitianAdjointPoly(f, n)
    g_star = HermitianAdjointPoly(g, n)
    while True:
        num = F*f_star + G*g_star
        num = T(num)
        den = f*f_star + g*g_star
        den = 1 / T(den)
        res = num * den
        k = Zx([int(round(elt)) for elt in res])
        F = F - k*f 
        G = G - k*g
        if all(elt == 0 for elt in k):
            break
    return f, g, F, G

In [144]:
def NTT(f, n, q):
    # Zp
    roots = (x^n + 1).roots(Integers(q))
    ans = [f.subs(x = i[0]) % q for i in roots]
    return ans

In [169]:
def NTRUSolve(f, g, n, q):
    if n == 1:
        # u, v are numbers
        gcd_, u, v = xgcd(f[0], g[0])
        print("gcd", u * f + v * g, "\n")
        if gcd_ != 1:
            return None, None, False
        F, G = -v*q, u*q
        print("F1, G1", F / q, G / q)
        return F, G, True
    else:
        # ▷ f′, g′, F′, G′ ∈ Z[x]/(x^n/2 + 1)
        # ▷ N as defined in either (3.25) or (3.26)
        f_ = FieldNorm(f, n) 
        g_ = FieldNorm(g, n) 
        print(n//2, f_, g_, sep="\n")
        F_, G_, flag = NTRUSolve(f_, g_, n//2, q)
        if flag:
            F = F_.subs(x=x^2) * g.subs(x=-x) 
            G = G_.subs(x=x^2) * f.subs(x=-x)
            print("F, G", F, G, sep="\n")
            f, g, F, G = Reduce(f, g, F, G, n)
            return F % (x^n +1), G % (x^n +1), flag
        else:
            return F_, G_, flag
            
        

In [170]:
def NTRUGen(q, n):
    # sigma = 1.17 / sqrt(q / (2*n)) 
    while True:

        f = -x^7 + 3*x^6 - x^4 + 4*x^3 + 6*x^2 - 2*x - 4
        g = x^7 - x^6 - 2*x^5 - 4*x^3 - 3*x^2 - x + 7
        
        if gs_norm(f, g, q, n) > (1.17 ** 2) * q:
            print("restart")

        print(f, g, sep="\n")
                
        F, G, flag = NTRUSolve(f, g, n, q)
        F, G = F % (x^n +1), G % (x^n +1)
        print("(f*G - g*F) % (x^n + 1) == q: ", (f*G - g*F) % (x^n + 1) == q)
        
        if not flag:
            continue
        else:
            F = Zx([int(coef) for coef in F.coefficients(sparse=False)])
            G = Zx([int(coef) for coef in G.coefficients(sparse=False)])
            break
            
    return f, g, F, G
    

NTRUGen(17, 8)

-x^7 + 3*x^6 - x^4 + 4*x^3 + 6*x^2 - 2*x - 4
x^7 - x^6 - 2*x^5 - 4*x^3 - 3*x^2 - x + 7
4
-51*x^3 + 51*x^2 - 54*x - 17
-33*x^3 - 4*x^2 - 47*x + 57
2
-2049*x + 3196
-1576*x + 6335
1
14412817
42616001
gcd 1 

F1, G1 5126443 15157932
F, G
137347660856*x + 552092278885
527996245356*x + 823560761424
F, G
15015*x^5 - 1820*x^4 + 100354*x^3 + 16363*x^2 + 112471*x + 136401
105060*x^5 + 105060*x^4 + 297237*x^3 + 150977*x^2 + 196938*x - 61999
F, G
-5*x^13 - 5*x^12 - 6*x^11 - 16*x^10 + 68*x^9 + x^8 + 17*x^7 - 33*x^6 - 8*x^5 + 160*x^4 + 64*x^3 - 172*x^2 + 20*x + 140
26*x^13 + 78*x^12 + 3*x^11 - 17*x^10 - 109*x^9 + 138*x^8 + 44*x^7 - 69*x^6 + 26*x^5 - 46*x^4 - 26*x^3 + 44*x^2 + 8*x - 16
(f*G - g*F) % (x^n + 1) == q:  True


(-x^7 + 3*x^6 - x^4 + 4*x^3 + 6*x^2 - 2*x - 4,
 x^7 - x^6 - 2*x^5 - 4*x^3 - 3*x^2 - x + 7,
 2*x^7 + x^6 + 2*x^5 - x^4 - 5*x^3 + 2*x - 5,
 -4*x^7 - x^6 - x^5 + x^4 + 3*x^3 + 2*x^2 - 2*x + 3)

In [178]:
def NTRUGen(q, n):
    # sigma = 1.17 / sqrt(q / (2*n)) 
    while True:
        while True:
            while True:

                f, g = gen_poly(n, q)
                print(f, g, sep="\n")

                if gs_norm(f, g, q, n) > (1.17 ** 2) * q:
                    print("restart norm\n")
                    continue
                break

            f_ntt = NTT(f, n, q)
            if any((elem == 0) for elem in f_ntt):
                print("restart ntt\n")
                continue
            break
                
        F, G, flag = NTRUSolve(f, g, n, q)
        
        if not flag:
            print("restart solve")
            continue
        else:
            F, G = F % (x^n +1), G % (x^n +1)
            F = Zx([int(coef) for coef in F.coefficients(sparse=False)])
            G = Zx([int(coef) for coef in G.coefficients(sparse=False)])
            print("(f*G - g*F) % (x^n + 1) == q", (f*G - g*F) % (x^n + 1) == q)
            break
            
    return f, g, F, G
    

NTRUGen(12 * 1024 + 1, 4)

373*x^3 + 303*x^2 + 206*x + 159
252*x^3 + 365*x^2 + 335*x + 275
2
193047*x + 87148
152029*x + 111240
1
44861918113
35487154441
gcd 1 

F1, G1 -17397551500 -13761997323
F, G
32503573735093121500*x - 23782946295060540000
32648337420452781309*x - 14738573039299336356
F, G
-14930748*x^5 + 21625885*x^4 - 16478419*x^3 + 11412330*x^2 + 4479955*x - 3677575
-20147968*x^5 + 16366848*x^4 - 10838221*x^3 + 8353719*x^2 + 159650*x - 123225
(f*G - g*F) % (x^n + 1) == q True


(373*x^3 + 303*x^2 + 206*x + 159,
 252*x^3 + 365*x^2 + 335*x + 275,
 205*x^3 + 95*x^2 + 107*x + 29,
 163*x^3 + 142*x^2 + 134*x + 104)

In [93]:
def KeyGen(sigma, q, n):
    f, g, F, G = NTRUGen(q, n)
    B = PolyToLattice4(g, G, -f, -F, n)
    B_ = FFT(B)
    B_adj = HermitianAdjointMatrix(B_, n2)
    G = mult_matrix(B_, B_adj)
    # Computing the LDL⋆ tree
    T = LDL(g00, g01, g10, g11, n)
    
    for leaf in T:
        leaf.value = sigma / sqrt(leaf.value)
        
    sk = (B_, T)
    f_q = invertmodprime(f, q, n)
    h = g * f_q % p
    pk = h
    return sk, pk