# A fast class for polynomials in SageMath

## Motivation

FHE needs efficient operations in its native ring $\mathbb Z[X] / \langle X^N+1 \rangle $.
Unfortunately, the reduction mod $X^N+1$ and the automorphism evaluations $X \longmapsto X^5$ for example, are slow.
There are a couple more operations, which are FHE friendly, such as the centered reduction mod Q.

## Benchmarks

We test our code with the FHE standard $N=2^{15}$ and a modulus $Q = 2^{1000}$.

In [64]:
from sage.libs.ntl import *
import time

def set_ntl(element, modulus=None):
    # chooses between ZZX and ZZ_pX based on modulus given
    if modulus == 0 or modulus is None:
        return ntl.ZZX(element)
    else:
        return ntl.ZZ_pX(element, modulus)

class Poly:
    @classmethod
    def setup(cls, N=N, modulus=None):
        cls.N, cls.N2 = N, N*2
        cls.modulus = 0 if modulus is None else modulus
        cls.R = PolynomialRing(ZZ, 'x') # for printing
        cls.precomputations()
        
    @classmethod
    def precomputations(cls):
        cls.auto5indices = [ZZ((i * Zmod(cls.N2)(5)) % (cls.N2)) for i in range(cls.N)]
        cls.zero_array = [0] * cls.N
        cls.zero_pX = ntl.ZZ_pX(cls.zero_array, cls.modulus)
        cls.zero_X = ntl.ZZX(cls.zero_array)
              
    @classmethod
    def random(cls, modulus=None): # slow ~ 400ms, about 10x slower than NTL
        # maybe we use the NTL random function instead
        m = modulus if modulus is not None else cls.modulus
        ring = PolynomialRing(Zmod(m), 'x').quotient(x**cls.N + 1)
        return Poly(set_ntl(ring.random_element().list(), m))


    def __init__(self, coeffs, modulus=None):
        if modulus:
            assert modulus == self.modulus, f"Modulus mismatch {modulus} != {self.modulus}"
        self.modulus = self.modulus if modulus is None else modulus
        
        if isinstance(coeffs, list):
            assert len(coeffs) <= self.N
            self.c = set_ntl(coeffs, modulus)
        elif isinstance(coeffs, Poly):
            self.c = coeffs.c
        else:    
            self.c = coeffs
    
    # ARITHMETIC OPERATORS
    
    def __add__(self, other):
        return Poly(self.c + other.c)

    def __iadd__(self, other):
        self.c += other.c
        return self
    
    def __sub__(self, other):
        return Poly(self.c - other.c)
    
    def __isub__(self, other):
        self.c -= other.c
        return self

    def __neg__(self):
        return Poly(-self.c)
    
    def __mul__(self, other): # about 90ms, if you square it's about 60ms
        product = self.c * other.c
        return Poly(self.mod_quo(product))
    
    def __imul__(self, other):
        self.c *= other.c
        self.c = self.mod_quo(self.c)
        return self
    
    def __pow__(self, n):
        pass
        # result = 1
        # while exponent > 0:
        #     if exponent % 2 == 1:
        #         result *= base
        #     base *= base
        #     exponent //= 2
        # return result
    
    def __truediv__(self, other):
        pass
    
    # MODULAR  OPERATORS
    
    def __mod__(self, modulus): # fast, 1-2ms
        Poly.modulus = modulus
        if modulus == 0:
            return Poly(self.c, 0)
        else:
            tmp = self.c.convert_to_modulus(ntl.ZZ_pContext(modulus))
            return Poly(tmp, modulus)
    
    def mod_quo(self, element, minus=True):
        if minus:
            return element.truncate(self.N) - element.left_shift(-self.N)
        else:
            return element.truncate(self.N) + element.left_shift(-self.N)
        
        
    ## SHIFTS AND AUTOMORPHISMS
        
    def __lshift__(self, n): # about 1-2ms
        temp = self.c.left_shift(n % self.N)
        return Poly(self.mod_quo(temp, minus=False))
    
    def __rshift__(self, n):
        return self << (self.N - n)
    
    def auto5(self): # 11 ms atm
        result = copy(self.zero_X) if self.modulus == 0 else copy(self.zero_pX)
        for i in range(self.N):
            result[self.auto5indices[i]] = self[i]
        return Poly(self.mod_quo(result), self.modulus)
    
    # OTHER METHODS
    
    def __copy__(self):
        return Poly(copy(self.c), self.modulus)
    
    def zero(self, modulus=None): # gives back the zero polynomial in the correct ring
        if modulus:
            if modulus == 0:
                return Poly(copy(self.zero_X), 0)
            result = copy(self.zero_pX).convert_to_modulus(ntl.ZZ_pContext(modulus))
            return Poly(result, modulus)
        zero = copy(self.zero_pX) if self.modulus != 0 else copy(self.zero_X)
        return Poly(zero, self.modulus)
    
    def __repr__(self): # doesn't need to be fast
        if self.modulus != 0:
            # we perform the central reduction in [-Q//2, Q//2)
            temp = set_ntl((self % self.modulus).c.list()).list() # reduced in ZZX
            temp = [a if a < self.modulus//2 else a - self.modulus for a in temp]
            return str(self.R(temp))
        else:
            return str(self.R(self.c.list()))
    
    def __setitem__(self, key, value):
        self.c[key] = value
        
    def __getitem__(self, key):
        return self.c[key]

In [69]:
# baby cases
N = 4
Q = 2**10
Poly.setup(N, Q)
%time Poly.random()
a = Poly([1,2,3,40], Q)
print(a)

print(a.auto5())
print(a % 4)
print(a)
print(a.auto5())


CPU times: user 421 µs, sys: 5 µs, total: 426 µs
Wall time: 444 µs
40*x^3 + 3*x^2 + 2*x + 1
-40*x^3 + 3*x^2 - 2*x + 1
-x^2 - 2*x + 1
40*x^3 + 3*x^2 + 2*x + 1
-40*x^3 + 3*x^2 - 2*x + 1


In [70]:
# benchmarks

N = 2**15
Q = 2**1000
Poly.setup(N, Q)
%time a = Poly.random()
b = Poly.random()

%time a << 324
%time a.auto5()
%time a.zero()


CPU times: user 362 ms, sys: 36.1 ms, total: 398 ms
Wall time: 408 ms
CPU times: user 2 ms, sys: 2.42 ms, total: 4.43 ms
Wall time: 6.35 ms
CPU times: user 13.2 ms, sys: 1.59 ms, total: 14.8 ms
Wall time: 15.3 ms
CPU times: user 12.3 ms, sys: 1.5 ms, total: 13.8 ms
Wall time: 14.1 ms
CPU times: user 11.9 ms, sys: 1.59 ms, total: 13.5 ms
Wall time: 13.5 ms
CPU times: user 12.4 ms, sys: 1.65 ms, total: 14 ms
Wall time: 14 ms
CPU times: user 6 µs, sys: 0 ns, total: 6 µs
Wall time: 6.91 µs


0