# Fast Discrete Fourier Transform

Below, we implement two classes:
- DFT module: computes the fast discrete Fourier transform of a vector of complex numbers using mainly the numpy module (with fixed precision). It serves as a framework for
- the homomorphic FHE\_DFT module. It contains two main functions, the homomorphic Slot2Coeff (for bootstrapping) and the encode\_clear function to encode a vector of integers into a polynomial.


In [1]:
%%writefile DFT_class.sage

import numpy as np
from mpmath import mp, mpc
load("sagefhepoly/polyfhe.sage")
load("root_of_unity.sage")

def br(bit, length): # bit-reversal
    return int('{:0{width}b}'.format(bit, width=length)[::-1], 2)

class DFT: # a module to perform the DFT in the clear
    roots = None
    
    def __init__(self, N, precision=53) -> None:
        self.N = N
        self.prec = precision
        if precision > 53:
            mp.dps = precision * log(2, 10)
        assert N & (N - 1) == 0, "N must be a power of 2"
        self.log_N = ZZ(log(N, 2))

        self.gen_BR_lookup()
        self.gen_roots()
        self.encoding_precomputation()
        
    def gen_BR_lookup(self):
        # generates a look-up table for the bit-reversal permutation
        assert hasattr(self, 'N'), "N is not defined, initialize the DFT class first"
        self.BR = [np.array([br(i, l) for i in range(2**l)], dtype=int) for l in range(self.log_N + 1)]
                
    def gen_roots(self): # 2 MB for N = 2^15
        self.roots_BR     = [[]] + [[0] * 2**(k-1) for k in range(1, self.log_N)]
        self.roots        = [[]] + [[0] * 2**(k-1) for k in range(1, self.log_N)]
        self.roots_BR_inv = [[]] + [[0] * 2**(k-1) for k in range(1, self.log_N)]
        for k in range(1, self.log_N):
            step = self.N // 2**(k+1)
            self.roots_BR[k]     = [root(2*self.N, 5**(self.BR[k-1][i]) * step, self.prec) for i in range(2**(k-1))]
            self.roots_BR_inv[k] = [root(2*self.N,-5**(self.BR[k-1][i]) * step, self.prec) for i in range(2**(k-1))]
            self.roots[k]        = [root(2*self.N, 5**(i)               * step, 53) for i in range(2**(k-1))]
        DFT.roots = self.roots
            
    def convert_poly_to_C(self, poly, precision: int=53):
        # converting a polynomial to C representation with the appropriate precision
        if precision > 53:
            return np.array([mpc(ZZ(i)) for i in poly.list(full = True)], dtype=object)
        else:
            return np.array([ZZ(i) for i in poly.list(full = True)], dtype=complex)
            
    def decode_no_to_bo(self, poly, precision: int=53): # 369ms
        # decoding, normal to bit-reversed order
        p = self.convert_poly_to_C(poly, precision)
        v = p[:self.N//2] + p[self.N//2:] * 1j
        for k in range(1, self.log_N):
            step = self.N // 2**(k+1)
            for i in range(2**(k-1)):
                for j in range(step):
                    index = i * step * 2 + j
                    a = v[index]
                    prod = self.roots_BR[k][i] * v[index + step]
                    v[index] = a + prod
                    v[index + step] = a - prod        
        return v
    
    def decode_bo_to_no(self, poly, precision: int=53): # 353ms
        # decoding, bit-reversed to normal order
        p = self.convert_poly_to_C(poly, precision)
        v = p[::2] + p[1::2] * 1j
        for k in range(1, self.log_N):
            step = self.N // 2**(k+1)
            shift1, shift2 = 2**(k-1), 2**k
            for i in range(shift1):
                for j in range(step):
                    index = i + j * shift2
                    a = v[index]
                    prod = self.roots[k][i] * v[index + shift1]
                    v[index] = a + prod
                    v[index + shift1] = a - prod
        return v
    
    def encode_bo_to_no(self, vect, precision: int=53): # 817ms
        # encoding, bit-reversed to normal order, OVERWRITES the input vector
        v = vect
        assert len(v) == self.N // 2, "Input vector has the wrong length"
        assert type(v) == np.ndarray, "Input vector must be a numpy array"
        for k in reversed(range(1, self.log_N)):
            for i in range(2**(k-1)):
                step = self.N // 2**(k+1)
                for j in range(step):
                    index = i * step * 2 + j
                    a = v[index]
                    b = v[index + step]
                    v[index] = (a + b)
                    v[index + step] = (a - b) * self.roots_BR_inv[k][i]
            v /= 2
        return np.append(v.real, v.imag)
    
    def encode_no_to_no(self, vect, precision: int=53): # 827ms
        # encodes a normal order vector to normal order, by applying bit-reversal and then encoding
        v = vect
        assert type(v) == np.ndarray, "Input vector must be a numpy array"
        lg = ZZ(log(len(v), 2))
        reverse = np.array([v[br(i, lg)] for i in range(len(v))], dtype=type(v[0]))
        return self.encode_bo_to_no(reverse, precision)
    
    def encode_fast(self, vec, bitrev: bool=False): # 1 ms
        assert type(vec) == np.ndarray, "Input vector must be a numpy array"
        assert hasattr(self, 'sequence'), "Precomputation not done, call pre() first"
        if not bitrev: vec = vec[self.BR[-2]]
        #if not bitrev: vec = vec[self.BR[log(len(vec), 2)]]
        delta, index = 2, 1
        for _ in range(self.log_N - 1):
            vec = np.reshape(vec, (self.N//delta, delta//2))
            a, b = vec[::2], vec[1::2]
            vec[::2], vec[1::2] = (a + b), (a - b) * self.sequence[index]
            vec /= 2
            delta *= 2
            index += 1
        vec = np.reshape(vec, self.N//2)
        return np.append(vec.real, vec.imag)
    
    def encoding_precomputation(self):
        delta, list_roots = 2, []
        for k in range(self.log_N - 1, 0, -1):
            for i in range(self.N // 2 // delta):
                u = root(2*self.N, (5 ** self.BR[k-1][i]) * self.N // (2**(k+1)))
                list_roots.append(u)
            delta *= 2
        sequence, index, delta = [[]], 0, 2
        for _ in range(self.log_N - 1):
            step = self.N // 2 // delta
            A = [[list_roots[i] ** (-1)] for i in range(index, index + step)]
            sequence.append(np.array(A))
            index += step
            delta *= 2
        self.sequence = sequence

Overwriting DFT_class.sage


In [2]:
%%writefile FHE_DFT_class.sage

load("DFT_class.sage")

def next_power_of_2(x):
    return 1 << (floor(x) - 1).bit_length()

class FHE_DFT(DFT):
    def __init__(self, N, slots, modulus, precision=53) -> None:
        super().__init__(N, precision)
        assert slots & (slots - 1) == 0, "slots must be a power of 2"
        assert slots <= N//2, "slots must be <= N/2"
        assert hasattr(Poly, "N"),  "Poly module not loaded"
        self.delta = 2 ** precision
        self.n = slots
        self.log_n = ZZ(log(slots, 2))
        self.mod = modulus
        if self.n > 1:
            self.precompute()
    
    # this will be applied to encode plaintext values into the polynomial ring
    def encode_clear(self, values, delta=None, modulus = None, bitrev: bool=False): # 1.2 ms
        if type(values) == list:
            values = np.array(values, dtype=complex)
        assert type(values) == np.ndarray, "Input vector must be a numpy array"
        modulus = self.mod if modulus is None else modulus
        complex_encoding = super().encode_fast(values, bitrev=bitrev)
        maximum = np.max(np.abs(complex_encoding))
        d = self.delta if delta is None else delta
        if maximum * d >= 2**63 - 1:
            # Encoding is too large for numpy's int64
            # to maintain speed, we first scale up the CC(64) values by the maximum
            first_scale = (2**62) // next_power_of_2(maximum)
            complex_encoding = (complex_encoding * first_scale).astype(int)
            p = Poly(set_ntl(complex_encoding, modulus), modulus)
            return p * (d // first_scale) # and now we scale the rest in NTL
        int_array = np.round(complex_encoding * d).astype(int)
        return Poly(set_ntl(int_array, modulus), modulus)
    
    def precompute(self, scaling = True):
        # We create the polynomial encoding for the primitive roots
        # First, for the Slot2Coeff precomputation
        v = [[0, 0, 0] for _ in range(self.log_n)]
        scale = 1/(2*I) if scaling else 1
        li1 = [scale     for _ in range(self.n // 2)]
        li2 = [scale * I for _ in range(self.n // 2)]
        v[0][0] = (li1 + li2) * (self.N // self.n // 2)
        v[0][1] = (li2 + li1) * (self.N // self.n // 2)
        v[0] = [self.encode_clear(i, modulus=0) for i in v[0][:2]] + [0]
        
        delta, s = self.N // 2, 1 # s = "scale"
        for k in range(1, self.log_n):
            if k == self.log_n - 1 and scaling: s = 2 / self.n 
            r = super().roots[k]
            v[k][0] = [[s, -r[i] * s][j] for _ in range(delta // 2) for j in range(2) for i in range(self.N//2//delta)]
            v[k][1] = [[r[i] * s, 0] [j] for _ in range(delta // 2) for j in range(2) for i in range(self.N//2//delta)]
            v[k][2] = [[0, s]        [j] for _ in range(delta // 2) for j in range(2) for i in range(self.N//2//delta)]
            v[k] = [self.encode_clear(i, modulus=0) for i in v[k]]
            delta //= 2
        self.poly_roots_S2C = v
        
        # Now for the Coeff2Slot precomputation
        v = [[0, 0, 0] for _ in range(self.log_n)]
        li1 = ([0.5] * (self.n//2) + [0.5*I] * (self.n//2)) * (self.N // self.n // 2)
        li2 = ([0.5] * (self.n//2) + [-0.5*I] * (self.n//2)) * (self.N // self.n // 2)
        v[0] = [self.encode_clear(i, modulus=0) for i in [li1, li2]] + [0]
        for k in range(self.log_n-1, 0, -1):
            delta = self.N // (2**k)
            r = super().roots[k]
            v[k][0] = [[0.5, -0.5 / r[i]][j] for _ in range(delta // 2) for j in range(2) for i in range(self.N//2//delta)]
            v[k][1] = [[0.5, 0]          [j] for _ in range(delta // 2) for j in range(2) for i in range(self.N//2//delta)]
            v[k][2] = [[0  , 0.5 / r[i]] [j] for _ in range(delta // 2) for j in range(2) for i in range(self.N//2//delta)]
            v[k] = [self.encode_clear(i, modulus=0) for i in v[k]]
        self.poly_roots_C2S = v
    
    
    ## this is the ONLY method, that will be used for bootstrapping
    def Slot2Coeff(self, poly):
        # homomorphic Slot2Coeff with input in bit-reversed order
        p = poly
        if self.n == 1: # dividing by I
            return -p.monomial_shift(self.N//2)
        assert type(p) == Poly or type(p.a) == Poly, "Input must be a polynomial/RLWE/CKKS"
        p = p * self.poly_roots_S2C[0][0] + p.auto(self.n // 2) * self.poly_roots_S2C[0][1]
        p = p >> 1
        for k in range(1, self.log_n):
            v0, v1, v2 = self.poly_roots_S2C[k][0], self.poly_roots_S2C[k][1], self.poly_roots_S2C[k][2]
            step = 2 ** (k-1)
            p = p * v0 + p.auto(step) * v1 + p.auto(-step) * v2
            p = p >> 1
        return p
    
    # may not work properly
    def Coeff2Slot(self, poly): # 168 ms
        # homomorphic Coeff2Slot with output in ?? order
        p = poly
        if self.n == 1: return poly.monomial_shift(self.N//2)
        assert type(p) == Poly or type(p.a) == Poly, "Input must be a polynomial/RLWE/CKKS"
        for k in range(self.log_n-1, 0, -1):
            v0, v1, v2 = self.poly_roots_C2S[k][0], self.poly_roots_C2S[k][1], self.poly_roots_C2S[k][2]
            step = 2 ** (k-1)
            p = (p * v0 + p.auto(step) * v1 + p.auto(-step) * v2)
            p = p >> 1
        p = p.auto_inverse() * self.poly_roots_C2S[0][0] + p * self.poly_roots_C2S[0][1]
        return p >> 1


Overwriting FHE_DFT_class.sage


In [3]:
from time import time

if __name__ == "__main__":
    load("FHE_DFT_class.sage")
    # Testing
    N = 2**12
    n = N // 2
    Q = 2**1000
    precision = 20
    Poly.setup(N, N*2)
    t = time()
    F = FHE_DFT(N, n, Q, precision=precision)
    print("Init time:", time() - t)

    test_array = np.array([1] + [1 for i in range(N//2-1)], dtype=complex)
    #print(F.encode_clear(test_array), 2**20)
    test_poly = Poly([i for i in range(N)], N*2) % 0
    %timeit F.Slot2Coeff(test_poly)

Init time: 0.7379488945007324
92 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
