In [2]:
#Main params
N = 11
p = 3
q = 32 
d = 4

# a class Zx of polynomials with integer coefficients and x as an unknown variable
Zx.<x> = ZZ[]

# ----------------------------------------------- AUXILIARY OPERATIONS ----------------------------------------------- 
def findDegree(coefs_list):
    len_ = len(coefs_list)
    for i in range(len_ - 1, -1, -1):
        if coefs_list[i] != 0:
            return i
        
def compareDegrees(coefs_list1, coefs_list2):
    cnt1 = 0
    cnt2 = 0
    len1 = len(coefs_list1)
    len2 = len(coefs_list2)
    for i in range(len1 - 1, -1, -1):
        if coefs_list1[i] != 0:
            cnt1 = i
            break
    for i in range(len2 - 1, -1, -1):
        if coefs_list2[i] != 0:
            cnt2 = i
            break
    return cnt1 < cnt2

def invertmodprime(f,p):
    Zq.<z> = PolynomialRing(Integers(p))
    ZQphi.<Z> = Zq.quotient(z^N-1)
    a = f % p
    a = a.subs(x=z)

    k = 0
    b = 1*z^0
    c = 0*z^0
    f = a 
    g = z^N-1
    
    assert a.gcd(g) in {i for i in range(p)},  f"inverse polynomial for {a} doesn't exist in (Z/Zp)[var]/(var^N-1)"
        
    while True:
        while list(f)[0] == 0:
            f = f / Z
            c = c * Z
            k += 1
        
        if findDegree(list(f)) == 0:
            b = 1/list(f)[0] * b
            ans = Z^(N-k) * b
            return Zx(ans.lift())
        
        if compareDegrees(list(f), list(g)):
            f, g = g, f
            b, c = c, b
        
        u = list(f)[0] * (1/list(g)[0])
        f = f - u*g
        b = b - u*c

        
def invertmodpowerof2(a, p):
    r = int(math.log(p, 2))
    p = 2
    
    q = p
    b = invertmodprime(a, p)

    while q < p^r:
        q = q^2
        b = b * (2 - a*b) % q % (x^N-1)
        
    b = b % p^r % (x^N - 1)
    return b


def balancedmod(f,q):
    g = list(((f[i] + q//2) % q) - q//2 for i in range(N))
    return Zx(g)

def convolution(f,g):
    return (f * g) % (x^N-1)

# ----------------------------------------------- BASIC SETUP -----------------------------------------------
def validate_params():
    if q > p and gcd(p,q) == 1:
        return True
    return False

def generate_polynomial(d1, d2):

    assert (d1 + d2) <= N       # asserting that there are less nonzero coefficients given than number of all coefficients
    
    result = [1]*d1 + [-1]*d2 + [0]*(N-d1-d2)
    shuffle(result)
    
    return Zx(result)

# ----------------------------------------------- MAIN SETUP -----------------------------------------------
def generate_keys(poly1 = None, poly2= None):
    if validate_params():
        while True:
            try:
                if poly1 is None and poly2 is None:
                    f = generate_polynomial(d+1, d)
                    g = generate_polynomial(d, d)
                else:
                    f = poly1
                    g = poly2
                              
                f_q = invertmodpowerof2(f,q)
                f_p = invertmodprime(f,p)
                break
        
            except:
                pass 
    
        public_key = balancedmod(p * convolution(f_q,g),q)

        secret_key = f,f_p

        return public_key,secret_key

    else:
        print("Provided params are not correct. q and p should be co-prime, q should be a power of 2 considerably larger than p and p should be prime.")

#---------------------------------- ENCRYPTION -----------------------------------------
def generate_message():
    result = list(randrange(3) - 1 for j in range(N))
    return Zx(result)

def encrypt(message, public_key, r = None):
    if r is None:
        r = generate_polynomial(d, d-1)
    return balancedmod(convolution(public_key,r) + message,q)


def decrypt(encrypted_message, secret_key):
    f,f_p = secret_key
    
    a = balancedmod(convolution(encrypted_message,f),q)
     
    return balancedmod(convolution(a,f_p),p)

def start():
    poly1 = -1+x+x^2-x^4+x^6+x^9-x^10
    poly2 = -1+x^2+x^3+x^5-x^8-x^10
    public_key, secret_key = generate_keys(poly1, poly2)
    message = -1+x^3-x^4-x^8+x^9+x^10
    print("Message: " + str(message))
    
    encrypted_message = encrypt(message, public_key, -1+x^2+x^3+x^4-x^5-x^7)
    print("Encryption: " + str(encrypted_message))
    
    decrypted_message = decrypt(encrypted_message, secret_key)
    print("Decryption: " + str(decrypted_message))
    
    if message == decrypted_message:
        print("Passed")
    else:
        print("Failed")
#----------------------------------------------------------------------------------------
#-------------------------------------- MAIN --------------------------------------------
#----------------------------------------------------------------------------------------

start()

Message: x^10 + x^9 - x^8 - x^4 + x^3 - 1
Encryption: -13*x^10 + 6*x^9 - 7*x^8 + 7*x^7 - 2*x^6 - 16*x^5 + 14*x^4 - 8*x^3 - 6*x^2 + 11*x + 14
Decryption: x^10 + x^9 - x^8 - x^4 + x^3 - 1
Passed
