In [139]:
Zx.<x> = ZZ[]
Qx.<y> = QQ[]
n = 512
sigmamin = 1.277833697
sigmamax = 1.8205
q = 12*1024+1
phi = x^n+1
from os import urandom
from Crypto.Hash import SHAKE256
LN2 = 0.69314718056
sigma_star = 1.17 / sqrt(q / (2*n)) 
sigma_for_tree = 165.736617183
SALT_LEN = 40

In [140]:
#author: Dmytro Husan

def UniformBits(k):
    return int.from_bytes(bytes(list(floor(uniform(0, 256)) for i in range(k / 8))), 'big')

#author: Dmytro Husan
def BaseSampler():
    u = UniformBits(72) 
    z_0 = 0
    RCDT = [3024686241123004913666, 1564742784480091954050, 636254429462080897535, 199560484645026482916, 47667343854657281903, 8595902006365044063, 1163297957344668388, 117656387352093658, 8867391802663976, 496969357462633, 20680885154299, 638331848991, 14602316184, 247426747, 3104126, 28824, 198, 1, 0]
    for i in range(0, 18):
        z_0 = z_0 + int(u<RCDT[i]) 
    return z_0

#author: Dmytro Husan
def ApproxExp(x, ccs):
    C = [0x00000004741183A3,0x00000036548CFC06,0x0000024FDCBF140A,0x0000171D939DE045,0x0000D00CF58F6F84,0x000680681CF796E3,0x002D82D8305B0FEA,0x011111110E066FD0,0x0555555555070F00,0x155555555581FF00,0x400000000002B400,0x7FFFFFFFFFFF4800,0x8000000000000000]
    y = C[0]
    z = floor(2^63*x)
    for i in range(1, 13):
        y = C[i] - (z*y) >> 63
    z = floor(2^63*ccs)
    y = (z*y) >> 63
    return y

#author: Dmytro Husan
def BerExp(x, ccs):
    s = floor(x/LN2) 
    r = x - s*LN2 
    s = min(s, 63)
    z = (2*ApproxExp(r, ccs) - 1) >> s
    for i in range(56, -8, -8):
        p = UniformBits(8)
        w = p - ((z >> i) & 0xFF)
        if int(w) == 0:
            break
    return int(w < 0)

#author: Dmytro Husan
def SamplerZ(mu, sigma, sigmamin, sigmamax):
    r = mu - int(floor(mu)) 
    ccs = sigmamin/sigma
    while True:
        z_0 = BaseSampler()
        b = UniformBits(8)&0x1 
        z = b + (2*b-1)*z_0
        x = (z-r)^2/2/sigma^2 - z_0^2/2/sigmamax^2
        if BerExp(x, ccs) == 1:
            return z + int(floor(mu)) 


In [141]:
#author: Kateryna Makovetska
class Node:
    def __init__(self, value):
        self.value = value
        self.leftchild = 0
        self.rightchild = 0
        self.tree = [self.value,self.leftchild,self.rightchild]
        
    def update_tree(self):
        self.tree = [self.value,self.leftchild,self.rightchild]
        
    def __str__(self):
        return '[' + str(self.value) + ',' + str(self.leftchild) + ','  + str(self.rightchild) + ']' 
    
    def __repr__(self):
        return '[' + str(self.value) + ',' + str(self.leftchild) + ','  + str(self.rightchild) + ']'
    
    def print_tree(self, pref=""):
        leaf = "|—————> "
        top = "|_______"
        son1 = "|       "
        son2 = "        "
        width = len(top)

        a = ""
        if (self.value * y^0).degree() and self.leftchild:
            if (pref == ""):
                a += pref + str(self.value) + "\n"
            else:
                a += pref[:-width] + top + str(self.value) + "\n"
            try:
                a += self.leftchild.print_tree(pref + son1)
                a += self.rightchild.print_tree(pref + son2)
            except:
                pass
            return a
        else:
            return (pref[:-width] + leaf + str(self.value) + "\n")
        
def normalize_tree(tree, sigma):
    if (tree.value * y^0).degree() and tree.leftchild and tree.rightchild:
        normalize_tree(tree.leftchild, sigma)
        normalize_tree(tree.rightchild, sigma)
    else:
        tree.value = N(sigma / sqrt(tree.value))
        tree.update_tree()

In [142]:
# formula 3.22 from page 28, author: Evgen Postulga
#changed 
def Merge(a,b):
    iks = a.parent()([0, 1])
    a = a.subs(iks=iks^2)
    b = iks*b.subs(iks=iks^2)
    return a + b 

In [143]:
# algorythm 16 page 45, author: Maxim Pushkar
def Balance(f,q,n):
    g = list(((f[i] + q//2) % q) - q//2 for i in range(n))
    return f.parent()(g)

# formula 3.21 from page 28, author: Evgen Postulga
def Split(f,n):
    f0 = list(f[2*i+0] for i in range(n//2))
    f1 = list(f[2*i+1] for i in range(n//2))
    return f.parent()(f0), f.parent()(f1)  



In [144]:
#author: Evgen Postulga ???
def Omega(n):
    phi4 = cyclotomic_polynomial(4)
    phi_n = phi4.complex_roots()
    phi_n.reverse()
    k = 2
    while k != n:
        phi_n = sum([[sqrt(elt), - sqrt(elt)] for elt in phi_n], [])
        k = 2*k
    return phi_n

# formula 3.18, page 27, author: Evgen Postulga
def FFT(f, n):
    return [f.subs(k) for k in Omega(n)]

# algorithm 1, page 29, author: Evgen Postulga
def SplitFFT(f):
    n = len(f)
    w = Omega(n)
    f0 = [0] * (n // 2)
    f1 = [0] * (n // 2)
    for i in range(n // 2):
        f0[i] = 0.5 * (f[2 * i] + f[2 * i + 1])
        f1[i] = 0.5 * (f[2 * i] - f[2 * i + 1]) / w[2 * i]
    return f0, f1

# algorithm 2, page 29, author: Evgen Postulga
def MergeFFT(f0, f1):
    n = 2 * len(f0)
    w = Omega(n)
    f = [0] * n
    for i in range(n // 2):
        f[2 * i + 0] = f0[i] + w[2 * i] * f1[i]
        f[2 * i + 1] = f0[i] - w[2 * i] * f1[i]
    return f

# https://github.com/tprest/falcon.py/blob/master/fft.py
def invFFT(f_fft):
    n = len(f_fft)
    if (n > 2):
        f0_fft, f1_fft = SplitFFT(f_fft)
        f0 = invFFT(f0_fft)
        f1 = invFFT(f1_fft)

        f = n*[0]
        for i in range(n//2):
            f[2*i+0] = f0[i]
            f[2*i+1] = f1[i] 

    elif (n <= 2):
        f = [0,0]
        f[0] = f_fft[0].real()
        f[1] = f_fft[0].imag()
    return f


In [145]:
#these are from https://github.com/tprest/falcon.py/blob/master/fft.py
def add_fft(f, g):
    assert len(f) == len(g)
    deg = len(f)
    return [f[i] + g[i] for i in range(deg)]

def neg_fft(f):
    deg = len(f)
    return [- f[i] for i in range(deg)]

def sub_fft(f, g):
    return add_fft(f, neg_fft(g))

def mul_fft(f_fft, g_fft):
    deg = len(f_fft)
    return [f_fft[i] * g_fft[i] for i in range(deg)]

def div_fft(f_fft, g_fft):
    assert len(f_fft) == len(g_fft)
    deg = len(f_fft)
    return [f_fft[i] / g_fft[i] for i in range(deg)]


In [146]:
# author: Evgen Postulga
def CyclicRotate(input, n):
    return input[-n:] + input[0:-n]

In [147]:
# author: Evgen Postulga
def HermitianAdjointPoly(p, n):
    f=[p[0]]
    for i in range(1,n):    
        f.append(-p[n-i])
    return p.parent()(f) 

In [148]:
# author: Evgen Postulga
def LatticeToPoly4(M, n_2): 
    p00= M[0][:n_2]
    p01= M[0][n_2:]
    p10= M[n_2][:n_2]
    p11= M[n_2][n_2:]
    return p00, p01, p10, p11

In [149]:
# author: Evgen Postulga
def PolyToLattice4(p00, p01, p10, p11, n):
    M=[]
    p=[p00.coefficients(sparse=False),p01.coefficients(sparse=False),p10.coefficients(sparse=False),p11.coefficients(sparse=False)]
    for i in range(4):
        while len(p[i])!=n:
            p[i].append(0)
    for i in range(n):
        m1 = CyclicRotate(p[0], i)
        m2 = CyclicRotate(p[1], i)
        M.append(m1+m2)
    for i in range(n):
        m1 = CyclicRotate(p[2], i)
        m2 = CyclicRotate(p[3], i)
        M.append(m1+m2)
    return M

In [150]:
# author: Evgen Postulga
def InnerProduct(a,b,n):
    s = [a[i]*b[i].conjugate() for i in range(n)]
    return sum(s)

# formula 3.9 from page 24, author: Evgen Postulga
def EuclideanNorm(a,n):
    b=InnerProduct(a,a,n)
    return sqrt(b)

# formula 3.25 from page 30, author: Maxim Pushkar
def FieldNorm(f, n):
    f0, f1 = Split(f,n)
    iks = f.parent()([0, 1])
    return (f0^2 - iks * f1^2) % (iks^(n//2)+1)


#author: Pavlo Pinchuk 
def NTT(f, n, q):
    # Zp
    roots = (x^n + 1).roots(Integers(q))
    ans = [f.subs(x = i[0]) % q for i in roots]
    return ans

In [151]:
#author: Kateryna Makovetska
def LDL(g00, g01, g10, g11, n):
    D00 = g00
    g100 = FFT(g10,n)
    g000 = FFT(g00,n)
    L10 = div_fft(g100,g000)
    L10_star = HermitianAdjointPoly(Qx(invFFT(L10)),n)
    D11 = sub_fft(FFT(g11,n), mul_fft(mul_fft(L10,FFT(L10_star,n)),g000))
    L = [1*y^0,0*y^0,Qx(invFFT(L10)),1*y^0]
    D = [D00,0*y^0,0*y^0,Qx(invFFT(D11))]
    return L,D

In [152]:
#author: Kateryna Makovetska
def ffLDL(g00, g01, g10, g11, n):
    L,D = LDL(Qx(g00), Qx(g01), Qx(g10), Qx(g11), n)
    L10 = Qx(L[2])
    tree = Node(L10)
    D00 = Qx(D[0])
    D11 = Qx(D[3])
    if n == 2:
        tree.leftchild = Node(abs(int(FFT(D00,n)[0].real())))
        tree.rightchild = Node(abs(int(FFT(D11,n)[0].real())))
        tree.update_tree()
        return tree
    else:
        d00,d01 = SplitFFT(FFT(D00,n))
        d10,d11 = SplitFFT(FFT(D11,n))
        d01_star = HermitianAdjointPoly(Qx(invFFT(d01)),n//2)
        d11_star = HermitianAdjointPoly(Qx(invFFT(d11)),n//2)
        tree.leftchild = ffLDL(invFFT(d00),invFFT(d01),d01_star,invFFT(d00),n//2)
        tree.rightchild = ffLDL(invFFT(d10),invFFT(d11),d11_star,invFFT(d10),n//2)
    return tree   
           

In [153]:
#author: Karina Ilchenko
def Reduce(f, g, F, G, n):
    
    TT = Zx.change_ring(QQ).quotient(x^n+1) 
    
    f_star = HermitianAdjointPoly(f, n)
    g_star = HermitianAdjointPoly(g, n)
    while True:
        num = F*f_star + G*g_star
        num = TT(num)
        den = f*f_star + g*g_star
        den = 1 / TT(den)
        res = num * den
        k = Zx([int(round(elt)) for elt in res])
        F = F - k*f 
        G = G - k*g
        if all(elt == 0 for elt in k):
            break
    return f, g, F, G



In [154]:
#author: Karina Ilchenko and Pavlo Pinchuk 
def NTRUSolve(f, g, n, q):
    if n == 1:
        # u, v are numbers
        gcd_, u, v = xgcd(f[0], g[0])
        # print("gcd", u * f + v * g, "\n")
        if gcd_ != 1:
            return None, None, False
        F, G = -v*q, u*q
        # print("F1, G1", F / q, G / q)
        return F, G, True
    else:
        # ▷ f′, g′, F′, G′ ∈ Z[x]/(x^n/2 + 1)
        # ▷ N as defined in either (3.25) or (3.26)
        f_ = FieldNorm(f, n) 
        g_ = FieldNorm(g, n) 
        # print(n//2, f_, g_, sep="\n")
        F_, G_, flag = NTRUSolve(f_, g_, n//2, q)
        if flag:
            F = F_.subs(x=x^2) * g.subs(x=-x) 
            G = G_.subs(x=x^2) * f.subs(x=-x)
            # print("F, G", F, G, sep="\n")
            f, g, F, G = Reduce(f, g, F, G, n)
            return F % (x^n +1), G % (x^n +1), flag
        else:
            return F_, G_, flag

In [155]:
#author: Pavlo Pinchuk
def HashToPoint(salt, message, q, n):

    k = int((2**16) // q)
    
    shake = SHAKE256.new()
    shake.update(salt)
    shake.update(message)
    hashed = [0 for i in range(n)]
    
    i = 0
    j = 0
    while i < n:
        twobytes = shake.read(2)
        elt = (twobytes[0] << 8) + twobytes[1] 
        if elt < k * q:
            hashed[i] = elt % q
            i += 1
        j += 1
    return Zx(hashed) % phi

In [156]:
#author: Karina Ilchenko and Pavlo Pinchuk
def NTRUGen(q, n):
    
    def gen_poly(n, q):
        
        def D(mu=0):
            z = 0
            for i in range(1, 4096/n + 1):
                sigma_star = 1.17 * sqrt(q / 8192)
                sigmamin, sigmamax = 1.277833697, 1.8205
                zi = SamplerZ(mu, sigma_star, sigmamin, sigmamax)
                z += zi
            return z

        f = [0] * n
        g = [0] * n
        for i in range(n):
            f[i] = D()
            g[i] = D()
        f = Zx(f) % phi
        g = Zx(g) % phi
        return f, g
    
    def gs_norm(f, g, q, n):
        TT = Zx.change_ring(QQ).quotient(x^n+1) 
        # Using (3.9) with (3.8) or (3.10)    
        f_star = HermitianAdjointPoly(f, n)
        g_star = HermitianAdjointPoly(g, n)
        first = EuclideanNorm([*g.coefficients(sparse=False), *(-f).coefficients(sparse=False)], n)
        s1 = (q * TT(f_star)) / TT((f*f_star + g*g_star))
        s2 = (q * TT(g_star)) / TT((f*f_star + g*g_star))
        second = EuclideanNorm(list(s1) + list(s2), n)
        gamma = max(first, second)
        return gamma
    
    
    while True:
        while True:
            while True:

                f, g = gen_poly(n, q)
                # print(f, g, sep="\n")

                if gs_norm(f, g, q, n) > (1.17 ** 2) * q:
                    # print("restart norm\n")
                    continue
                break

            if  0 in NTT(f, n, q):
                # print("restart ntt\n")
                continue
            break
                
        F, G, flag = NTRUSolve(f, g, n, q)
        
        if not flag:
            # print("restart solve")
            continue
        else:
            F, G = F % (x^n +1), G % (x^n +1)
            F = Zx([int(coef) for coef in F.coefficients(sparse=False)])
            G = Zx([int(coef) for coef in G.coefficients(sparse=False)])
            # print("(f*G - g*F) % (x^n + 1) == q", (f*G - g*F) % (x^n + 1) == q)
            break
            
    return f, g, F, G

In [157]:
#author Karina Ilchenko
def NaiveKeyGen(sigma, q, n):
    f, g, F, G = NTRUGen(q, n)
    B = [[g, -f], [G, -F]]
    sk = B  
    
    TT = Zx.change_ring(Integers(q)).quotient(x^n+1)
    f_q = Zx(lift(1 / TT(f))) 
    h = g * f_q % phi % q
    h_star = HermitianAdjointPoly(h, n)
    A = [Zx(1), h_star]
    pk = A
    
    A_star = [[Zx(1)], [h]]
    r00 = (B[0][0] * A_star[0][0] + B[0][1] * A_star[1][0]) % phi % q 
    r10 = (B[1][0] * A_star[0][0] + B[1][1] * A_star[1][0]) % phi % q 
    result = [[r00], [r10]]
    
    return sk, h

In [158]:
#author Karina Ilchenko
def NaiveSign(message, B, q, n):
    
    noise = urandom(SALT_LEN)
    
    # list of coefs "c". hash value c ∈ Zq[x]/(ϕ)
    c = HashToPoint(noise, message, q, n)
    #B = [[g, -f], [G, -F]]
    B_ = [[B[1][1], -B[0][1]], [-B[1][0], B[0][0]]]
    
    B_inv = [[B[1][1]/q, -B[0][1]/q], [-B[1][0]/q, B[0][0]/q]]
    B_ = [[B[1][1], -B[0][1]], [-B[1][0], B[0][0]]]
    B_inv = [[B[1][1]/q, -B[0][1]/q], [-B[1][0]/q, B[0][0]/q]]

    

    
    t = [c.coefficients(sparse=False), [[0]*n]]
    # t = t*B_inv = [c, 0] * [[r00/q, r01], [r10, r11/q]]
    t00 = c * B_inv[0][0]  % phi
    t01 = c * B_inv[1][0]  % phi
    t = t00.coefficients(sparse=False) + t01.coefficients(sparse=False)
    
    z = [int(round(float(el))) for el in t]
    
    # z * B 
    # B = [[g, -f], [G, -F]]
    z1 = Zx(list(z)[:n])
    z2 = Zx(list(z)[n:])
    zb1 = (z1 * B[0][0] + z2 * B[1][0]) % phi
    zb2 = (z1 * B[0][1] + z2 * B[1][1]) % phi
    zb1 = zb1.coefficients(sparse=False)
    zb2 = zb2.coefficients(sparse=False)
    zb = vector(zb1 + zb2)
    
    # s = [c | 0] - z*B
    c0 = vector(c.coefficients(sparse=False) + [0]*n)
    
    s = c0 - zb
    
    s1 = Zx(list(s)[0:n])  % q
    s2 = Zx(list(s)[n:]) % q
    
    beta = EuclideanNorm(Balance(s1, q, n).coefficients(sparse=False) + Balance(s2, q, n).coefficients(sparse=False), n)
    
#     check that  ▷ s1 + s2h = c mod (ϕ, q)
    sa = (s1 * Zx(1) + s2 * h) % (x^n + 1) % q

    return noise, s2, beta

In [159]:
#author Karina Ilchenko
def NaiveVerify(message, noise, s2, h, q, n, beta):
    c = HashToPoint(noise, message, q, n)
    s1 = (c - s2 * h) % phi
    if EuclideanNorm(Balance(s1, q, n).coefficients(sparse=False) + Balance(s2, q, n).coefficients(sparse=False), n) == beta:
        return True
    else:
        return False

In [160]:
sk, h = NaiveKeyGen(sigma_star, q, n)

message = b"message"
noise, signature, beta_naive  = NaiveSign(message, sk, q, n)

print(NaiveVerify(message, noise, signature, h, q, n, beta_naive))
print(NaiveVerify(b"qqqqqqqqqq4444qq", noise, signature, h, q, n, beta_naive))


True
False


In [161]:
#author: Kateryna Makovetska
def ffSampling(t0_, t1_, T, sigmamin, sigmamax, q, n):
    # _ means constant , __ means ' 
    if n == 1:
        sigma = T.value
        z0_ = SamplerZ(t0_[0].real(),sigma ,sigmamin, sigmamax)
        z1_ = SamplerZ(t1_[0].real(),sigma ,sigmamin, sigmamax)
        return [z0_],[z1_]
    
    l = T.value
    T0 = T.leftchild
    T1 = T.rightchild
    t1 = SplitFFT(t1_)
    
    z1 = ffSampling(*t1, T1, sigmamin, sigmamax, q, n/2)
    z1_ = MergeFFT(*z1)
    t0__ = add_fft(t0_, mul_fft(sub_fft(t1_,z1_),l))
    t0 = SplitFFT(t0__)
    z0 = ffSampling(*t0, T0, sigmamin, sigmamax, q, n/2)
    z0_ = MergeFFT(*z0)
    return z0_,z1_

In [162]:
#author: Dmytro Husan
def KeyGen(phi, n):
    f, g, F, G = NTRUGen(q, n)
    B_ = (FFT(g, n), FFT(-f, n), FFT(G, n), FFT(-F, n))
    B_star = (FFT(HermitianAdjointPoly(g,n),n), FFT(HermitianAdjointPoly(G, n),n), FFT(HermitianAdjointPoly(-f,n),n), FFT(HermitianAdjointPoly(-F,n),n))
    p00 = add_fft(mul_fft(B_[0],B_star[0]),  mul_fft(B_[1],B_star[2]))
    p10 = add_fft(mul_fft(B_[0],B_star[1]),  mul_fft(B_[1],B_star[3]))
    p01 = add_fft(mul_fft(B_[2],B_star[0]),  mul_fft(B_[3],B_star[2]))
    p11 = add_fft(mul_fft(B_[2],B_star[1]),  mul_fft(B_[3],B_star[3]))
    B = [[Qx(invFFT(p00)).coefficients(sparse = False),Qx(invFFT(p01)).coefficients(sparse = False)],
         [Qx(invFFT(p10)).coefficients(sparse = False),Qx(invFFT(p11)).coefficients(sparse = False)]]
    
    T = ffLDL(B[0][0],B[0][1],B[1][0],B[1][1],n)
    normalize_tree(T, sigma_for_tree)
    sk = (B_, T)
    TT = Zx.change_ring(Integers(q)).quotient(x^n+1)
    f_q = Zx(lift(1 / TT(f))) 
    pk = g*f_q % phi % q
   
    return sk, pk




In [163]:
sk,pk = KeyGen(phi, n)

In [164]:
m,t = sk
B = m

In [165]:
#written by Maxim Pushkar, Dmytro Husan, Karina Ilchenko and Pavlo Pinchuk
def Sign(message,B_,T,sigmamin,sigmamax,q,n,h):
    r = b'jdgeufbjenwkf'
    g1 = B_[0]
    f1 = B_[1]
    G1 = B_[2]
    F1 = B_[3]
    c = HashToPoint(r, message, q, n)
    f = div_fft(neg_fft(B_[1]),[q]*len(B_[1]))
    F = div_fft(B_[3],[q]*len(B_[1])) #F = -F here
    
    t0 = mul_fft(FFT(c,n),F)
    t1 = mul_fft(FFT(c,n),f)
    t = t0+t1
        
    z1,z2 = ffSampling(t0, t1, T, sigmamin, sigmamax, q, n)
    #s = [c + 0] -  z*B_ 
    s = sub_fft(FFT(c,n) + [0]*n, add_fft(mul_fft(z1,g1),mul_fft(z2,G1)) + add_fft(mul_fft(z1,f1),mul_fft(z2,F1)))
        
    s1, s2 = invFFT(s[:n]), invFFT(s[n:])
    s1 = Zx(list(round(j) for j in s1 ))
    s2 = Zx(list(round(j) for j in s2 ))
#     S = s1 + s2 * h 
#     S = S % phi % q
#     CHECK EQUATION S1 + S2*H = C (MOD PHI) (MOD Q)
    bound = EuclideanNorm(s,2*n)^2/(2*n)
    return r,s2,bound   

In [166]:

r,s,bound = Sign(b'hello',B,t,sigmamin,sigmamax,q,n,pk)


In [167]:
#written by Maxim Pushkar, Dmytro Husan, Karina Ilchenko and Pavlo Pinchuk
def Verify(message, r, s, pk, bound, q, n):
    c = HashToPoint(r, message, q, n)

    s2 = s
    # s2 ∈ Z[x]/(x^n + 1)
    #s2 = Qx(s2)
    
    h = pk
    s1 = Balance((c - s2*h) % (x^n + 1) % q,q,n)
    s1 = FFT(s1,n)
    s2 = FFT(s2,n)
    if round((EuclideanNorm(s1+s2,2*n))^2 / (2*n),1) == round(bound,1):
        print("signature accepted")
        return True
    
    print("signature rejected")
    return False

In [168]:
print(Verify(b'hello',r,s,pk,bound,q,n))

signature accepted
True


In [169]:
print(Verify(b'heello',r,s,pk,bound,q,n))

signature rejected
False


In [170]:
r1,s1,bound1 = Sign(b'Karina',B,t,sigmamin,sigmamax,q,n,pk)

In [171]:
print(Verify(b'Karina',r1,s1,pk,bound1,q,n))

signature accepted
True
