In [1]:
from Crypto.Hash import SHAKE256
from os import urandom

LN2 = 0.69314718056

n = 512
q = 12 * 1024 + 1
sigma = 1.17 / sqrt(q / (2*n)) 
SALT_LEN = 40
Zx.<x> = ZZ[]
RCDT = [3024686241123004913666, 1564742784480091954050, 636254429462080897535, 199560484645026482916, 47667343854657281903, 8595902006365044063, 1163297957344668388, 117656387352093658, 8867391802663976, 496969357462633, 20680885154299, 638331848991, 14602316184, 247426747, 3104126, 28824, 198, 1, 0]

def balance(f,q,n):
    g = list(((f[i] + q//2) % q) - q//2 for i in range(n))
    return Zx(g)

def split(f,n):
    f0 = list(f[2*i] for i in range(n//2+1))
    f1 = list(f[2*i+1] for i in range(n//2+1))
    return Zx(f0), Zx(f1)

def InnerProduct(a,b,n):
    s = [a[i]*b[i] for i in range(n)]
    return sum(s)

def EuclideanNorm(a,n):
    b = InnerProduct(a,a,n)
    return sqrt(float(b))

def FieldNorm(f, n):
    f0, f1 = split(f,n)
    iks = f.parent()([0, 1])
    return (f0^2 - iks * f1^2) % (iks^(n//2)+1)

def HermitianAdjointPoly(p, n):
    f=[p[0]]
    for i in range(1,n):    
        f.append(-p[n-i])
    return Zx(f) 

def UniformBits(k):
    return int.from_bytes(bytes(list(floor(uniform(0, 256)) for i in range(k / 8))), 'big')

def BaseSampler():
    u = UniformBits(72)
    z_0 = 0
    RCDT = [3024686241123004913666, 1564742784480091954050, 636254429462080897535, 199560484645026482916, 47667343854657281903, 8595902006365044063, 1163297957344668388, 117656387352093658, 8867391802663976, 496969357462633, 20680885154299, 638331848991, 14602316184, 247426747, 3104126, 28824, 198, 1, 0]
    for i in range(0, 18):
        z_0 = z_0 + int(u<RCDT[i]) 
    return z_0

def ApproxExp(x, ccs):
    C = [0x00000004741183A3,0x00000036548CFC06,0x0000024FDCBF140A,0x0000171D939DE045,0x0000D00CF58F6F84,0x000680681CF796E3,0x002D82D8305B0FEA,0x011111110E066FD0,0x0555555555070F00,0x155555555581FF00,0x400000000002B400,0x7FFFFFFFFFFF4800,0x8000000000000000]
    y = C[0]
    z = floor(2^63*x)
    for i in range(1, 13):
        y = C[i] - (z*y) >> 63
    z = floor(2^63*ccs)
    y = (z*y) >> 63
    return y

def BerExp(x, ccs):
    s = floor(x/LN2)
    r = x - s*LN2
    s = min(s, 63)
    z = (2*ApproxExp(r, ccs) - 1) >> s
    for i in range(56, -8, -8):
        p = UniformBits(8)
        w = p - ((z >> i) & 0xFF)
        if int(w) == 0:
            break
    return int(w < 0)

def SamplerZ(mu, sigma, sigmamin, sigmamax):
    r = mu - int(floor(mu))
    ccs = sigmamin/sigma
    while True:
        z_0 = BaseSampler()
        b = UniformBits(8)&0x1
        z = b + (2*b-1)*z_0
        x = (z-r)^2/2/sigma^2 - z_0^2/2/sigmamax^2
        if BerExp(x, ccs) == 1:
            return z + int(floor(mu))

def Reduce(f, g, F, G, n):
    
    T = Zx.change_ring(QQ).quotient(x^n+1) 
    
    iks = f.parent()([0, 1])
    f_star = HermitianAdjointPoly(f, n)
    g_star = HermitianAdjointPoly(g, n)
    while True:
        num = F*f_star + G*g_star
        num = T(num)
        den = f*f_star + g*g_star
        den = 1 / T(den)
        res = num * den
        k = Zx([int(round(elt)) for elt in res])
        F = F - k*f % (x^n + 1)
        G = G - k*g % (x^n + 1)
        if all(elt == 0 for elt in k):
            break
    return f, g, F, G

def NTT(f, n, q):
    # Zp
    roots = (x^n + 1).roots(Integers(q))
    ans = [f.subs(x = i[0]) % q for i in roots]
    return ans

def NTRUSolve(f, g, n, q):
    if n == 1:
        gcd_, u, v = xgcd(f[0], g[0])
        if gcd_ != 1:
            return None, None, False
        F, G = -v*q, u*q
        return F, G, True
    else:
        f_ = FieldNorm(f, n) 
        g_ = FieldNorm(g, n)
        F_, G_, flag = NTRUSolve(f_, g_, n//2, q)
        if flag:
            F = F_.subs(x=x^2) * g.subs(x=-x) 
            G = G_.subs(x=x^2) * f.subs(x=-x)
            f, g, F, G = Reduce(f, g, F, G, n)
            return F % (x^n +1), G % (x^n +1), flag
        else:
            return F_, G_, flag
            
def NTRUGen(q, n):
    
    def gen_poly(n, q):
        
        def D(mu=0):
            z = 0
            for i in range(1, 4096/n + 1):
                sigma_star = 1.17 * sqrt(q / 8192)
                sigmamin, sigmamax = 1.277833697, 1.8205
                zi = SamplerZ(mu, sigma_star, sigmamin, sigmamax)
                z += zi
            return z

        f = [0] * n
        g = [0] * n
        for i in range(n):
            f[i] = D()
            g[i] = D()
        f = Zx(f) % (x^n + 1)
        g = Zx(g) % (x^n + 1)
        return f, g
    
    def gs_norm(f, g, q, n):
        T = Zx.change_ring(QQ).quotient(x^n+1)  
        f_star = HermitianAdjointPoly(f, n)
        g_star = HermitianAdjointPoly(g, n)
        first = EuclideanNorm([*g.coefficients(sparse=False), *(-f).coefficients(sparse=False)], n)
        s1 = (q * T(f_star)) / T((f*f_star + g*g_star))
        s2 = (q * T(g_star)) / T((f*f_star + g*g_star))
        second = EuclideanNorm(list(s1) + list(s2), n)
        gamma = max(first, second)
        return gamma
    
    
    while True:
        while True:
            while True:

                f, g = gen_poly(n, q)

                if gs_norm(f, g, q, n) > (1.17 ** 2) * q:
                    continue
                break

            if  0 in NTT(f, n, q):
                continue
            break
                
        F, G, flag = NTRUSolve(f, g, n, q)
        
        if not flag:
            continue
        else:
            F, G = F % (x^n +1), G % (x^n +1)
            F = Zx([int(coef) for coef in F.coefficients(sparse=False)])
            G = Zx([int(coef) for coef in G.coefficients(sparse=False)])
            break
            
    return f, g, F, G

def NaiveKeyGen(sigma, q, n):
    f, g, F, G = NTRUGen(q, n)
    B = [[g, -f], [G, -F]]
    sk = B
    
    T = Zx.change_ring(Integers(q)).quotient(x^n+1)
    f_q = Zx(lift(1 / T(f)))
    h = g * f_q % (x^n + 1) % q
    h_star = HermitianAdjointPoly(h, n)
    A = [Zx(1), h_star]
    pk = A
    
    A_star = [[Zx(1)], [h]]
    r00 = (B[0][0] * A_star[0][0] + B[0][1] * A_star[1][0]) % (x^n + 1) % q 
    r10 = (B[1][0] * A_star[0][0] + B[1][1] * A_star[1][0]) % (x^n + 1) % q 
    result = [[r00], [r10]]
    
    return sk, h

def HashToPoint(salt, message, q, n):

    k = int((2**16) // q)
    
    shake = SHAKE256.new()
    shake.update(salt)
    shake.update(message)
    hashed = [0 for i in range(n)]
    
    i = 0
    j = 0
    while i < n:
        twobytes = shake.read(2)
        elt = (twobytes[0] << 8) + twobytes[1] 
        if elt < k * q:
            hashed[i] = elt % q
            i += 1
        j += 1
    return Zx(hashed) % (x^n + 1)

def NaiveSign(message, B, q, n):
    
    noise = urandom(SALT_LEN)
    
    c = HashToPoint(noise, message, q, n)
    
    B_ = [[B[1][1], -B[0][1]], [-B[1][0], B[0][0]]]
    B_inv = [[B[1][1]/q, -B[0][1]/q], [-B[1][0]/q, B[0][0]/q]]

    t = [c.coefficients(sparse=False), [[0]*n]]
    t00 = c * B_inv[0][0]  % (x^n + 1)
    t01 = c * B_inv[1][0]  % (x^n + 1)
    t = t00.coefficients(sparse=False) + t01.coefficients(sparse=False)
    
    z = [int(round(float(el))) for el in t]
    
    z1 = Zx(list(z)[:n])
    z2 = Zx(list(z)[n:])
    zb1 = (z1 * B[0][0] + z2 * B[1][0]) % (x^n + 1)
    zb2 = (z1 * B[0][1] + z2 * B[1][1]) % (x^n + 1)
    zb1 = zb1.coefficients(sparse=False)
    zb2 = zb2.coefficients(sparse=False)
    zb = vector(zb1 + zb2)
    
    c0 = vector(c.coefficients(sparse=False) + [0]*n)
    
    s = c0 - zb
    
    s1 = Zx(list(s)[0:n])  % q
    s2 = Zx(list(s)[n:]) % q
    
    beta = EuclideanNorm(balance(s1, q, n).coefficients(sparse=False) + balance(s2, q, n).coefficients(sparse=False), n)
    
    return noise, s2, beta

def NaiveVerify(message, noise, s2, h, q, n, beta):
    c = HashToPoint(noise, message, q, n)
    s1 = (c - s2 * h) % (x^n + 1)
    tmp = EuclideanNorm(balance(s1, q, n).coefficients(sparse=False) + balance(s2, q, n).coefficients(sparse=False), n)
    if tmp == beta:
        return True
    else:
        return False
    
def test():
    sk, h = NaiveKeyGen(sigma, q, n)

    message = b"hello_world!"
    noise, signature, beta  = NaiveSign(message, sk, q, n)

    if NaiveVerify(message, noise, signature, h, q, n, beta) == True:
        print("Passed")
    else:
        print("Failed")
    
test()

Passed
