# A fast class for polynomials in SageMath

## Motivation

FHE needs efficient operations in its native ring $\mathbb Z[X] / \langle X^N+1 \rangle $.
Unfortunately, the reduction mod $X^N+1$ and the automorphism evaluations $X \longmapsto X^5$ for example, are slow.
There are a couple more operations, which are FHE friendly, such as the centered reduction mod Q.

## Benchmarks

We test our code with the FHE standard $N=2^{15}$ and a modulus $Q = 2^{1000}$.

In [1]:
from sage.libs.ntl import *

def set_ntl(element, modulus=None):
    # chooses between ZZX and ZZ_pX based on modulus given
    if modulus == 0 or modulus is None:
        return ntl.ZZX(element)
    else:
        return ntl.ZZ_pX(element, modulus)

class Poly:
    @classmethod
    def setup(cls, N=N, modulus=None):
        cls.N = N
        cls.modulus = 0 if modulus is None else modulus
        cls.R = PolynomialRing(ZZ, 'x') # for printing
        
    @classmethod
    def random(cls, modulus=None): # slow ~ 400ms, about 10x slower than NTL
        assert hasattr(cls, 'N'), "Poly.setup() must be called first"
        m = modulus if modulus is not None else cls.modulus
        ring = PolynomialRing(Zmod(m), 'x').quotient(x**cls.N + 1)
        return Poly(set_ntl(ring.random_element().list(), m))
    
    def __init__(self, coeffs, modulus=None):
        if modulus:
            assert modulus == self.modulus, f"Modulus mismatch {modulus} != {self.modulus}"
        self.modulus = self.modulus if modulus is None else modulus
        
        if isinstance(coeffs, list):
            assert len(coeffs) <= self.N
            self.c = set_ntl(coeffs, modulus)
        elif isinstance(coeffs, Poly):
            self.c = coeffs.c
        else:    
            self.c = coeffs
    
    def __mod__(self, modulus): # fast, 1-2ms
        Poly.modulus = modulus
        if modulus == 0:
            return Poly(self.c, 0)
        else:
            tmp = self.c.convert_to_modulus(ntl.ZZ_pContext(modulus))
            return Poly(tmp, modulus)
            
    def __repr__(self): # doesn't need to be fast
        if self.modulus != 0:
            # we perform the central reduction in [-Q//2, Q//2)
            temp = set_ntl((self % self.modulus).c.list()).list() # reduced in ZZX
            temp = [a if a < self.modulus//2 else a - self.modulus for a in temp]
            return str(self.R(temp))
        else:
            return str(self.R(self.c.list()))
    
    def __setitem__(self, key, value):
        self.c[key] = value
        
    def __getitem__(self, key):
        return self.c[key]
    
    def __add__(self, other):
        return Poly(self.c + other.c)

    def __iadd__(self, other):
        self.c += other.c
        return self
    
    def __sub__(self, other):
        return Poly(self.c - other.c)
    
    def __isub__(self, other):
        self.c -= other.c
        return self

    def __neg__(self):
        return Poly(-self.c)
    
    def mod_quo(self, element, minus=True):
        if minus:
            return element.truncate(self.N) - element.left_shift(-self.N)
        else:
            return element.truncate(self.N) + element.left_shift(-self.N)
        
    def __mul__(self, other): # about 90ms, if you square it's about 60ms
        product = self.c * other.c
        return Poly(self.mod_quo(product))
    
    def __imul__(self, other):
        self.c *= other.c
        self.c = self.mod_quo(self.c)
        return self
    
    def __pow__(self, n):
        pass
    
    def __truediv__(self, other):
        pass
    
    def __lshift__(self, n): # about 1-2ms
        temp = self.c.left_shift(n % self.N)
        return Poly(self.mod_quo(temp, minus=False))
    
    def __rshift__(self, n):
        return self << (self.N - n)
    
    def auto5(self):
        pass

In [2]:
# baby cases
N = 4
Q = next_prime(2**10)
Q = 2**10
Poly.setup(N, Q)
%time Poly.random()
a = Poly([1,2,3,4], Q)
print(a * a)
%time a *= a
print(a)

CPU times: user 25.6 ms, sys: 2.47 ms, total: 28.1 ms
Wall time: 71.5 ms
20*x^3 - 6*x^2 - 20*x - 24
CPU times: user 6 µs, sys: 1e+03 ns, total: 7 µs
Wall time: 6.91 µs
20*x^3 - 6*x^2 - 20*x - 24


In [3]:
# benchmarks

N = 2**4
Q = 2**1000
Poly.setup(N, Q)
%time a = Poly.random()
b = Poly.random()

%time a << 324
print(a[0])



CPU times: user 5.39 ms, sys: 439 µs, total: 5.83 ms
Wall time: 11.1 ms
CPU times: user 10 µs, sys: 0 ns, total: 10 µs
Wall time: 11 µs
9824651439357964690538867958447619784132110040075901022547355111970917342425386668252313818066907981753421110753582497699969924416128622128469056884622678367391024041750715045392979793666492304532998147010344784834466714381375926919774390140133788544975913127778763782877653685852648471714915086410536
