In [7]:
import numpy as np
import sympy as sp
from math import cos, sin, pi, log


In [38]:
from numpy.polynomial import Polynomial

class CKKSEncoder:
    """Basic CKKS encoder to encode complex vectors into polynomials."""
    
    def __init__(self, M: int):
        """Initialization of the encoder for M a power of 2. 
        
        xi, which is an M-th root of unity will, be used as a basis for our computations.
        """
        self.xi = np.exp(2 * np.pi * 1j / M)
        self.M = M
        
    @staticmethod
    def vandermonde(xi: np.complex128, M: int) -> np.array:
        """Computes the Vandermonde matrix from a m-th root of unity."""
        
        N = M //2
        matrix = []
        # We will generate each row of the matrix
        for i in range(N):
            # For each row we select a different root
            root = xi ** (2 * i + 1)
            row = []

            # Then we store its powers
            for j in range(N):
                row.append(root ** j)
            matrix.append(row)
        return matrix
    
    def sigma_inverse(self, b: np.array) -> Polynomial:
        """Encodes the vector b in a polynomial using an M-th root of unity."""

        # First we create the Vandermonde matrix
        A = CKKSEncoder.vandermonde(self.xi, self.M)

        # Then we solve the system
        coeffs = np.linalg.solve(A, b)

        # Finally we output the polynomial
        p = Polynomial(coeffs)
        return p

    def sigma(self, p: Polynomial) -> np.array:
        """Decodes a polynomial by applying it to the M-th roots of unity."""

        outputs = []
        N = self.M //2

        # We simply apply the polynomial on the roots
        for i in range(N):
            root = self.xi ** (2 * i + 1)
            output = p(root)
            outputs.append(output)
        return np.array(outputs)
b = CKKSEncoder(8)
b.sigma_inverse([2, 0, 2, 0])

Polynomial([ 1.00000000e+00-7.39557099e-32j,  3.14018492e-16+0.00000000e+00j,
       -2.46519033e-32-1.00000000e+00j, -3.14018492e-16+9.86076132e-32j], domain=[-1.,  1.], window=[-1.,  1.], symbol='x')

In [28]:
def embedding(coeffs):
    """Computes a variant of the canonical embedding on the given coefficients.

    Computes the canonical embedding which consists of evaluating a given polynomial at roots of unity
    that are indexed 1 (mod 4), w, w^5, w^9, ...
    The evaluations are returned in the order: w, w^5, w^(5^2), ...

    Args:
        coeffs (list): List of complex numbers to transform.

    Returns:
        List of transformed coefficients.
    """
    fft_length = 8
    num_slots = fft_length // 4
    roots_of_unity = [0] * fft_length
    roots_of_unity_inv = [0] * fft_length
    for i in range(fft_length):
        angle = 2 * pi * i / fft_length
        roots_of_unity[i] = complex(cos(angle), sin(angle))
        roots_of_unity_inv[i] = complex(cos(-angle), sin(-angle))
    rot_group = [1] * num_slots
    for i in range(1, num_slots):
        rot_group[i] = (5 * rot_group[i - 1]) % fft_length

    num_coeffs = len(coeffs)
    result = coeffs
    log_num_coeffs = int(log(num_coeffs, 2))
    for logm in range(1, log_num_coeffs + 1):
        idx_mod = 1 << (logm + 2)
        gap = fft_length // idx_mod
        for j in range(0, num_coeffs, (1 << logm)):
            for i in range(1 << (logm - 1)):
                index_even = j + i
                index_odd = j + i + (1 << (logm - 1))

                rou_idx = (rot_group[i] % idx_mod) * gap
                omega_factor = roots_of_unity[rou_idx] * result[index_odd]

                butterfly_plus = result[index_even] + omega_factor
                butterfly_minus = result[index_even] - omega_factor

                result[index_even] = butterfly_plus
                result[index_odd] = butterfly_minus
    return result

In [33]:
embedding([2,0,2,0])

[(4+0j), (4+0j), 0j, 0j]

In [32]:
np.fft.ifft2([2, 0, 2, 0], n=4)

array([1.+0.j, 0.+0.j, 1.+0.j, 0.+0.j])

In [25]:
import numpy as np

def canonical_embedding_encoder(message, n):
    # if len(message) != n:
    #     raise ValueError("Message length must be equal to the polynomial degree n.")
    
    encoded_poly = np.fft.fft(message, n=n)
    # encoded_poly = np.round(encoded_poly.real).astype(int)

    return encoded_poly

def canonical_embedding_decoder(encoded_poly, n):
    decoded_poly = np.fft.ifft(encoded_poly, n=n)
    # decoded_message = np.round(decoded_poly.real).astype(int)

    return decoded_poly

# Example usage
message = [2, 0, 2, 0]
n = 8

# Encoding the message
encoded_poly = canonical_embedding_encoder(message, n)
print("Encoded Polynomial Coefficients:", encoded_poly)

# Decoding the polynomial
decoded_message = canonical_embedding_decoder(encoded_poly, n)
print("Decoded Message:", decoded_message)

Encoded Polynomial Coefficients: [4.+0.j 2.-2.j 0.+0.j 2.+2.j 4.+0.j 2.-2.j 0.+0.j 2.+2.j]
Decoded Message: [2.+0.j 0.+0.j 2.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j 0.+0.j]
