# A fast class for polynomials in SageMath

## Motivation

FHE needs efficient operations in its native ring $\mathbb Z[X] / \langle X^N+1 \rangle $.
Unfortunately, the reduction mod $X^N+1$ and the automorphism evaluations $X \longmapsto X^5$ for example, are slow if implemented in plain SageMath.
There are a couple more operations, which are FHE friendly, such as the centered reduction mod Q.

## Benchmarks

We test our code with the FHE standard $N=2^{15}$ and a modulus $Q = 2^{1000}$.

In [395]:
from sage.libs.ntl import *
from sage.rings.polynomial.polynomial_integer_dense_ntl import *
import time

def set_ntl(element, modulus=None):
    # chooses between ZZX and ZZ_pX based on modulus given
    if modulus == 0 or modulus is None:
        return ntl.ZZX(element)
    else:
        return ntl.ZZ_pX(element, modulus)

class Poly:
    @classmethod
    def setup(cls, N=N, modulus=None):
        cls.N, cls.N2 = N, N*2
        # here we store the current modulus of a given element
        cls.modulus = 0 if modulus is None else modulus
        cls.R = PolynomialRing(ZZ, 'x') # for printing
        cls.precomputations()
        
    @classmethod
    def precomputations(cls):
        cls.indices_auto5 = [ZZ((i * Zmod(cls.N2)(5)) % (cls.N2)) for i in range(cls.N)]
        cls.indices_auto5_poly = ntl.ZZ_pX(cls.indices_auto5, cls.N2)
        
        cls.zero_array = [0] * cls.N
        cls.zero_pX = ntl.ZZ_pX(cls.zero_array, cls.modulus)
        cls.zero_X = ntl.ZZX(cls.zero_array)
              
    @classmethod
    def random(cls, modulus=None): # slow ~ 400ms, about 10x slower than NTL
        # maybe we use the NTL random function instead
        m = modulus if modulus is not None else cls.modulus
        ring = PolynomialRing(Zmod(m), 'x').quotient(x**cls.N + 1)
        return Poly(set_ntl(ring.random_element().list(), m), m)

    def __init__(self, coeffs, modulus=None, warning=True):
        if modulus:
            if modulus != self.modulus:
                if warning:
                    print(f"Warning: modulus = {modulus} given is different from the class modulus {self.modulus}")
                    print("Setting up class again...")
                Poly.setup(N=self.N, modulus=modulus)
        self.modulus = self.modulus if modulus is None else modulus

        if isinstance(coeffs, list):
            assert len(coeffs) <= self.N
            self.c = set_ntl(coeffs, modulus)
        elif isinstance(coeffs, Poly):
            self.c = coeffs.c
        else:    
            self.c = coeffs
    
    # ARITHMETIC OPERATORS
    
    def __add__(self, other):
        return Poly(self.c + other.c, self.modulus)
    
    def __radd__(self, other):
        return Poly(self.c + other.c, self.modulus)

    def __iadd__(self, other):
        self.c += other.c
        return self
    
    def __sub__(self, other):
        return Poly(self.c - other.c, self.modulus)
    
    def __rsub__(self, other):
        return Poly(other.c - self.c, self.modulus)
    
    def __isub__(self, other):
        self.c -= other.c
        return self

    def __neg__(self):
        return Poly(-self.c, self.modulus)
    
    ## MULTIPLICATION OPERATORS
    
    def __mul__(self, other): 
        if isinstance(other, Poly): # 90ms, if you square it's 60ms
            product = self.mod_quo(self.c * other.c)
        else: # integer multiplication, about 5ms
            product = self.c * set_ntl([other], self.modulus)
        return Poly(product, self.modulus)
    
    def __imul__(self, other):
        self.c *= other.c
        self.c = self.mod_quo(self.c)
        return self
    
    def __pow__(self, exponent):
        if exponent == 0:
            return Poly([1], self.modulus)
        elif exponent == 1:
            return self
        elif exponent % 2 == 0:
            result = self ** (exponent // 2)
            return result * result
        else:
            result = self ** ((exponent - 1) // 2)
            return self * result * result
        
    ## SCALING OPERATORS
    
    def rescale(self, other):
        return Poly(self.c._right_pshift(ntl.ZZ(other)), self.modulus)
    
    def __truediv__(self, other): # 5-6ms
        # in contrary to rescale, this does not scale down the modulus
        return self.rescale(other) % self.modulus        
    
    # MODULAR OPERATORS
    
    def __mod__(self, modulus): # fast, for the necessary cases 1-2ms
        Poly.modulus = modulus
        element_modulus = self.element_modulus()
        
        if modulus == 0: # this does convert to ZZX!!
            if element_modulus == 0:
                return self
            else: # slow takes 300ms
                return Poly(ntl.ZZX(self.c), 0)
            
        elif element_modulus == 0: # slow! 400ms
            return Poly(ntl.ZZ_pX(self.c, modulus), modulus)
        
        elif modulus == element_modulus:
            return self
        
        else:
            tmp = self.c.convert_to_modulus(ntl.ZZ_pContext(modulus))
            return Poly(tmp, modulus)
    
    def mod_quo(self, element, minus=True): # 1-2ms
        if minus: # mod X^N + 1
            return element.truncate(self.N) - element.left_shift(-self.N)
        else:
            return element.truncate(self.N) + element.left_shift(-self.N)
           
    ## SHIFTS AND AUTOMORPHISMS
        
    def __lshift__(self, n): # about 1-2ms
        temp = self.c.left_shift(n % self.N)
        return Poly(self.mod_quo(temp, minus=False), self.modulus)
    
    def __rshift__(self, n):
        return self << (self.N - n)
    
    def auto5(self): # 11 ms atm
        result = copy(self.zero_X) if self.modulus == 0 else copy(self.zero_pX)
        for i in range(self.N):
            result[self.indices_auto5[i]] = self[i]
        return Poly(self.mod_quo(result), self.modulus)
    
    def auto(self, index): # 13 ms atm
        index = index % (self.N // 2)
        if index == 0: # index must be >= 1
            return self
        elif index == 1:
            return self.auto5()
        exponent = Zmod(self.N2)(5) ** (index - 1)
        # we use the NTL library to calculate the indices of the automorphism,
        # by storing the indices in a polynomial mod 2*N
        indices = self.indices_auto5_poly * ntl.ZZ_pX([exponent], self.N2)
        # same as auto5, but with a different permutation
        result = copy(self.zero_X) if self.modulus == 0 else copy(self.zero_pX)
        for i in range(self.N):
            result[indices[i]] = self[i]
        return Poly(self.mod_quo(result), self.modulus)
    
    def auto_inverse(self): # 2ms atm 
        result = -self.c.reverse().left_shift(1)
        result[0] = -result[self.N]
        return Poly(result.truncate(self.N), self.modulus)        
    
    # ACCESSORS
    
    def __setitem__(self, key, value):
        self.c[key] = value
        
    def __getitem__(self, key):
        return self.c[key]
    
    def leading_coefficient(self):
        return self.c.leading_coefficient()
    
    ## CHECKS
    
    def is_zero(self):
        return self.c.is_zero()
    
    def is_one(self):
        return self.c.is_one()
    
    def is_monic(self):
        return self.c.is_monic()
    
    ## PRINTING AND REPRESENTATION
    
    def centered_list(self): # we perform the central reduction in [-Q//2, Q//2)
        if self.modulus == 0:
            return self.c.list()
        temp = set_ntl((self % self.modulus).c.list()).list()
        return [a if a < self.modulus//2 else a - self.modulus for a in temp]
    
    def norm(self): # slow, 400ms
        return max([abs(a) for a in self.R(self.centered_list())])
    
    def __repr__(self): # doesn't need to be fast
        return str(self.R(self.centered_list()))

    # OTHER METHODS
    
    def __copy__(self):
        return Poly(copy(self.c), self.modulus)
    
    def __eq__(self, other):
        assert isinstance(other, Poly), "Cannot compare with non-Poly object!"
        assert other.modulus == self.modulus, "Different moduli!"
        other.check_modulus()
        self.check_modulus()
        return self.c == other.c
    
    def clear(self): # Resets this polynomial to zero, changes in place
        self.c.clear()
        
    def element_modulus(self):
        try: ## ZZ_pX case
            return self.c.get_modulus_context().modulus()
        except: ## ZZX case
            return 0
    
    def check_modulus(self): # Checks if the Class modulus is the same as the element modulus
        tmp = (self.modulus == self.element_modulus())
        if not tmp:
            print(f"Warning: Class modulus = {self.modulus} is different from the element modulus {self.element_modulus()}!")
        return tmp

In [380]:
# baby cases
N = 8
Q = 2**10
Poly.setup(N, Q)
%time Poly.random()
a = Poly([1,2,3,4,5,6,7,8], Q)
a2 = PolynomialRing(Zmod(0), 'x').quotient(x**N + 1)([1,2,3,4,5,6,7,8])
# a.auto(2)
a = a % 8
a = a % 0
a = a % 1024


CPU times: user 718 µs, sys: 1 µs, total: 719 µs
Wall time: 740 µs
True


In [394]:
# benchmarks

N = 2**15
print(N)
Q = 2**1000
Poly.setup(N, Q)
a = Poly.random()
b = Poly.random()
# %time a << 324
%time b = a.auto5()
# %time b = a.auto(13)
a = a % Q
%time b = a % 0

32768
CPU times: user 12.2 ms, sys: 987 µs, total: 13.1 ms
Wall time: 13.2 ms
CPU times: user 305 ms, sys: 3.04 ms, total: 308 ms
Wall time: 313 ms
