# A fast class for polynomials in SageMath

## Motivation

FHE needs efficient operations in its native ring $\mathbb Z[X] / \langle X^N+1 \rangle $.
Unfortunately, the reduction mod $X^N+1$ and the automorphism evaluations $X \longmapsto X^5$ for example, are slow.
There are a couple more operations, which are FHE friendly, such as the centered reduction mod Q.

## Benchmarks

We test our code with the FHE standard $N=2^{15}$ and a modulus $Q = 2^{1000}$.

In [1]:
from sage.libs.ntl import *
import time

def set_ntl(element, modulus=None):
    # chooses between ZZX and ZZ_pX based on modulus given
    if modulus == 0 or modulus is None:
        return ntl.ZZX(element)
    else:
        return ntl.ZZ_pX(element, modulus)

class Poly:
    @classmethod
    def setup(cls, N=N, modulus=None):
        cls.N, cls.N2 = N, N*2
        cls.modulus = 0 if modulus is None else modulus
        cls.R = PolynomialRing(ZZ, 'x') # for printing
        cls.precomputations()
        
    @classmethod
    def precomputations(cls):
        cls.indices_auto5 = [ZZ((i * Zmod(cls.N2)(5)) % (cls.N2)) for i in range(cls.N)]
        cls.indices_auto5_poly = ntl.ZZ_pX(cls.indices_auto5, cls.N2)
        
        cls.zero_array = [0] * cls.N
        cls.zero_pX = ntl.ZZ_pX(cls.zero_array, cls.modulus)
        cls.zero_X = ntl.ZZX(cls.zero_array)
              
    @classmethod
    def random(cls, modulus=None): # slow ~ 400ms, about 10x slower than NTL
        # maybe we use the NTL random function instead
        m = modulus if modulus is not None else cls.modulus
        ring = PolynomialRing(Zmod(m), 'x').quotient(x**cls.N + 1)
        return Poly(set_ntl(ring.random_element().list(), m))

    def __init__(self, coeffs, modulus=None):
        if modulus:
            assert modulus == self.modulus, f"Modulus mismatch {modulus} != {self.modulus}"
        self.modulus = self.modulus if modulus is None else modulus
        
        if isinstance(coeffs, list):
            assert len(coeffs) <= self.N
            self.c = set_ntl(coeffs, modulus)
        elif isinstance(coeffs, Poly):
            self.c = coeffs.c
        else:    
            self.c = coeffs
    
    # ARITHMETIC OPERATORS
    
    def __add__(self, other):
        return Poly(self.c + other.c)
    
    def __radd__(self, other):
        return Poly(self.c + other.c)

    def __iadd__(self, other):
        self.c += other.c
        return self
    
    def __sub__(self, other):
        return Poly(self.c - other.c)
    
    def __rsub__(self, other):
        return Poly(other.c - self.c)
    
    def __isub__(self, other):
        self.c -= other.c
        return self

    def __neg__(self):
        return Poly(-self.c)
    
    def __mul__(self, other): # about 90ms, if you square it's about 60ms
        product = self.c * other.c
        return Poly(self.mod_quo(product))
    
    def __imul__(self, other):
        self.c *= other.c
        self.c = self.mod_quo(self.c)
        return self
    
    def __pow__(self, n):
        pass
        # result = 1
        # while exponent > 0:
        #     if exponent % 2 == 1:
        #         result *= base
        #     base *= base
        #     exponent //= 2
        # return result
    
    def __truediv__(self, other):
        pass
    
    # MODULAR  OPERATORS
    
    def __mod__(self, modulus): # fast, 1-2ms
        Poly.modulus = modulus
        if modulus == 0:
            return Poly(self.c, 0)
        else:
            tmp = self.c.convert_to_modulus(ntl.ZZ_pContext(modulus))
            return Poly(tmp, modulus)
    
    def mod_quo(self, element, minus=True): # 1-2ms
        if minus:
            return element.truncate(self.N) - element.left_shift(-self.N)
        else:
            return element.truncate(self.N) + element.left_shift(-self.N)
        
        
    ## SHIFTS AND AUTOMORPHISMS
        
    def __lshift__(self, n): # about 1-2ms
        temp = self.c.left_shift(n % self.N)
        return Poly(self.mod_quo(temp, minus=False))
    
    def __rshift__(self, n):
        return self << (self.N - n)
    
    def auto5(self): # 11 ms atm
        result = copy(self.zero_X) if self.modulus == 0 else copy(self.zero_pX)
        for i in range(self.N):
            result[self.indices_auto5[i]] = self[i]
        return Poly(self.mod_quo(result), self.modulus)
    
    def auto(self, index): # 13 ms atm
        index = index % (self.N // 2)
        if index == 0: # index must be >= 1
            return self
        elif index == 1:
            return self.auto5()
        exponent = Zmod(self.N2)(5) ** (index - 1)
        # we use the NTL library to calculate the indices of the automorphism,
        # by storing the indices in a polynomial mod 2*N
        indices = self.indices_auto5_poly * ntl.ZZ_pX([exponent], self.N2)
        # same as auto5, but with a different permutation
        result = copy(self.zero_X) if self.modulus == 0 else copy(self.zero_pX)
        for i in range(self.N):
            result[indices[i]] = self[i]
        return Poly(self.mod_quo(result), self.modulus)
    
    def auto_inverse(self): # 1.5ms atm 
        result = -self.c.reverse().left_shift(1)
        result[0] = -result[self.N]
        result[self.N] = 0
        return Poly(result, self.modulus)        
    
    # ACCESSORS
    
    def __setitem__(self, key, value):
        self.c[key] = value
        
    def __getitem__(self, key):
        return self.c[key]
    
    # OTHER METHODS
    
    def __copy__(self):
        return Poly(copy(self.c), self.modulus)
    
    def zero(self, modulus=None): # gives back the zero polynomial in the correct ring
        ## TOBEMODIFIED
        if modulus:
            if modulus == 0:
                return Poly(copy(self.zero_X), 0)
            result = copy(self.zero_pX).convert_to_modulus(ntl.ZZ_pContext(modulus))
            return Poly(result, modulus)
        zero = copy(self.zero_pX) if self.modulus != 0 else copy(self.zero_X)
        return Poly(zero, self.modulus)
    
    def __repr__(self): # doesn't need to be fast
        if self.modulus != 0:
            # we perform the central reduction in [-Q//2, Q//2)
            temp = set_ntl((self % self.modulus).c.list()).list() # reduced in ZZX
            temp = [a if a < self.modulus//2 else a - self.modulus for a in temp]
            return str(self.R(temp))
        else:
            return str(self.R(self.c.list()))
    
    

In [2]:
# baby cases
N = 8
Q = 2**10
Poly.setup(N, Q)
%time Poly.random()
a = Poly([1,2,3,40,5,6,70,8], Q)
print(a)
a.auto(2)
a.auto_inverse()


CPU times: user 24.4 ms, sys: 2.85 ms, total: 27.2 ms
Wall time: 58.4 ms
8*x^7 + 70*x^6 + 6*x^5 + 5*x^4 + 40*x^3 + 3*x^2 + 2*x + 1


-2*x^7 - 3*x^6 - 40*x^5 - 5*x^4 - 6*x^3 - 70*x^2 - 8*x + 1

In [5]:
# benchmarks

N = 2**15
print(N)
Q = 2**1000
Poly.setup(N, Q)
a = Poly.random()
b = Poly.random()
b = Poly.random()

%time a + b
# %time a << 324
%time b = a.auto5()
%time b = a.auto(13)
%time b = a.auto_inverse()


32768
CPU times: user 1.04 ms, sys: 153 µs, total: 1.19 ms
Wall time: 1.2 ms
CPU times: user 13.3 ms, sys: 1.65 ms, total: 14.9 ms
Wall time: 15 ms
CPU times: user 15.5 ms, sys: 2.18 ms, total: 17.7 ms
Wall time: 18.1 ms
CPU times: user 1.75 ms, sys: 662 µs, total: 2.41 ms
Wall time: 2.42 ms


In [4]:
b = 3
c = 4
d = b + c
print(d)

7
