In [1]:
from sage.all import *
import sys
import math
from random import shuffle

In [2]:
def set_params(N1, p1, q1, d1):
    global N
    N = N1
    global p
    p = p1
    global q
    q = q1
    global d
    d = d1
    

In [3]:
Zx.<x> = ZZ[]

def find_degree(coefs_list):
    for i in range(len(coefs_list)-1, -1, -1):
        if coefs_list[i] != 0:
            return i

def invertmodprime(f,p):
    Zq.<z> = PolynomialRing(Integers(p))
    ZQphi.<Z> = Zq.quotient(z^N-1)
    a = f % p
    a = a.subs(x=z)
    k = 0
    b = 1*z^0
    c = 0*z^0
    f = a 
    g = z^N-1
    
    assert a.gcd(g) in {i for i in range(p)}
        
    while True:
        while list(f)[0] == 0:
            f = f / Z
            c = c * Z
            k += 1
        
        if find_degree(list(f)) == 0:
            b = 1/list(f)[0] * b
            ans = Z^(N-k) * b
            return Zx(ans.lift())
        
        if find_degree(list(f)) < find_degree(list(g)):
            f, g = g, f
            b, c = c, b
        
        u = list(f)[0] * (1/list(g)[0])
        f = f - u*g
        b = b - u*c

        
def invertmodpowerof2(a, p):
    r = int(math.log(p, 2))
    p = 2
    
    q = p
    b = invertmodprime(a, p)

    while q < p^r:
        q = q^2
        b = b * (2 - a*b) % q % (x^N-1)
        
    b = b % p^r % (x^N - 1)
    return b
   

def balancedmod(f,q):
    ''' reduces every coefficient of a Zx polynomial f modulo q
        with additional balancing, so the result coefficients are integers in interval [-q/2, +q/2]
        more specifically: for an odd q [-(q-1)/2, +(q-1)/2], for an even q [-q/2, +q/2-1]. 
        returns Zx reduced polynomial'''

    g = list(((f[i] + q//2) % q) - q//2 for i in range(N))
    return Zx(g)

def convolution(f,g):
    ''' performs a multiplication operation specific for NTRU, which works like a traditional polynomial multiplication
        with additional reduction of the result by x^N-1 (x^n is replaced by 1, x^n-1 by x, x^n-2 by x^2, ...)
        returns Zx polynomial'''
    
    return (f * g) % (x^N-1)


def validate_params():
    ''' checks params meet certain conditions: if q is considerably larger than p
        and if greatest common divider of p and q is 1 
        
        returns N, p, q '''
  
    if q > p and gcd(p,q) == 1:
        return True
    return False

def generate_polynomial(d1, d2):
    ''' generates a random polynomial with d nonzero coefficients
        returns Zx polynomial '''
    assert (d1 + d2) <= N       # asserting that there are less nonzero coefficients given than number of all coefficients
    
    result = [1]*d1 + [-1]*d2 + [0]*(N-d1-d2)
    shuffle(result)
    return Zx(result)

# ----------------------------------------------- MAIN SETUP -----------------------------------------------
def generate_keys(poly1 = None, poly2= None):
    ''' generates a public and private key pair, based on provided parameters
        returns Zx public key and a secret key as a tuple of Zx f (private key) and Zx F_p'''
    if validate_params():

        #   some polynomials are not invertible and as f and g are calculated randomly,
        #   it may be necessary to skip some invalid examples
        while True:
            try:
                if poly1 is None or poly2 is None:
                    # generate 2 random polynomials f and g with number of nonzero coefficients < given number
                    f = generate_polynomial(d+1, d)
                    g = generate_polynomial(d, d)
                else:
                    f = poly1
                    g = poly2                
                f_q = invertmodpowerof2(f,q) 
                f_p = invertmodprime(f,p)  
                break
        
            except:
                pass 
    
        #formula: public key = F_q ~ g (mod q)
        public_key = balancedmod(p * convolution(f_q,g),q)

        #secret key is a tuple containing a private key (f) and variable f_p needed for decryption
        secret_key = f,f_p

        return public_key,secret_key

    else:
        print("Provided params are not correct. q and p should be co-prime, q should be a power of 2 considerably larger than p and p should be prime.")

#---------------------------------- ENCRYPTION -----------------------------------------
def generate_message():
    ''' creates a polynomial from a random list of coefficients selected from a set {-1,0,1}  
        returns Zx polynomial'''
    result = list(randrange(3) - 1 for j in range(N))
    return Zx(result)

def encrypt(message, public_key, r = None):
    ''' performs encryption of a given message using a provided public key
        returns Zx encrypted message'''
    # generate random polynomial with number of nonzero coefficients < N for adding extra noise  
    if r is None:
        r = generate_polynomial(d, d-1)
    
    # formula: encrypted_message = p * r ~ public_key + message (mod q)
    # while performing modulo operation, balance coefficients of encrypted_message 
    # for the integers in interval [-q/2, +q/2]
    return balancedmod(convolution(public_key,r) + message,q)


def decrypt(encrypted_message, secret_key):
    ''' performs decryption of a given ciphertext using an own private key
        
        returns Zx decrypted message'''
    # private key - f; additional variable stored for decryption - f_p     
    f,f_p = secret_key
    
    # formula: a = f ~ encrypted_message (mod q)
    # balance coefficients of a for the integers in interval [-q/2, +q/2]
    a = balancedmod(convolution(encrypted_message,f),q)
     
    # formula: F_p ~ a (mod p) with additional balancing as above
    return balancedmod(convolution(a,f_p),p)


In [4]:
def set_polynoms(m = None, f = None, g = None, r = None):
    public_key, secret_key = generate_keys(f, g)
    message = m
    print("MESSAGE: " + str(message))

    encrypted_message = encrypt(message, public_key, r)
    print("ENCRYPTION: " + str(encrypted_message))

    decrypted_message = decrypt(encrypted_message, secret_key)
    print("DECRYPTION: " + str(decrypted_message))

    if message == decrypted_message:
        print("TEST PASSED")
    else:
        print("TEST FAILED")

In [7]:
set_params(11,3,32,3)


public_key, secret_key = generate_keys()

message = generate_message()
print("MESSAGE: " + str(message))

encrypted_message = encrypt(message, public_key)
print("ENCRYPTION: " + str(encrypted_message))

decrypted_message = decrypt(encrypted_message, secret_key)
print("DECRYPTION: " + str(decrypted_message))

if message == decrypted_message:
    print("TEST PASSED")
else:
    print("TEST FAILED")

MESSAGE: -x^10 + x^9 - x^8 - x^6 + x^5 - x^3 + x - 1
ENCRYPTION: 10*x^10 - x^9 + 6*x^8 + 2*x^7 + 13*x^6 + 12*x^5 - 4*x^4 - 7*x^3 - 8*x^2 - 13*x - 12
DECRYPTION: -x^10 + x^9 - x^8 - x^6 + x^5 - x^3 + x - 1
TEST PASSED


In [8]:
#an example from Wikipedia
m = -1+x^3-x^4-x^8+x^9+x^10
f = -1+x+x^2-x^4+x^6+x^9-x^10
g = -1+x^2+x^3+x^5-x^8-x^10
r = -1+x^2+x^3+x^4-x^5-x^7
set_params(11,3,32,4)
#set_polynoms(-1+x^3-x^4-x^8+x^9+x^10,-1+x+x^2-x^4+x^6+x^9-x^10, -1+x^2+x^3+x^5-x^8-x^10,-1+x^2+x^3+x^4-x^5-x^7)
set_polynoms(m,f,g,r)

MESSAGE: x^10 + x^9 - x^8 - x^4 + x^3 - 1
ENCRYPTION: -13*x^10 + 6*x^9 - 7*x^8 + 7*x^7 - 2*x^6 - 16*x^5 + 14*x^4 - 8*x^3 - 6*x^2 + 11*x + 14
DECRYPTION: x^10 + x^9 - x^8 - x^4 + x^3 - 1
TEST PASSED
