In [1]:
import sys
import os
sys.path.append(os.path.abspath('/home/theo_ubuntu/Diplomatic_incident/HE/TenSEAL/My_HE_FHE/'))

from HE_functions import Ckks_init, Encrypt_model, Dencrypt_model

In [2]:
import csv

def load_csv_as_dicts(file_path):
    """
    Load data from a CSV file into a list of dictionaries.

    Args:
        file_path (str): Path to the CSV file.

    Returns:
        list: A list of dictionaries where keys are column names 
              and values are parsed based on the content.
    """
    data = []
    with open(file_path, "r") as csvfile:
        reader = csv.DictReader(csvfile)
        for row in reader:
            parsed_row = {}
            for key, value in row.items():
                # Attempt to parse each value into int, float, or leave as string
                try:
                    parsed_row[key] = int(value)
                except ValueError:
                    try:
                        parsed_row[key] = float(value)
                    except ValueError:
                        parsed_row[key] = value
            data.append(parsed_row)
            
    return data

valid_config = load_csv_as_dicts("valid_config_2.csv")
valid_config[-1]

{'degree': 2048, 'coeff': '[20, 20, 14]'}

In [3]:
from typing import List, Tuple
import numpy as np

def batch_parameters(params: List[np.ndarray], slot_count: int) -> List[List[np.ndarray]]:
    batched = [params[i:i + slot_count] for i in range(0, len(params), slot_count)]
    # print(f"Number of parts: {len(batched)}")
    return batched

In [4]:
import tenseal as ts
import math
from models import Net_Mnist, Net_Cifar
from model_utils import get_parameters, flatten_params, reshape_params
from model_utils import set_parameters, test_flatten_reshape

In [5]:
def convert_size(size_bytes):
    if size_bytes == 0:
        return "0B"
    size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return "%s %s" % (s, size_name[i])


In [6]:
def test_baching_HE(num_clients, global_model):
    statistics = []

    for Selected_key in range(len(valid_config)):

        degree = valid_config[Selected_key]['degree']
        slot_count = degree // 2
        coeff_modulus = list(map(int, valid_config[Selected_key]['coeff'].strip("[]").split(",")))
        print(f"degree: {degree}, slot_count: {slot_count}, coeff_modulus: {coeff_modulus}")

        context = Ckks_init(degree,coeff_modulus)

        original_params = get_parameters(global_model)
        Global_shapes = flatten_params(original_params)[1]
        flat_params = flatten_params(original_params)[0]
        print(len(flat_params))

        ## =========== CLIENTS ==================== ##
        local_models = [None] * num_clients

        for client_id in range(num_clients):
            params = get_parameters(global_model)
            flat_params = flatten_params(params)[0]
            batched = batch_parameters(flat_params, slot_count)
            local_models[client_id] = batched
        
        ## =========== CLIENTS ==================== ##

        clients_encrypted_batches = []
        sum_ciphertext_size = 0
        for client_id in range(num_clients):
            encrypted_batches = []
            for batch in batched:
                encrypted_batch = ts.ckks_vector(context, batch)
                #encrypted_batch = np.array(batch)
                encrypted_batches.append(encrypted_batch)
                sum_ciphertext_size += len(encrypted_batch.serialize())

            clients_encrypted_batches.append(encrypted_batches)
        print(len(clients_encrypted_batches))

        diff_encrypted_batches = [
            batch[0] - batch[1] for batch in zip(*clients_encrypted_batches)
        ]

        decrypted_params = []
        for batch in diff_encrypted_batches:  
            decrypted_batch = np.array(batch.decrypt())
            #decrypted_batch = np.array(batch)
            decrypted_params.append(decrypted_batch)
        print(len(decrypted_params))
        decrypted_params = np.concatenate(decrypted_params, axis=0)
        print(len(decrypted_params))


        reshaped_batch = reshape_params(decrypted_params, Global_shapes)
        set_parameters(global_model, reshaped_batch)
        test_flatten_reshape(original_params, reshaped_batch)

        print(f"Sum ciphertext size: {convert_size(sum_ciphertext_size)} ")

        statistics.append([degree, slot_count, coeff_modulus, sum_ciphertext_size])
    
    return statistics


In [7]:
num_clients = 10
global_model = Net_Cifar()

# statistics = test_baching_HE(num_clients, global_model)

In [10]:
Selected_key = -1

degree = valid_config[Selected_key]['degree']
slot_count = degree // 2
coeff_modulus = list(map(int, valid_config[Selected_key]['coeff'].strip("[]").split(",")))
print(f"degree: {degree}, slot_count: {slot_count}, coeff_modulus: {coeff_modulus}")

context = Ckks_init(degree,coeff_modulus)

original_params = get_parameters(global_model)
Global_shapes = flatten_params(original_params)[1]
print("reshaped_batch",Global_shapes)

# flat_params = flatten_params(original_params)[0]
# batched = batch_parameters(flat_params, slot_count)
# print(len(batched))
# print(len(flat_params))

## ============================== CLIENTS ================================== ##
clients_encrypted_batches = []
for client_id in range(num_clients):
    params = get_parameters(global_model)
    
    ## training 

    flat_params = flatten_params(params)[0]
    batched = batch_parameters(flat_params, slot_count)
    encrypted_batches = []
    for batch in batched:
        encrypted_batch = ts.ckks_vector(context, batch)
        encrypted_batches.append(encrypted_batch)

    clients_encrypted_batches.append(encrypted_batches)

print("num of clients encrypted_batches ",len(clients_encrypted_batches))
print("num of batches for a client",len(clients_encrypted_batches[0]))

# =============================== SERVER ======================================= #

sumed_encrypted_batches = [
    sum(batch) for batch in zip(*clients_encrypted_batches)
]
print("batches after sum",len(sumed_encrypted_batches))


decrypted_params = np.concatenate([batch.decrypt() for batch in sumed_encrypted_batches]) / num_clients
print("decrypted_batches",len(decrypted_params))

reshaped_batch = reshape_params(decrypted_params, Global_shapes)
print("reshaped_batch",flatten_params(reshaped_batch)[1])
set_parameters(global_model, reshaped_batch)

degree: 2048, slot_count: 1024, coeff_modulus: [20, 20, 14]
reshaped_batch [(6, 3, 5, 5), (6,), (16, 6, 5, 5), (16,), (120, 400), (120,), (84, 120), (84,), (10, 84), (10,)]
num of clients encrypted_batches  10
num of batches for a client 61
batches after sum 61
decrypted_batches 62006
reshaped_batch [(6, 3, 5, 5), (6,), (16, 6, 5, 5), (16,), (120, 400), (120,), (84, 120), (84,), (10, 84), (10,)]


In [9]:
decrypted_batches = np.concatenate([batch.decrypt() for batch in sumed_encrypted_batches])
print("decrypted_batches",len(decrypted_batches))

decrypted_batches 62006
