In [2]:
from ctypes import *
import numpy as np
_so = cdll.LoadLibrary('./xmkckks.so')
import copy
import time


In [3]:
class _Ldouble(Structure):
    _fields_ = [
        ('data', POINTER(c_double)),
        ('size', c_size_t)
    ]

class _Luint64(Structure):
    _fields_ = [
        ('data', POINTER(c_ulonglong)),
        ('size', c_size_t)
    ]

# class _Params(Structure):
#     _fields_ = [
#         ('qi', _Luint64),
#         ('pi', _Luint64),

#         ('logN', c_int),
#         ('logSlots', c_int),
#         ('gamma', c_int),

#         ('scale', c_double),
#         ('sigma', c_double)
#     ]
    
class _ParametersLiteral(Structure):
    _fields_ = [
        ('qi', _Luint64),
        ('pi', _Luint64),

        ('logN', c_int),
        ('logSlots', c_int),

        ('scale', c_double),
        ('sigma', c_double)
    ]
    
class _Poly(Structure):
    _fields_ = [
        ('coeffs', POINTER(_Luint64)),
        ('IsNTT', c_bool),
        ('IsMForm', c_bool),
        ('size', c_size_t)
    ]

# class _PolyPair(Structure):
#     _fields_ = [
#         ('p0', _Poly),
#         ('p1', _Poly)
#     ]
    
class _PolyQP(Structure):
    _fields_ = [
        ('Q', POINTER(_Poly)),
        ('P', POINTER(_Poly))
    ]

class _PolyQPPair(Structure):
    _fields_ = [
        ('qp0', _PolyQP),
        ('qp1', _PolyQP)
    ]
    
class _Share(Structure):
    _fields_ = [
        ('data', POINTER(_Poly)),
        ('size', c_size_t)
    ]
    
class _Ciphertext(Structure):
    _fields_ = [
        ('data', POINTER(_Poly)),
        ('size', c_size_t),
        ('idxs', POINTER(c_int)),
        ('scale', c_double),
        # ('isNTT', c_bool)
    ]

class _Data(Structure):
    _fields_ = [
        ('data', POINTER(_Ciphertext)),
        ('size', c_size_t)
    ]

class _MPHEServer(Structure):
    _fields_ = [
        # ('params', _Params),
        ('paramsLiteral', _ParametersLiteral),
        ('crs', _Poly),
        ('sk', _PolyQP),
        ('pk', _PolyQPPair),

        # ('secretKey', _Poly),
        ('data', _Data),
        ('idx', c_int),
    ]



In [4]:
_newMPHEServer = _so.newMPHEServer
_newMPHEServer.restype = POINTER(_MPHEServer)

_encryptFromPk = _so.encryptFromPk
_encryptFromPk.argtypes = [ POINTER(_PolyQPPair), POINTER(c_double), c_size_t, c_int]
_encryptFromPk.restype = POINTER(_Ciphertext)

_partialDecrypt = _so.partialDecrypt
_partialDecrypt.argtypes = [ POINTER(_PolyQP), POINTER(_Ciphertext), c_int]
_partialDecrypt.restype = POINTER(_Ciphertext)

_ringQAddLvl = _so.ringQAddLvl
_ringQAddLvl.argtypes = [ POINTER(_Ciphertext), c_int, POINTER(_Ciphertext), c_int]
_ringQAddLvl.restype = POINTER(_Ciphertext)

_decodeAfterPartialDecrypt = _so.decodeAfterPartialDecrypt
_decodeAfterPartialDecrypt.argtypes = [ POINTER(_Ciphertext) ]
_decodeAfterPartialDecrypt.restype = POINTER(_Ldouble)

_addCTs = _so.addCTs
_addCTs.argtypes = [ POINTER(_Ciphertext), POINTER(_Ciphertext)]
_addCTs.restype = POINTER(_Ciphertext)

_multiplyCTConst = _so.multiplyCTConst
_multiplyCTConst.argtypes = [ POINTER(_Ciphertext), c_double]
_multiplyCTConst.restype = POINTER(_Ciphertext)

_addRingPs = _so.addRingPs
_addRingPs.argtypes = [ POINTER(_Poly), POINTER(_Poly)]
_addRingPs.restype = POINTER(_Poly)

# _genCRS = _so.genCRS
# _genCRS.argtypes = [ POINTER(_Params) ]
# _genCRS.restype = POINTER(_Poly)

# ----------------

### Wrapper Classes (pickle-able) ###

# class Params:
#     def __init__(self, _params):
#         self.qi = _Conversion.from_luint64(_params.qi)
#         self.pi = _Conversion.from_luint64(_params.pi)
#         self.logN = _params.logN
#         self.logSlots = _params.logSlots
#         self.scale = _params.scale
#         self.sigma = _params.sigma
    
#     # So we can send to Lattigo
#     def make_structure(self):
#         _params = _Params()
        
#         _params.qi = _Conversion.to_luint64(self.qi)
#         _params.pi = _Conversion.to_luint64(self.pi)
#         _params.logN = self.logN
#         _params.logSlots = self.logSlots
#         _params.scale = self.scale
#         _params.sigma = self.sigma

#         return _params

class ParametersLiteral:
    def __init__(self, _paramsLiteral):
        self.qi = _Conversion.from_luint64(_paramsLiteral.qi)
        self.pi = _Conversion.from_luint64(_paramsLiteral.pi)
        self.logN = _paramsLiteral.logN
        self.logSlots = _paramsLiteral.logSlots
        self.scale = _paramsLiteral.scale
        self.sigma = _paramsLiteral.sigma
    
    # So we can send to Lattigo
    def make_structure(self):
        _paramsLiteral = _ParametersLiteral()
        
        _paramsLiteral.qi = _Conversion.to_luint64(self.qi)
        _paramsLiteral.pi = _Conversion.to_luint64(self.pi)
        _paramsLiteral.logN = self.logN
        _paramsLiteral.logSlots = self.logSlots
        _paramsLiteral.scale = self.scale
        _paramsLiteral.sigma = self.sigma

        return _paramsLiteral

# use self.data instead of value used in go to be compatible with helper func to_list_with_conv() 
class Ciphertext:
    def __init__(self, _ct):
        self.data = [ None ] * _ct.size
        self.idxs = [ None ] * _ct.size

        for i in range(_ct.size):
            # self.data[i] = _Conversion.from_poly(_ct.data[i])
            self.data[i] = Poly(_ct.data[i])
            self.idxs[i] = _ct.idxs[i]
        self.scale = _ct.scale
        # self.idx = _ct.idx
        # self.isNTT = _ct.isNTT
    
    # So we can send to Lattigo
    def make_structure(self):
        _ct = _Ciphertext()

        data = [ None ] * len(self.data)
        idxs = [ None ] * len(self.idxs)
        for i in range(len(self.data)):
            data[i] = self.data[i].make_structure()
            idxs[i] = self.idxs[i]

        _ct.size = len(data)
        _ct.data = (_Poly * _ct.size)(*data)
        _ct.scale = self.scale
        _ct.idxs = (c_int * _ct.size)(*idxs)
        # _ct.idx = self.idx
        # _ct.isNTT = self.isNTT

        return _ct

class Poly:
    def __init__(self, _poly):
        self.coeffs = [ None ] * _poly.size
        
        for i in range(_poly.size):
            self.coeffs[i] = _Conversion.from_luint64(_poly.coeffs[i])
        
        self.IsNTT = _poly.IsNTT
        self.IsMForm = _poly.IsMForm
    
    # So we can send to Lattigo
    def make_structure(self):
        _poly = _Poly()

        coeffs = [ None ] * len(self.coeffs)
        
        for i in range(len(self.coeffs)):
            coeffs[i] = _Conversion.to_luint64(self.coeffs[i])
        
        _poly.size = len(coeffs)
        _poly.coeffs = (_Luint64 * _poly.size)(*coeffs)
        _poly.IsNTT = self.IsNTT
        _poly.IsMForm = self.IsMForm

        return _poly

class PolyQP:
    def __init__(self, _polyQP):
        self.Q = Poly(_polyQP.Q.contents)
        self.P = Poly(_polyQP.P.contents)
    
    # So we can send to Lattigo
    def make_structure(self):
        _polyQP = _PolyQP()
        
        _polyQP.Q.contents = self.Q.make_structure()
        _polyQP.P.contents = self.P.make_structure()

        return _polyQP

# Server that has Multi-Party Homomorphic Encryption functionality
class MPHEServer:
    def __init__(self, server_id):
        _server_ptr = _newMPHEServer(server_id)
        _server = _server_ptr.contents

        self.paramsLiteral = ParametersLiteral(_server.paramsLiteral) # implemented but currently not used, security parameters hardcoded in export.go
        self.sk = PolyQP(_server.sk)
        self.pk = _Conversion.from_polyQPpair(_server.pk)
        # self.sk = _Conversion.from_polyQP(_server.sk)
        # self.pk = _Conversion.from_polyQPpair(_server.pk)
        # self.crs = _Conversion.from_poly(_server.crs)
        # self.secret_key = _Conversion.from_poly(_server.secretKey)
        # self.data = []  # NOTE: always have this as decryptable by secret_key
        self.idx = _server.idx
    
    def encryptFromPk(self, data):
        # params = self.params.make_structure()
        # sk = _Conversion.to_poly(self.secret_key)
        pk = _Conversion.to_polyQPpair(self.pk)

        data_ptr = (c_double * len(data))(*data)
        enc_ct = _encryptFromPk(byref(pk), data_ptr, len(data), self.idx)

        # self.data = _Conversion.from_data(enc_ct.contents)
        self.data = Ciphertext(enc_ct.contents)

        return self.data
    
    def partialDecrypt(self, ciphertext):
        # params = self.params.make_structure()
        sk = self.sk.make_structure()
        # ct = _Conversion.to_data(self.data)
        ct = ciphertext.make_structure()

        partial_dec_ct = _partialDecrypt(byref(sk), byref(ct), self.idx)
        # dec_data = _Conversion.to_list(dec_data.contents)

        return Ciphertext(partial_dec_ct.contents)

    def ringAddLvl(self, ct1, ct1_idx, ct2, ct2_idx):
        op1 = ct1.make_structure()
        op2 = ct2.make_structure()
        op1 = _ringQAddLvl(op1, ct1_idx, op2, ct2_idx)

        return Ciphertext(op1.contents)

    def aggregate_pds(self, ct_pd_list, client_ids):
        # Aggregate partially decrypted ciphertexts, requires to be called by the server.
        # pd_list: a list of partilally decrypted ciphertexts
        # ct_pd_list: a list of client ids involved in the encryption
        # Return: ct_pd_agg is the decrytped ciphertext, 
        # which can be sent for decryption and decoding by calling server.decodeAfterPartialDecrypt(ct_pd_agg)
    
        # size mismatch of partially decrypted ciphertexts and client ids involved in the encryption
        if len(ct_pd_list) != len(client_ids):
            raise Exception("aggregate_pds(): ct_pd_list has a length of " + str(len(ct_pd_list)) + 
                            ", but client_ids has a length of " + str(len(client_ids)))
        
        # empty list of ciphertexts
        if len(ct_pd_list) == 0:
            raise Exception("aggregate_pds(): len(ct_pd_list) is 0.")
        else:
            # add polynomial ring on Ciphertext["0"] and Ciphertext["client_id"]
            ct_pd_agg = self.ringAddLvl(ct_pd_list[0], 0, ct_pd_list[0], client_ids[0])
            if len(ct_pd_list) > 1:
                for ct_id in range(1, len(ct_pd_list)):
                    ct_pd_agg = self.ringAddLvl(ct_pd_agg, 0, ct_pd_list[ct_id], client_ids[ct_id])
            return ct_pd_agg
    
    def addRingPs(self, rP1, rP2):
        ringP1 = rP1.make_structure()
        ringP2 = rP2.make_structure()
        sumRingP = _addRingPs(ringP1, ringP2)
        return Poly(sumRingP.contents)
    
    def aggregate_ringPs(self, ringP_list):
        # Generate the aggregated public key (the P ring in the public key) based on xMKCKKS (https://arxiv.org/abs/2104.06824)
        # pring_list: a list of P rings in the public keys' first QPRings, access by client_1.pk[0].P
        # Return: sum_ringP, a Poly ring that to be shared by all clients
        if len(ringP_list) == 0 or len(ringP_list) == 1:
            raise Exception("aggregate_ringPs(): ringP_list has a length of " + str(len(ringP_list)))
        else:
            sum_ringP = self.addRingPs(ringP_list[0], ringP_list[1])
            if len(ringP_list) > 2:
                for pk_id in range(2, len(ringP_list)):
                    self.addRingPs(sum_ringP, ringP_list[pk_id])
            return sum_ringP

    def decodeAfterPartialDecrypt(self, ciphertext):
        ct = ciphertext.make_structure()
        res = _decodeAfterPartialDecrypt(ct)
        return _Conversion.from_ldouble(res.contents)

    def addCTs(self, ct1, ct2):
        op1 = ct1.make_structure()
        op2 = ct2.make_structure()
        res = _addCTs(op1, op2)
        return Ciphertext(res.contents)

    def multiplyCTConst(self, ct1, const):
        op1 = ct1.make_structure()
        res = _multiplyCTConst(op1, const)
        return Ciphertext(res.contents)

        
    # def gen_crs(self):
    #     params = self.params.make_structure()

    #     crs = _genCRS(byref(params))
    #     self.crs = _Conversion.from_poly(crs.contents)

    #     return self.crs
    
    # def col_key_gen(self, ckg_shares):
    #     params = self.params.make_structure()
    #     sk = _Conversion.to_poly(self.secret_key)
    #     crs = _Conversion.to_poly(self.crs)
    #     shares_ptr = _Conversion.to_ptr(ckg_shares, _Conversion.to_share, _Share)
        
    #     cpk = _colKeyGen(byref(params), byref(sk), byref(crs), shares_ptr, len(ckg_shares))

    #     return _Conversion.from_polypair(cpk.contents)

    # def col_key_switch(self, agg, cks_shares):
    #     params = self.params.make_structure()
    #     data = _Conversion.to_data(agg)
    #     shares_ptr = _Conversion.to_ptr(cks_shares, _Conversion.to_share, _Share)

    #     switched_data = _colKeySwitch(byref(params), byref(data), shares_ptr, len(cks_shares))
    #     self.data = _Conversion.from_data(switched_data.contents)

    # def aggregate(self, updates):
    #     params = self.params.make_structure()
    #     data_ptr = _Conversion.to_ptr(updates, _Conversion.to_data, _Data)

    #     agg = _aggregate(byref(params), data_ptr, len(updates))

    #     return _Conversion.from_data(agg.contents)

    # def average(self, n):
    #     params = self.params.make_structure()
    #     data = _Conversion.to_data(self.data)

    #     avg_data = _mulByConst(byref(params), byref(data), 1/n)
    #     self.data = _Conversion.from_data(avg_data.contents)

    # # DEBUG: Decrypts its data then prints contents
    # def print_data(self):
    #     params = self.params.make_structure()
    #     sk = _Conversion.to_poly(self.secret_key)
    #     ct = _Conversion.to_data(self.data)

    #     dec_data = _decrypt(byref(params), byref(sk), byref(ct))
    #     dec_data = _Conversion.to_list(dec_data.contents)

    #     print('Decrypted SERVER data:\n\t', dec_data)

# Performs conversion between Structures (which contain pointers) to pickle-able classes
class _Conversion:
    # (FYI) Convert to numpy array: https://stackoverflow.com/questions/4355524/getting-data-from-ctypes-array-into-numpy

    # Generic array type Structure to list

    def to_list(_l):
        l = [ None ] * _l.size

        for i in range(_l.size):
            l[i] = _l.data[i]
        
        return l

    def to_list_with_conv(_l, conv):
        l = [ None ] * _l.size

        for i in range(_l.size):
            l[i] = conv(_l.data[i])
        
        return l

    def to_ptr(l, conv, t):
        lt = [ None ] * len(l)

        for i in range(len(l)):
            lt[i] = conv(l[i])
        
        return (t * len(lt))(*lt)

    ### _Luint64 (list of uint64)

    def from_luint64(_luint64):
        return _Conversion.to_list(_luint64)

    def to_luint64(l):
        luint64 = _Luint64()

        luint64.size = len(l)
        luint64.data = (c_ulonglong * luint64.size)(*l)

        return luint64

    ### _Ldouble (list of double)

    def from_ldouble(_ldouble):
        return _Conversion.to_list(_ldouble)

    def to_ldouble(l):
        ldouble = _Ldouble()

        ldouble.size = len(l)
        ldouble.data = (c_ulonglong * ldouble.size)(*l)

        return _ldouble
    
    # _Poly (list of Coefficients (Luint64))
        
    # def from_poly(_poly):
    #     coeffs = [ None ] * _poly.size

    #     for i in range(_poly.size):
    #         coeffs[i] = _Conversion.from_luint64(_poly.coeffs[i])

        
        
    #     return coeffs
    
    # def to_poly(coeffs):
    #     list_luint64 = [ None ] * len(coeffs)

    #     for i in range(len(coeffs)):
    #         list_luint64[i] = _Conversion.to_luint64(coeffs[i])
        
    #     _poly = _Poly()
    #     _poly.size = len(list_luint64)
    #     _poly.coeffs = (_Luint64 * _poly.size)(*list_luint64)

    #     return _poly

    # _PolyPair (list[2] of Poly)
    
    def from_polyQPpair(_qpp):
        qpp = [ None ] * 2

        qpp[0] = PolyQP(_qpp.qp0)
        qpp[1] = PolyQP(_qpp.qp1)
        
        return qpp

    def to_polyQPpair(qpp):        
        _qpp = _PolyQPPair()

        if len(qpp) != 2:
            print('ERROR: Only a list of size 2 makes a pair (not {})'.format(len(qpp)))
            return None

        _qpp.qp0 = qpp[0].make_structure()
        _qpp.qp1 = qpp[1].make_structure()

        return _qpp
        
    # def from_polypair(_pp):
    #     pp = [ None ] * 2

    #     pp[0] = _Conversion.from_poly(_pp.p0)
    #     pp[1] = _Conversion.from_poly(_pp.p1)
        
    #     return pp

    # def to_polypair(pp):        
    #     _pp = _PolyPair()

    #     if len(pp) != 2:
    #         print('ERROR: Only a list of size 2 makes a pair (not {})'.format(len(pp)))
    #         return None

    #     _pp.p0 = _Conversion.to_poly(pp[0])
    #     _pp.p1 = _Conversion.to_poly(pp[1])

    #     return _pp

    ### _Share (list of Poly)

    def from_share(_share):        
        return _Conversion.to_list_with_conv(_share, _Conversion.from_poly)

    def to_share(share):
        list_poly = [ None ] * len(share)

        for i in range(len(share)):
            list_poly[i] = _Conversion.to_poly(share[i])
        
        _share = _Share()
        _share.size = len(list_poly)
        _share.data = (_Poly * _share.size)(*list_poly)

        return _share

    ### _Data (list of Ciphertext)

    def from_data(_data):
        return _Conversion.to_list_with_conv(_data, Ciphertext)
    
    def to_data(data):
        list_ciphertext = [ None ] * len(data)

        for i in range(len(data)):
            list_ciphertext[i] = data[i].make_structure()
        
        _data = _Data()
        _data.size = len(list_ciphertext)
        _data.data = (_Ciphertext * _data.size)(*list_ciphertext)

        return _data

In [5]:
def print_decrypted(expected, decrypted, rounding=2, first_k=10):
    # Print expected results without homomorphic encryption (HE) and decrypted results after HE operations
    # expected: a list of expected results
    # decrypted: a list of decrypted results
    # rounding: rounding points for eliminating the error
    # first_k: display the first k elements to avoid showing all numbers (2^logSlots) in decrypted results
    
    # if first_k < len(expected):
    #     print("Exptected:\n", expected[:first_k])
    # else:
    print("Exptected:\n", expected[:first_k])
    print("\nBefore rounding (first 10 elements):\n", decrypted[:first_k])
    print("\nAfter rounding (first 10 elements):\n", np.round(decrypted, rounding)[:first_k])

## Initialization

Note: Security parameters are hard coded in export.go

In [6]:
start_time=time.time()
server = MPHEServer(server_id=0) # id for FL server has to be 0
client_1 = MPHEServer(server_id=1) # id for FL client starts from 1, can be any unique non-0 integer
client_2 = MPHEServer(server_id=2) # id for FL client starts from 1, can be any unique non-0 integer
client_3 = MPHEServer(server_id=3) # id for FL client starts from 1, can be any unique non-0 integer
client_4 = MPHEServer(server_id=4) # id for FL client starts from 1, can be any unique non-0 integer

01234

In [7]:
# num_decimal = 2
weights1 = np.random.random(size=8192)#[ 0.1, 0.2, 2.1, -2.2 ]
weights2 = np.random.random(size=8192)#[ 0.2, 0.3, 1.2, -1.2 ]
weights3 = np.random.random(size=8192)#[ 0.3, 0.5, 1.5, -1.5 ]
weights4 = np.random.random(size=8192)#[ 100, 200, 300, 400 ]

In [8]:
print(weights1)
print(weights2)
print(weights3)
print(weights4)

[0.66980947 0.34271767 0.54271665 ... 0.90311834 0.83146099 0.81431416]
[0.12213238 0.30562099 0.2798189  ... 0.34823306 0.79593998 0.93182322]
[0.7380102  0.16674721 0.6034131  ... 0.17261962 0.4624206  0.54360165]
[0.56860346 0.9461559  0.55968274 ... 0.56031052 0.41785735 0.15272541]


## Key Aggregation

In [9]:
ringP_list = [client_1.pk[0].P, client_2.pk[0].P, client_3.pk[0].P, client_4.pk[0].P] # form a list of P rings in all clients' public keys
agg_ringP = server.aggregate_ringPs(ringP_list) # generate aggregated ring P for xMKCKKS

# assign aggregated ring P to the public keys of all clients
client_1.pk[0].P = copy.deepcopy(agg_ringP)
client_2.pk[0].P = copy.deepcopy(agg_ringP)
client_3.pk[0].P = copy.deepcopy(agg_ringP)
client_4.pk[0].P = copy.deepcopy(agg_ringP)

## Encryption

In [10]:
ct1 = client_1.encryptFromPk(weights1)
ct2 = client_2.encryptFromPk(weights2)
ct3 = client_3.encryptFromPk(weights3)
ct4 = client_4.encryptFromPk(weights4)

## Homomorphic Addition & Multiplication (ct * const)

In [None]:
ct5 = server.addCTs(ct1, ct2)
ct6 = server.addCTs(ct5, ct3)
ct7 = server.addCTs(ct6, ct4)

ct6 = server.multiplyCTConst(ct6, 2.5)

## Partial Decryption & Aggregation

In [37]:
ct7_pd1 = client_1.partialDecrypt(ct7)
ct7_pd2 = client_2.partialDecrypt(ct7)
ct7_pd3 = client_3.partialDecrypt(ct7)
ct7_pd4_tmp = client_4.partialDecrypt(ct7)

ct7_pd4 = ct7_pd4_tmp

In [38]:
del ct7_pd4_tmp

In [None]:
ct4_pd_agg = server.aggregate_pds([ct7_pd3, ct7_pd2, ct7_pd1, ct7_pd4], [3,2,1,4]) # needs to ensure the order of client_ids is the same as the order of ct_pd_list

## Decoding (& remaining steps after obtaining agg_pd) 

In [47]:
dec_ct4 = server.decodeAfterPartialDecrypt(ct4_pd_agg)
expected_ct4 = np.sum((np.array(weights1), np.array(weights2), np.array(weights3), np.array(weights4)), axis=0)
print_decrypted(expected_ct4, dec_ct4, rounding=6, first_k=10)
print("--- %s seconds ---" % (time.time() - start_time))

Exptected:
 [2.0985555  1.76124176 1.98563139 1.50926414 2.35214468 1.93195578
 0.82043896 2.08554052 0.75392891 2.7080138 ]

Before rounding (first 10 elements):
 [2.098555503323388, 1.7612417593518142, 1.9856313903453031, 1.5092641424317452, 2.352144677606473, 1.931955775837633, 0.8204389618539836, 2.0855405184506526, 0.7539289140665433, 2.70801380042919]

After rounding (first 10 elements):
 [2.098556 1.761242 1.985631 1.509264 2.352145 1.931956 0.820439 2.085541
 0.753929 2.708014]
--- 97.9135468006134 seconds ---


In [41]:
ct5_pd1 = client_1.partialDecrypt(ct5)
ct5_pd2 = client_2.partialDecrypt(ct5)

In [42]:
ct5_pd_agg = server.aggregate_pds([ct5_pd1, ct5_pd2], [1,2])

In [43]:
dec_ct5 = server.decodeAfterPartialDecrypt(ct5_pd_agg)
expected_ct5 = np.sum((np.array(weights1), np.array(weights2)), axis=0)
print_decrypted(expected_ct5, dec_ct5, rounding=6, first_k=10)

Exptected:
 [0.79194185 0.64833866 0.82253555 0.37014792 1.38915026 0.74055453
 0.05543939 0.67515782 0.67224453 1.15170554]

Before rounding (first 10 elements):
 [0.7919418476647007, 0.6483386566174736, 0.8225355499653848, 0.3701479199005293, 1.3891502578696446, 0.7405545339454237, 0.055439388599688555, 0.6751578170328273, 0.672244526567014, 1.1517055433949577]

After rounding (first 10 elements):
 [0.791942 0.648339 0.822536 0.370148 1.38915  0.740555 0.055439 0.675158
 0.672245 1.151706]


In [44]:
ct6_pd1 = client_1.partialDecrypt(ct6)
ct6_pd2 = client_2.partialDecrypt(ct6)
ct6_pd3 = client_3.partialDecrypt(ct6)

In [45]:
ct6_pd_agg = server.aggregate_pds([ct6_pd1, ct6_pd2, ct6_pd3], [1,2,3])

In [46]:
dec_ct6 = server.decodeAfterPartialDecrypt(ct6_pd_agg)
expected_ct6 = np.sum((np.array(weights1), np.array(weights2), np.array(weights3)), axis=0) * 2.5
print_decrypted(expected_ct6, dec_ct6, rounding=6, first_k=10)

Exptected:
 [3.82488011 2.03771466 3.56487162 1.99933001 5.80148385 3.07293388
 0.24397655 3.19685931 1.87048005 5.19649001]

Before rounding (first 10 elements):
 [3.8248801090569553, 2.037714659843706, 3.5648716237061224, 1.9993300100769118, 5.801483845717021, 3.072933884208135, 0.24397654526134938, 3.1968593117971262, 1.870480052688438, 5.196490006370599]

After rounding (first 10 elements):
 [3.82488  2.037715 3.564872 1.99933  5.801484 3.072934 0.243977 3.196859
 1.87048  5.19649 ]


## Noise Budget

In [21]:
import math
Scale = server.paramsLiteral.scale
LogSlots = server.paramsLiteral.logSlots

In [22]:
noise_budget = -math.log2(Scale)+(LogSlots)+8
print("Noise Budget:", noise_budget)

# Check noise level for the first 10 elements
for i in range(10):
# for i in range(len(dec_ct5)):
    delta = dec_ct4[i] - np.round(dec_ct4[i], 2)
    noise = math.log2(np.abs(delta))
    if noise <= noise_budget:
        print(i, "Pass.", 'Noise:', noise)
    else:
        print(i, "Fail.", 'Noise:', noise)

Noise Budget: -31.0
0 Fail. Noise: -7.682682834429518
1 Fail. Noise: -13.241110206729507
2 Fail. Noise: -10.040988064462784
3 Fail. Noise: -9.816959364762223
4 Fail. Noise: -7.726767798591919
5 Fail. Noise: -8.356439185313935
6 Fail. Noise: -7.917955622802751
7 Fail. Noise: -13.999685491640657
8 Fail. Noise: -7.810009240763777
9 Fail. Noise: -7.833778463875343
