In [1]:
from Crypto.Hash import SHAKE256
from os import urandom
from random import shuffle
from math import log2
from sage.modules.misc import gram_schmidt

Zx.<x> = ZZ[]
k = 5
n = 2^k
q = 12289
phi = x^n + 1
SALT_LEN = 40

In [2]:
def Balance(f,q,n):
    g = list(((f[i] + q//2) % q) - q//2 for i in range(n))
    return f.parent()(g)

def Split(f,n):
    f0 = list(f[2*i+0] for i in range(n//2))
    f1 = list(f[2*i+1] for i in range(n//2))
    return f.parent()(f0), f.parent()(f1)  

def InnerProduct(a,b,n):
    s = [a[i]*b[i].conjugate() for i in range(n)]
    return sum(s)

def EuclideanNorm(a,n):
    b = InnerProduct(a,a,n)
    return sqrt(float(b))

def FieldNorm(f, n):
    f0, f1 = Split(f,n)
    iks = f.parent()([0, 1])
    return (f0^2 - iks * f1^2) % (iks^(n//2)+1)

def HermitianAdjointPoly(p, n):
    f=[p[0]]
    for i in range(1,n):    
        f.append(-p[n-i])
    return p.parent()(f)

def CyclicRotate(input, n):
    lst = input[-n:] + input[0:-n]
    for i in range(n):
        lst[i] *= -1
    return lst

def PolyToLattice4(p00, p01, p10, p11, n):
    M=[]
    p=[p00.coefficients(sparse=False),p01.coefficients(sparse=False),p10.coefficients(sparse=False),p11.coefficients(sparse=False)]
    for i in range(4):
        while len(p[i])!=n:
            p[i].append(0)
    for i in range(n):
        m1 = CyclicRotate(p[0], i)
        m2 = CyclicRotate(p[1], i)
        M.append(vector(m1+m2))
    for i in range(n):
        m1 = CyclicRotate(p[2], i)
        m2 = CyclicRotate(p[3], i)
        M.append(vector(m1+m2))
    return M

def NTT(f, n, q):
    # Zp
    roots = (x^n + 1).roots(Integers(q))
    ans = [f.subs(x = i[0]) % q for i in roots]
    return ans


In [3]:
def UniformBits(k):
    return int.from_bytes(bytes(list(floor(uniform(0, 256)) for i in range(k / 8))), 'big')

def BaseSampler():
    u = UniformBits(72)
    z_0 = 0
    RCDT = [3024686241123004913666, 1564742784480091954050, 636254429462080897535, 199560484645026482916, 47667343854657281903, 8595902006365044063, 1163297957344668388, 117656387352093658, 8867391802663976, 496969357462633, 20680885154299, 638331848991, 14602316184, 247426747, 3104126, 28824, 198, 1, 0]
    for i in range(0, 18):
        z_0 = z_0 + int(u<RCDT[i]) 
    return z_0

def ApproxExp(x, ccs):
    C = [0x00000004741183A3,0x00000036548CFC06,0x0000024FDCBF140A,0x0000171D939DE045,0x0000D00CF58F6F84,0x000680681CF796E3,0x002D82D8305B0FEA,0x011111110E066FD0,0x0555555555070F00,0x155555555581FF00,0x400000000002B400,0x7FFFFFFFFFFF4800,0x8000000000000000]
    y = C[0]
    z = floor(2^63*x)
    for i in range(1, 13):
        y = C[i] - (z*y) >> 63
    z = floor(2^63*ccs)
    y = (z*y) >> 63
    return y

def BerExp(x, ccs):
    LN2 = 0.69314718056
    s = floor(x/LN2)
    r = x - s*LN2
    s = min(s, 63)
    z = (2*ApproxExp(r, ccs) - 1) >> s
    for i in range(56, -8, -8):
        p = UniformBits(8)
        w = p - ((z >> i) & 0xFF)
        if int(w) == 0:
            break
    return int(w < 0)

def SamplerZ(mu, sigma, sigmamin, sigmamax):
    r = mu - int(floor(mu))
    ccs = sigmamin/sigma
    while True:
        z_0 = BaseSampler()
        b = UniformBits(8)&0x1
        z = b + (2*b-1)*z_0
        x = (z-r)^2/2/sigma^2 - z_0^2/2/sigmamax^2
        if BerExp(x, ccs) == 1:
            return z + int(floor(mu))

In [4]:
def gen_poly(n, q):
        
        def D(mu=0):
            z = 0
            for i in range(1, 4096/n + 1):
                sigma_star = 1.17 * sqrt(q / 8192)
                sigmamin, sigmamax = 1.277833697, 1.8205
                zi = SamplerZ(mu, sigma_star, sigmamin, sigmamax)
                z += zi
            return z

        f = [0] * n
        g = [0] * n
        for i in range(n):
            f[i] = D()
            g[i] = D()
        f = Zx(f) % phi
        g = Zx(g) % phi
        
        return f, g

In [5]:
def gs_norm(f, g, q, n):
    
    TT = Zx.change_ring(QQ).quotient(x^n+1) 
       
    f_star = HermitianAdjointPoly(f, n)
    g_star = HermitianAdjointPoly(g, n)
    first = EuclideanNorm([*g.coefficients(sparse=False), *(-f).coefficients(sparse=False)], n)
    s1 = (q * TT(f_star)) / TT((f*f_star + g*g_star))
    s2 = (q * TT(g_star)) / TT((f*f_star + g*g_star))
    second = EuclideanNorm(list(s1) + list(s2), n)
    gamma = max(first, second)
    
    return gamma
    

In [6]:
def Reduce(f, g, F, G, n):
    
    TT = Zx.change_ring(QQ).quotient(x^n+1) 
    
    f_star = HermitianAdjointPoly(f, n)
    g_star = HermitianAdjointPoly(g, n)
    while True:
        num = F*f_star + G*g_star
        num = TT(num)
        den = f*f_star + g*g_star
        den = 1 / TT(den)
        res = num * den
        k = Zx([int(round(elt)) for elt in res])
        F = (F - k*f) % phi
        G = (G - k*g) % phi
        if all(elt == 0 for elt in k):
            break
            
    return f, g, F, G


In [7]:
def NTRUSolve(f, g, n, q):
    
    if n == 1:
        # u, v are numbers
        gcd_, u, v = xgcd(f[0], g[0])
        if gcd_ != 1:
            return None, None, False
        F, G = -v*q, u*q
        return F, G, True
    else:
        
        # ▷ f′, g′, F′, G′ ∈ Z[x]/(x^n/2 + 1)   
        f_ = FieldNorm(f, n) 
        g_ = FieldNorm(g, n) 
        F_, G_, flag = NTRUSolve(f_, g_, n//2, q)
        
        if flag:
            F = F_.subs(x=x^2) * g.subs(x=-x) % phi
            G = G_.subs(x=x^2) * f.subs(x=-x) % phi
            f, g, F, G = Reduce(f, g, F, G, n)
        
            return F % (x^n +1), G % (x^n +1), flag
        else:
            return F_, G_, flag

In [8]:
def NTRUGen(q, n):
    
    while True:
        while True:
            while True:

                f, g = gen_poly(n, q)

                gamma = gs_norm(f, g, q, n)
                if  gamma > (1.17 ** 2) * q:
                    continue
                break

            if  0 in NTT(f, n, q):
                continue
            break
                
        F, G, flag = NTRUSolve(f, g, n, q)
        
        if not flag:
            continue
        else:
            F, G = F % (x^n +1), G % (x^n +1)
            F = Zx([int(coef) for coef in F.coefficients(sparse=False)])
            G = Zx([int(coef) for coef in G.coefficients(sparse=False)])
            # print("(f*G - g*F) % (x^n + 1) == q", (f*G - g*F) % (x^n + 1) == q)
            break
            
    return f, g, F, G

In [9]:
def Master_Keygen(n, q):
    """
    work over the polynomial ring Zq[x]/(x^n + 1), 
    where n is a power of 2 and q is a prime congruent to 1 mod 2n
    AN(f) + AN(g) = AN(f + g), and AN(f) × AN(g) = AN(f ∗ g) and AN(p)[0] = p
    polynomials f, g, F, G ∈ R verifying: f ∗ G − g ∗ F = q 
    
    return: (h = g ∗ f_inv mod q) - Public Key
            (B = [[g, -f], [G, -F]]) - Secret Key
    """
    f, g, F, G = NTRUGen(q, n)
    B = [[g, -f], [G, -F]]   
    # print("sk: ", *B, sep="\n")
    # print()
    
    TT = Zx.change_ring(Integers(q)).quotient(x^n+1)
    f_q = Balance(Zx(lift(1 / TT(f))), q ,n)
    h = g * f_q % phi % q
    # print("pk: ", *h, sep="\n")
    # print()
    
    return B, Balance((h), q ,n)


In [10]:
def Gaussian_Sampler(B, sigma, c):
    
    v_i = vector([0]*2*n)
    B_, mu = gram_schmidt(B)
    c_i = vector(c)
    for i in range(2*n-1, -1, -1):
        c_i_ = c_i.dot_product(B_[i])/EuclideanNorm(list(B_[i]), len(B_[i]))**2
        sigma_i = sigma/EuclideanNorm(list(B_[i]), len(B_[i]))
        z_i = SamplerZ(c_i_, sigma_i, 1.277833697, 1.8205)
        c_i = c_i - z_i*vector(B[i])
        v_i = v_i + z_i*vector(B[i])
        
    return v_i


In [11]:
def HashToPoint(noise, user_id, q, n):
    
    """
    user_id: b"surname"
    noise: to avoid same hashes for people with same surnames
    return: polynomial % phi % q 
    """

    k = int((2**30) // q)
    
    shake = SHAKE256.new()
    shake.update(noise)
    shake.update(user_id)
    hashed = [0 for i in range(n)]
    
    i = 0
    j = 0
    while i < n:
        twobytes = shake.read(2)
        elt = (twobytes[0] << 8) + twobytes[1] 
        if elt < k * q:
            hashed[i] = elt % q
            i += 1
        j += 1
        
    return Zx(hashed)

In [12]:
def Extract(B, user_id, n, q):
    """
    generates SK once for each user
    {s1 + s2 ∗ h} = t  mod (ϕ, q)
    """
    global SK_id
    
    if SK_id:
        return SK_id
    else:
        t = HashToPoint(noise, user_id, q, n)
        t = t.coefficients(sparse=False) + [0]*n
        
        X = PolyToLattice4(B[0][0], B[0][1], B[1][0], B[1][1], n)
        
        eps = 1/2/n
        gamma = gs_norm(-B[0][1], B[0][0], q, n)
        sigma = gamma/float(pi)*sqrt(ln(2+2/eps)/2)
        
        s = vector(t) - Gaussian_Sampler(X, sigma, t)
        s1 = Zx(list(s[:n]))
        s2 = Zx(list(s[n:]))
                
        # print("check that s1 + s2*h = t (mod phi, q)")
        # print("hash t: ", t[:n])
        # sa = (s1 * Zx(1) + s2 * h) % (x^n + 1) % q
        # print("hash? : ", sa.coefficients(sparse=False))

        SK_id = s2
        return SK_id
    

In [50]:
def gen_small_ternar_poly(d1, d2, n):
    ''' generates a random polynomial with d1+d2 nonzero coefficients
        returns Zx polynomial '''
    result = [1]*d1 + [-1]*d2 + [0]*(n-d1-d2)
    shuffle(result)
    return Zx(result)

In [51]:

#----------------------------------------------------

def HashToBin(info, m, n, q):
    """
    H0: {0, 1}^n → {0, 1}^m
    """
    return info[0:m]

In [52]:
def gen_binar_list(d, n):
    ''' generates a random list with d nonzero coefficients
    '''
    result = [1]*d + [0]*(n-d)
    shuffle(result)
    return result

In [69]:
def Encrypt(message, h, user_id, n, q):
    
    """
    message - list ∈ {0,1}
    Note that encryption is designed using a key-encapsulation mechanism; 
    the hash of the key k is used to one-time-pad the message.
    message ∈ {0, 1}^m
    """
    m = len(message)
    
    d1 = UniformBits(800) % n // 2
    d2 = UniformBits(800) % n // 2
    r = gen_small_ternar_poly(d1, d2, n)
    
    d1 = UniformBits(800) % n // 2
    d2 = UniformBits(800) % n // 2
    e1 = gen_small_ternar_poly(d1, d2, n)
    
    d1 = UniformBits(800) % n // 2
    d2 = UniformBits(800) % n // 2
    e2 = gen_small_ternar_poly(d1, d2, n)
    
    # k = key for encription
    d = UniformBits(800) % n
    k = gen_binar_list(d, n)
    Hk = HashToBin(k, m, n, q)   # ∈ {0, 1}^m
    
    t = HashToPoint(noise, user_id, q, n)
    
    u = (r * h + e1) % phi % q
    v = (r * t + e2 + (q//2) * Zx(k)) % phi % q
         
    # XOR
    encrypted_message = [(Hk[i] + message[i]) % 2 for i in range(m)]
         
    return u, v, encrypted_message  
    

In [70]:
def Decrypt(SK_id, u, v, encrypted_message, n, q):
    
    s2 = SK_id
    m = len(encrypted_message)
    
    w = Balance((v - u * s2) % phi, q, n)
    
    k = (w / (q//2)).coefficients(sparse=False)
    k = [int(round(el)) % 2 for el in k]
    Hk = HashToBin(k, m, n, q)
    # print(Hk)
    
    # XOR
    decrypted_message = [(Hk[i] + encrypted_message[i]) % 2 for i in range(m)]
    
    return decrypted_message


In [71]:
# user information for signature
noise = urandom(SALT_LEN)
SK_id = None

B, h = Master_Keygen(n, q)
SK_id = Extract(B, b"Karina", n, q)


In [72]:
def test_IBE(message, B, h, SK_id, noise):
    
    print("MESSAGE:    ", message)

    u, v, encrypted_message = Encrypt(message, h, b"Karina", n, q)
    print("ENCRYPTION: ", encrypted_message)

    decrypted_message = Decrypt(SK_id, u, v, encrypted_message, n, q)
    print("DECRYPTION: ",  decrypted_message)

    if message == decrypted_message:
        print("SUCCESS")
    else:
        print("FAIL")

In [73]:
test_IBE([1]*15, B, h, SK_id, noise)

MESSAGE:     [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
ENCRYPTION:  [1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 0, 1, 1, 0, 0]
DECRYPTION:  [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
SUCCESS


In [77]:
test_IBE([0]*15, B, h, SK_id, noise)

MESSAGE:     [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
ENCRYPTION:  [1, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1]
DECRYPTION:  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
SUCCESS


In [78]:
test_IBE([1,0,1,0,1,0,1,0,1]*2, B, h, SK_id, noise)

MESSAGE:     [1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1]
ENCRYPTION:  [0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 0, 1]
DECRYPTION:  [1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1]
SUCCESS
