In [1]:
#
# NTRU
#

# Implementation based on the book
# Cryptography: Theory and Practice
# Stinson - Paterson

In [2]:
#
# imports
#

from collections import namedtuple
from functools import reduce

In [3]:
#
# Define some functions used later
#

# Compare two polynomials in a ring
def equals_poly(p1,p2, ring):
    c1 = list(ring(p1))
    c2 = list(ring(p2))
    equals = list(map(lambda x,y: x==y, c1, c2))
    return reduce(lambda acc,val: acc and val, equals, true )


# Calculate a mods
def mods_hof(modulus):
    def _mods(number): 
        reduced = mod(number, modulus)
        if (reduced > floor(modulus/2)):
              return Integer(reduced) - Integer(modulus)
        else:
            return reduced
    return _mods


# Returns a random (-1,0 or 1)
def rnd_coef(_a):
    rnd = random()
    if (rnd < 1/3) : return -1
    elif (rnd < 2/3) : return 0
    else: return 1


# define random coef function
def rnd_coefs(degree):
    return list(map(rnd_coef, range(degree+1)))

In [90]:
#
# Define a instance of NTRUE
#

# NTRUE.keys(): generate and returns public_key and private_key 
# NTRUE.encrypt(message, public_key): encrypts a message using pubk
# NTRUE.decrypt(encrypted, private_key): decrypts a message using prik

# q must be a prime number

def NTRUE(p,q,N):
    # ring defs
    ring = PolynomialRing(ZZ, 'x')
    _ring = PolynomialRing(Integers(q), 'x')
    modulus = x^N-1
    qring = _ring.quotient(modulus, 'x')
    
    # define mods function
    mods = mods_hof(q)
    def _mods_list(l, mods_func=mods):
        return list(map(mods_func, l))   
    
    # multiply two polynomials mods x^N-1 
    def mul_poly(p1, p2):
        res = (qring(p1)*qring(p2)).mod(modulus)
        coefs = _mods_list(list(res))
        return ring(coefs)
    
    # multiply a number to a polynomial 
    def scalar_mul(poly, number):
        coefs = ring(poly).list()
        multiplied = list(map(lambda c: mods(number*c) , coefs))
        return ring(multiplied)
    
    # sum two polys
    def sum_poly(p1,p2):
        res = (qring(p1) + qring(p2)).mod(modulus)
        coefs = _mods_list(list(res))
        return ring(coefs)
    
    # get random polynomial with degree N
    def rnd_poly(degree=N):
        coefs = rnd_coefs(degree)
        if abs(sum(coefs)) < 5:
            return ring(coefs)
        else:
            return rnd_poly(degree)
        
    def inverse(poly):
        inv = 1/qring(poly)
        coefs = _mods_list(list(inv))
        return ring(coefs)
        
    # define f and g from F and G
    def get_f(_F=None):
        F = _F or rnd_poly()
        return scalar_mul(F,p) + 1
    
    def get_g(_G=None):
        G = _G or rnd_poly()
        return scalar_mul(G,p)
    
    # get h from f and g
    def get_h(f,g):
        f_inv = inverse(f)
        return mul_poly(f_inv, g)
    
    # keys
    def keys(F=None,G=None):
        f = get_f(F)   # private key
        g = get_g(G)
        h = get_h(f,g) # public key
        return (h, f)
    
    # Encryption
    def encrypt(m,h,_r=None):
        r = _r or rnd_poly()
        return sum_poly(mul_poly(r,h), m)

    
    # Decryption
    def decrypt(y, f):
        a = mul_poly(y,f)
        mods_p = mods_hof(p)
        coefs = _mods_list(list(a), mods_p)
        return ring(coefs)
        
    
    ntrue = namedtuple("NTRUE", ["ring", "mul_poly", "mods", "rnd_poly", "scalar_mul", "inverse", "get_f", "get_g", "get_h", "sum_poly", "encrypt", "decrypt", "keys"])
    
    return ntrue(ring, mul_poly, mods, rnd_poly, scalar_mul, inverse, get_f, get_g, get_h, sum_poly, encrypt, decrypt, keys)



In [91]:
#
# Running example 9.1 from the book
#

# parameters:
p=3
q=31
N=23

# instance of NTRUEncrypt
ntrue = NTRUE(p,q,N)

#Chosen polynomials
F = x^18-x^9+x^8-x^4-x^2
G = x^17+x^12+x^9+x^3-x

# Message
m = x^15-x^12+x^7-1

# Chosen error polynomial
r = x^19+x^10+x^6-x^2

public_key, private_key = ntrue.keys(F,G)

# Encrypting and decrypting a message m
y = ntrue.encrypt(m,public_key,r)
m_line = ntrue.decrypt(y, private_key)

print("private key (f) is ")
print(private_key)

print("\npublic key (h) is")
print(public_key)

print("\nmessage (m) is")
print(m)

print("\nencrypted message (y) is")
print(y)

print("\ndecrypted message (m') is")
print(m_line)

print("\nThe message was correctly decrypted?")
print(equals_poly(m, m_line, ntrue.ring))



private key (f) is 
3*x^18 - 3*x^9 + 3*x^8 - 3*x^4 - 3*x^2 + 1

public key (h) is
-13*x^22 - 15*x^21 + 12*x^19 - 14*x^18 + 8*x^16 - 14*x^15 - 6*x^14 + 14*x^13 - 3*x^12 + 7*x^11 - 5*x^10 - 14*x^9 + 3*x^8 + 10*x^7 + 5*x^6 - 8*x^5 + 4*x^2 + x + 8

message (m) is
x^15 - x^12 + x^7 - 1

encrypted message (y) is
5*x^22 - 15*x^21 + 4*x^20 + 8*x^19 + 10*x^18 - 15*x^17 + 6*x^16 + 8*x^15 - 8*x^14 + 3*x^13 - 10*x^12 - 7*x^11 - x^10 - 9*x^9 + 12*x^8 - 14*x^7 + 15*x^6 - 10*x^5 + 15*x^4 - 14*x^3 - 5*x^2 - 15*x - 3

decrypted message (m') is
x^15 - x^12 + x^7 - 1

The message was correctly decrypted?
True


In [95]:
#
# Running a random example
#

# parameters
p=3
q=2053
N=401

# instance of NTRUEncrypt
ntrue = NTRUE(p,q,N)

# get keys
pub_k, pri_k = ntrue.keys()

# Generate a random message:
m = ntrue.rnd_poly(10)

# Encryption:
y = ntrue.encrypt(m, pub_k)

# Decryption:
m_line = ntrue.decrypt(y, pri_k)

#print("\nThe message was correctly decrypted?")
print(equals_poly(m, m_line, ntrue.ring))



True
