In [2]:
# ----------------------------------------------- AUXILIARY OPERATIONS ----------------------------------------------- 

def validate_params(N, p, q, d):
    """ check current ntru parameters
    """
    assert N.is_prime(), "N is not a prime"
    assert p.is_prime(), "p is not a prime"
    assert q > p and gcd(p,q) == 1, "q must be larger than p and the greatest common divider of p and q is 1"
    assert (2*d + 1) <= N
    # assert q > (6*d + 1) * p, "q must be more then (6*d + 1) * p"
        
def find_degree(coefs_list):
    """ returns the degree of polynomial using its coeficients list 'a'
    """
    for i in range(len(coefs_list)-1, -1, -1):
        if coefs_list[i] != 0:
            return i

def invertmodprime(f, p, N):
    ''' calculates an inversion of a polynomial modulo x^N-1 and then modulo p
        with assumption that p is prime.
        returns a Zx polynomial h such as convolution of h ~ f = 1 (mod p)                
        raises an exception if such Zx polynomial h doesn't exist'''

    T = Zx.change_ring(Integers(p)).quotient(x^N-1) 
    return Zx(lift(1 / T(f)))                      

def invertmodpowerof2(a, p, N):
    """
    calculates an inversion of a polynomial modulo x^N-1 and then modulo p
    """
    r = int(math.log(p, 2))
    p = 2
    
    q = p
    b = invertmodprime(a, p, N)

    while q < p^r:
        q = q^2
        b = b * (2 - a*b) % q % (x^N-1)
        
    b = b % p^r % (x^N - 1)
    return b   

def balancedmod(f, q , N):
    ''' reduces every coefficient of a Zx polynomial f modulo q
        with additional balancing, so the result coefficients are integers in interval [-q/2, +q/2]
        more specifically: for an odd q [-(q-1)/2, +(q-1)/2], for an even q [-q/2, +q/2-1]. 
        returns Zx reduced polynomial'''

    g = list(((f[i] + q//2) % q) - q//2 for i in range(N))
    return Zx(g)

def convolution(f, g, N):
    ''' performs a multiplication operation specific for NTRU, which works like a traditional polynomial multiplication
        with additional reduction of the result by x^N-1 (x^n is replaced by 1, x^n-1 by x, x^n-2 by x^2, ...)
        returns Zx polynomial'''
    
    return (f * g) % (x^N-1)

# ----------------------------------------------- SETUP -----------------------------------------------

def generate_polynomial(d1, d2, N):
    ''' generates a random polynomial with d nonzero coefficients
        returns Zx polynomial '''
    
    result = [1]*d1 + [-1]*d2 + [0]*(N-d1-d2)
    shuffle(result)
    
    return Zx(result)

def generate_keys(N, p, q, d, p1 = None, p2= None):
    ''' generates a public and private key pair, based on provided parameters
        returns Zx public key and a secret key as a tuple of Zx f (private key) and Zx F_p'''

        #   some polynomials are not invertible and as f and g are calculated randomly,
        #   it may be necessary to skip some invalid examples
    while True:
        try:
            # generate 2 random polynomials f and g with number of nonzero coefficients < given number
            f = p1 or generate_polynomial(d + 1, d, N)
            g = p2 or generate_polynomial(d, d, N)

            # formula: find f_q, where: f_q (*) f = 1 (mod q)
            # assuming q is a power of 2                 
            f_q = invertmodprime(f, q, N)

            # formula: find f_p, where: f_p (*) f = 1 (mod p) 
            # assuming p is a prime number 
            f_p = invertmodprime(f, p, N)
            break

        except:
            pass

    # formula: public key = F_q ~ g (mod q)
    public_key = balancedmod(p * convolution(f_q, g, N), q, N)

    # secret key is a tuple containing a private key (f) and variable f_p needed for decryption
    secret_key = f, f_p

    return public_key, secret_key
    
#---------------------------------- ENCRYPTION -----------------------------------------
def generate_message(N, p, q, d):
    ''' creates a polynomial from a random list of coefficients selected from a set {-1,0,1}  
        returns Zx polynomial'''
        
    #randrange(3) - 1 gives results from a set of {-1,0,1}, which is necessary for a proper decryption
    result = list(randrange(3) - 1 for j in range(N))
    return Zx(result)

def encrypt(message, public_key, N, p, q, d, r = None):
    ''' performs encryption of a given message using a provided public key
        returns Zx encrypted message'''

    # generate random polynomial with number of nonzero coefficients < N for adding extra noise  
    r = r or generate_polynomial(d, d-1, N)
    
    # formula: encrypted_message = p * r ~ public_key + message (mod q)
    # while performing modulo operation, balance coefficients of encrypted_message 
    # for the integers in interval [-q/2, +q/2]
    return balancedmod(convolution(public_key,r, N) + message,q, N)


def decrypt(encrypted_message, secret_key, N, p, q, d):
    ''' performs decryption of a given ciphertext using an own private key
        
        returns Zx decrypted message'''
    
    # private key - f; additional variable stored for decryption - f_p     
    f,f_p = secret_key
    
    # formula: a = f ~ encrypted_message (mod q)
    # balance coefficients of a for the integers in interval [-q/2, +q/2]
    a = balancedmod(convolution(encrypted_message,f, N),q, N)
     
    # formula: F_p ~ a (mod p) with additional balancing as above
    return balancedmod(convolution(a,f_p, N),p, N)

from random import shuffle
from math import log2

Zx.<x> = ZZ[]

def ntru_session(N, p, q, d, message = None, f = None, g = None, r = None):
    
    validate_params(N, p, q, d)
    N, p, q, d = N, p, q, d
    
    print("N, p, q, d : ", N, p, q, d)
    
    if message is not None:
        assert len(message.coefficients(sparse=False)) <= N, "invalid message"
    if f is not None:
        assert len(f.coefficients(sparse=False)) <= N, "invalid f degree"
    if g is not None:
        assert len(g.coefficients(sparse=False)) <= N, "invalid g degree"
    if r is not None:
        assert len(r.coefficients(sparse=False)) <= N, "invalid r degree"
    
    message = message or generate_message(N, p, q, d)
    public_key, secret_key = generate_keys(N, p, q, d, f, g)
    print("MESSAGE: " + str(message))
    encrypted_message = encrypt(message, public_key, N, p, q, d, r)
    print("ENCRYPTION: " + str(encrypted_message))
    decrypted_message = decrypt(encrypted_message, secret_key, N, p, q, d)
    print("DECRYPTION: " + str(decrypted_message))
    
    if message == decrypted_message:
        return True
        print("SUCCESS")
    else:
        return False
        print("FAIL")
    
    

In [3]:
def test_lecture():
    print(f"TEST CASE: ntru data from wikipedia")
    try:
        f = x^6 - x^4 + x^3 + x^2 - 1
        g = x^6 + x^4 - x^2 - x
        r = x^6 - x^5 + x - 1
        message = -x^5 + x^3 + x^2 - x + 1
        res = ntru_session(7, 3, 41, 2, message, f, g, r)
        if res:
            print("TEST PASSED")
        else:
            print("TEST FAILED")
    except:
        print("TEST FAILED")
    finally: 
        print("*******\n")
        
test_lecture()

TEST CASE: ntru data from wikipedia
N, p, q, d :  7 3 41 2
MESSAGE: -x^5 + x^3 + x^2 - x + 1
ENCRYPTION: -10*x^6 + 19*x^5 + 4*x^4 + 2*x^3 - x^2 + 3*x - 16
DECRYPTION: -x^5 + x^3 + x^2 - x + 1
TEST PASSED
*******

