In [1]:
"""Example of CKKS multiplication."""

from ckks.ckks_encoder import CKKSEncoder
from ckks.ckks_key_generator import CKKSKeyGenerator
from ckks.ckks_parameters import CKKSParameters
from ckks.ckks_evaluator import CKKSEvaluator


poly_degree = 8
ciph_modulus = 1 << 600
big_modulus = 1 << 1200
scaling_factor = 1 << 30
params = CKKSParameters(poly_degree=poly_degree,
                        ciph_modulus=ciph_modulus,
                        big_modulus=big_modulus,
                        scaling_factor=scaling_factor)
key_generator = CKKSKeyGenerator(params)

public_key = key_generator.public_key
secret_key = key_generator.secret_key
relin_key = key_generator.relin_key

encoder = CKKSEncoder(params)
evaluator = CKKSEvaluator(params)

message1 = [0.5, 0.3 + 0.2j, 0.78, 0.88j]
message2 = [0.2, 0.11, 0.4 + 0.67j, 0.9 + 0.99j]
plain1 = encoder.encode(message1, scaling_factor)
plain2 = encoder.encode(message2, scaling_factor)

print(encoder.decode(plain2))


[(0.4024999998509884+0.4150000000372529j), (-0.26337365712970495+0.14156470727175474j), (-0.12904698681086302+0.01590990275144577j), (-0.2616360504180193-0.32214565947651863j)]
[(0.20000000042976923+2.5262619773869233e-09j), (0.11000000033253365+6.007523922768598e-10j), (0.4000000000937715+0.669999998864071j), (0.8999999985478793+0.9899999981579262j)]


In [52]:
import numpy as np
import sympy as sp

def polynomial(x):
    return x**2 + 1

def roots_of_unity(n):
    """Calculate the n-th roots of unity."""
    return np.exp(2j * np.pi * np.arange(n) / n)

def reconstruct_polynomial(fft_length):
    # Ensure the FFT length is at least 4
    if fft_length < 4:
        raise ValueError("FFT length must be at least 4.")
    
    # Calculate the n-th roots of unity
    roots = roots_of_unity(fft_length)
    print(roots)
    # Evaluate the polynomial at the roots of unity
    evaluations = [polynomial(r) for r in roots]
    
    # Use numpy's IFFT to get the polynomial coefficients
    coefficients = np.fft.ifft(evaluations).real

    # Create the polynomial using sympy
    x = sp.symbols('x')
    reconstructed_poly = sum(c * x**i for i, c in enumerate(coefficients))

    # Print results
    print(f"Evaluations at the {fft_length}-th roots of unity: {evaluations}")
    print(f"Coefficients from IFFT: {coefficients}")
    print(f"Reconstructed Polynomial: {reconstructed_poly}")

    # Convert to sympy polynomial for exact representation
    reconstructed_poly_sympy = sp.poly(reconstructed_poly, x)
    print(f"Sympy Polynomial: {reconstructed_poly_sympy}")

# Example usage with FFT length 8
reconstruct_polynomial(fft_length=16)

[ 1.00000000e+00+0.00000000e+00j  9.23879533e-01+3.82683432e-01j
  7.07106781e-01+7.07106781e-01j  3.82683432e-01+9.23879533e-01j
  6.12323400e-17+1.00000000e+00j -3.82683432e-01+9.23879533e-01j
 -7.07106781e-01+7.07106781e-01j -9.23879533e-01+3.82683432e-01j
 -1.00000000e+00+1.22464680e-16j -9.23879533e-01-3.82683432e-01j
 -7.07106781e-01-7.07106781e-01j -3.82683432e-01-9.23879533e-01j
 -1.83697020e-16-1.00000000e+00j  3.82683432e-01-9.23879533e-01j
  7.07106781e-01-7.07106781e-01j  9.23879533e-01-3.82683432e-01j]
Evaluations at the 16-th roots of unity: [np.complex128(2+0j), np.complex128(1.7071067811865475+0.7071067811865476j), np.complex128(1+1.0000000000000002j), np.complex128(0.29289321881345254+0.7071067811865477j), np.complex128(1.2246467991473532e-16j), np.complex128(0.29289321881345254-0.7071067811865475j), np.complex128(0.9999999999999998-1j), np.complex128(1.7071067811865475-0.7071067811865477j), np.complex128(2-2.4492935982947064e-16j), np.complex128(1.707106781186548+0.70

In [67]:
import numpy as np

def canonical_embedding_encoder(message, n, q):
    # if len(message) != n:
    #     raise ValueError("Message length must be equal to the polynomial degree n.")
    
    encoded_poly = np.fft.fft(message, n=n)
    encoded_poly = np.round(encoded_poly.real).astype(int) % q

    return encoded_poly

def canonical_embedding_decoder(encoded_poly, n, q):
    decoded_poly = np.fft.ifft(encoded_poly, n=n)
    decoded_message = np.round(decoded_poly.real).astype(int) % q

    return decoded_message

# Example usage
message = [2, 0, 2, 0]
n = 16
q = 16

# Encoding the message
encoded_poly = canonical_embedding_encoder(message, n, q)
print("Encoded Polynomial Coefficients:", encoded_poly)

# Decoding the polynomial
decoded_message = canonical_embedding_decoder(encoded_poly, n, q)
print("Decoded Message:", decoded_message)

Encoded Polynomial Coefficients: [4 3 2 1 0 1 2 3 4 3 2 1 0 1 2 3]
Decoded Message: [2 0 1 0 0 0 0 0 0 0 0 0 0 0 1 0]


In [59]:
np.fft.fft([2,0,2,0], n=4)

array([4.+0.j, 0.+0.j, 4.+0.j, 0.+0.j])

In [63]:
from ckks.util.ntt import FFTContext
FFTContext(16).embedding_inv([2,0,2,0])

[1.0, 0j, (0.7071067811865476-0.7071067811865476j), 0j]

In [11]:
import numpy as np

def canonical_embedding(coefficients):
    """
    Computes the canonical embedding of a polynomial evaluated at specific roots of unity.
    
    The polynomial is evaluated at w, w^5, w^25, ..., where w is a primitive root of unity,
    and the powers of w are indexed 1 (mod 4) such as 1, 5, 5^2, 5^3, ...
    
    :param coefficients: List of coefficients of the polynomial (from lowest to highest degree).
    :return: List of evaluations of the polynomial at specific roots of unity.
    """
    # Degree of the polynomial is determined by the number of coefficients.
    n = len(coefficients)
    
    # Primitive nth root of unity, where n is large enough to avoid truncation errors
    w = np.exp(2j * np.pi / n)
    
    # Powers of w (1 mod 4): 1, 5, 25, 125, ...
    powers_of_w = []
    power = 1
    while power < n:
        powers_of_w.append(w**power)
        power *= 5
    
    # Polynomial evaluation at w^1, w^5, w^25, ...
    evaluations = []
    for w_power in powers_of_w:
        eval_value = sum(c * (w_power**i) for i, c in enumerate(coefficients))
        evaluations.append(eval_value)
    
    return evaluations

In [61]:
print(encoder.fft.embedding([1, 2, 3, 4]))
print(canonical_embedding([(0.4024999998509884+0.4150000000372529j), (-0.26337365712970495+0.14156470727175474j), (-0.12904698681086302+0.01590990275144577j), (-0.2616360504180193-0.32214565947651863j)]))

[(6.499813138042575+6.58220533833497j), (1.808830921755325-1.8042950079974285j), (-0.25717245092328955-2.3395646512156842j), (-4.0514716088746106-2.4383456791218574j)]
[np.complex128(0.06783661991357809+0.3973524905741215j)]


In [116]:
from math import sqrt
-(35 + 15*sqrt(6))**(1/3)/3 + 1/3 + 5/(3*(35 + 15*sqrt(6))**(1/3))

-0.650629191439388

In [90]:
from math import cos, sin, pi, log
fft_length = 16
num_slots = fft_length // 4
roots_of_unity = [0] * fft_length
roots_of_unity_inv = [0] * fft_length
for i in range(fft_length):
    angle = 2 * pi * i / fft_length
    roots_of_unity[i] = complex(cos(angle), sin(angle))
    roots_of_unity_inv[i] = complex(cos(-angle), sin(-angle))
rot_group = [1] * num_slots
for i in range(1, num_slots):
    rot_group[i] = (5 * rot_group[i - 1]) % fft_length

In [103]:
from ckks.util.bit_operations import bit_reverse_vec, reverse_bits
coeffs = [1,3,7,15]
num_coeffs = len(coeffs)
result = coeffs.copy()
log_num_coeffs = int(log(num_coeffs, 2))

for logm in range(log_num_coeffs, 0, -1):
    idx_mod = 1 << (logm + 2)
    gap = fft_length // idx_mod
    for j in range(0, num_coeffs, 1 << logm):  # j=0, 2^(logm), 2^(logm+1), ...
        for i in range(1 << (logm - 1)): # i=0...2^(logm-1)
            index_even = j + i # 0...2^(logm-1) + [0, 1, 2, 4]
            index_odd = j + i + (1 << (logm - 1))
            rou_idx = (rot_group[i] % idx_mod) * gap

            butterfly_plus = result[index_even] + result[index_odd]
            butterfly_minus = result[index_even] - result[index_odd]
            butterfly_minus *= roots_of_unity_inv[rou_idx]
            result[index_even] = butterfly_plus
            print(result)
            result[index_odd] = butterfly_minus
            print(result)

to_scale_down = bit_reverse_vec(result)

for i in range(num_coeffs):
    to_scale_down[i] /= num_coeffs


[8, 3, 7, 15]
[8, 3, (-5.54327719506772+2.2961005941905386j), 15]
[8, 18, (-5.54327719506772+2.2961005941905386j), 15]
[8, 18, (-5.54327719506772+2.2961005941905386j), (4.592201188381077+11.08655439013544j)]
[26, 18, (-5.54327719506772+2.2961005941905386j), (4.592201188381077+11.08655439013544j)]
[26, (-7.0710678118654755+7.0710678118654755j), (-5.54327719506772+2.2961005941905386j), (4.592201188381077+11.08655439013544j)]
[26, (-7.0710678118654755+7.0710678118654755j), (-0.9510760066866428+13.382654984325978j), (4.592201188381077+11.08655439013544j)]
[26, (-7.0710678118654755+7.0710678118654755j), (-0.9510760066866428+13.382654984325978j), (-13.38265498432598+0.9510760066866428j)]


In [100]:
1 << (log_num_coeffs - 1)

2