In [1]:
!pip install torch torchvision numpy matplotlib psutil seaborn 

[0m

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import numpy as np
import time
from openfhe import *
import psutil
import logging

In [3]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)
    def forward(self, x):
        x = F.relu(self.conv1(x))  # Output: 32x28x28
        x = self.pool(x)  # Output: 32x14x14
        x = F.relu(self.conv2(x))  # Output: 64x14x14
        x = self.pool(x)  # Output: 64x7x7
        # Flatten the output for fully connected layers
        x = x.view(x.size(0), -1)  # Flatten to (batch_size, 64*7*7)
        # Fully connected layers
        x = F.relu(self.fc1(x))  # Output: 128
        x = self.fc2(x)  # Output: num_classes
        return x

In [4]:
# Configure logging
logging.basicConfig(level=logging.ERROR)
log = logging.getLogger(__name__)

# Set the desired encryption ratio here (e.g., 0.12 for 12%, 0.5 for 50%, 1.0 for 100%)
ENCRYPTION_RATIO = 1.0

# model = models.vgg16(pretrained=False)
# model = models.alexnet(pretrained=False)
# model = models.resnet50(pretrained=False)
# model = models.mobilenet_v2(pretrained=False)
# model = models.resnet18(pretrained=False)
model = CNN()
model.load_state_dict(torch.load("./cnn.pt", map_location=torch.device('cpu')))

params = list(model.parameters())
raw_size = sum(p.numel() for p in params) * 4


def setup_ckks_context():
    parameters = CCParamsCKKSRNS()
    parameters.SetSecretKeyDist(SecretKeyDist.UNIFORM_TERNARY)
    parameters.SetSecurityLevel(SecurityLevel.HEStd_NotSet)
    parameters.SetRingDim(1 << 12)

    if get_native_int() == 128:
        rescale_tech = ScalingTechnique.FIXEDAUTO
        dcrt_bits = 78
        first_mod = 89
    else:
        rescale_tech = ScalingTechnique.FLEXIBLEAUTO
        dcrt_bits = 59
        first_mod = 60

    parameters.SetScalingModSize(dcrt_bits)
    parameters.SetScalingTechnique(rescale_tech)
    parameters.SetFirstModSize(first_mod)
    parameters.SetMultiplicativeDepth(4)

    cc = GenCryptoContext(parameters)
    cc.Enable(PKESchemeFeature.PKE)
    cc.Enable(PKESchemeFeature.KEYSWITCH)
    cc.Enable(PKESchemeFeature.LEVELEDSHE)
    cc.Enable(PKESchemeFeature.MULTIPARTY)
    return cc


class MultiPartyHE:
    def __init__(self):
        self.cc = setup_ckks_context()
        self.keyPairs = []
        self.joint_public_key = None
        self.ring_dim = self.cc.GetRingDimension()
        self.num_slots = int(self.ring_dim / 2)

    def generate_keys(self, num_parties=2):
        keyPair1 = self.cc.KeyGen()
        self.keyPairs.append(keyPair1)
        for i in range(1, num_parties):
            keyPair = self.cc.MultipartyKeyGen(self.keyPairs[0].publicKey, False, True)
            self.keyPairs.append(keyPair)
        private_keys = [kp.secretKey for kp in self.keyPairs]
        self.joint_key_pair = self.cc.MultipartyKeyGen(private_keys)
        self.joint_public_key = self.joint_key_pair.publicKey
        self.cc.EvalMultKeyGen(self.joint_key_pair.secretKey)
        return self.joint_public_key

    def encrypt(self, data, level=1):
        if not isinstance(data, list):
            data = [float(data)]
        plaintext = self.cc.MakeCKKSPackedPlaintext(data, 1, level)
        plaintext.SetLength(len(data))
        return self.cc.Encrypt(self.joint_public_key, plaintext)

    def decrypt(self, ciphertext):
        result = self.cc.Decrypt(ciphertext, self.joint_key_pair.secretKey)
        return result

    def add(self, cipher1, cipher2):
        return self.cc.EvalAdd(cipher1, cipher2)

    def multiply(self, cipher1, cipher2):
        return self.cc.EvalMult(cipher1, cipher2)


def encrypt_parameters(params, encrypt_ratio=1.0, cpu_usage_list=None):
    encrypted_params = []
    total_params = len(params)
    num_to_encrypt = int(total_params * encrypt_ratio)
    max_size = mhe.num_slots
    encrypted_part_size = 0
    raw_part_size = 0
    for i, param in enumerate(params):
        param_data = param.detach().cpu().numpy().flatten().tolist()
        param_raw_size = param.numel() * 4

        if i < num_to_encrypt:
            encrypted_chunks = []
            for j in range(0, len(param_data), max_size):
                chunk = param_data[j: j + max_size]
                start_chunk_time = time.time()
                encrypted_chunk = mhe.encrypt(chunk)
                end_chunk_time = time.time()
                if cpu_usage_list is not None:
                    cpu_usage_list.append(
                        {
                            "time": (start_chunk_time + end_chunk_time) / 2,
                            "usage": psutil.cpu_percent(),
                        }
                    )

                encrypted_chunks.append(encrypted_chunk)
            encrypted_params.append(encrypted_chunks)

            dcrt_bits = 59
            ciphertext_size_approx = mhe.ring_dim * dcrt_bits * 2
            encrypted_part_size += len(encrypted_chunks) * ciphertext_size_approx
        else:
            encrypted_params.append(param_data)
            raw_part_size += param_raw_size

    dcrt_bits = 59
    ciphertext_size = mhe.ring_dim * dcrt_bits * 2
    encrypted_size = sum(
        ciphertext_size
        for param in encrypted_params[:num_to_encrypt]
        for chunk in param
    )

    return encrypted_params, encrypted_size, encrypted_part_size, raw_part_size

def decrypt_parameters(encrypted_params):
    decrypted_time = 0.
    start_time = time.time()
    num_to_encrypt = int(len(params) * ENCRYPTION_RATIO)
    
    for i, param in enumerate(encrypted_params):
        if i < num_to_encrypt:
            if isinstance(param, list):
                for chunk in param:
                    mhe.decrypt(chunk)
    
    decrypted_time = time.time() - start_time
    return decrypted_time

mhe = MultiPartyHE()
joint_public_key = mhe.generate_keys(num_parties=2)

cpu_usage_encryption = []

start_time = time.time()
encrypted_params, encrypted_size, encrypted_part_size, raw_part_size = encrypt_parameters(
    params, encrypt_ratio=ENCRYPTION_RATIO, cpu_usage_list=cpu_usage_encryption
)
end_time = time.time()
partial_encryption_time = end_time - start_time

decryption_time = decrypt_parameters(encrypted_params)

total_execution_time = partial_encryption_time + decryption_time

raw_size_mb = raw_size / (1024**2)
encrypted_size_mb = encrypted_size / (1024**2)
encrypted_part_size_mb = encrypted_part_size / (1024**2)
raw_part_size_mb = raw_part_size / (1024**2)
partially_encrypted_total_mb = encrypted_part_size_mb + raw_part_size_mb

print(f"Raw model size: {raw_size_mb:.2f} MB")
print(f"Partial encryption time ({ENCRYPTION_RATIO*100:.0f}%): {partial_encryption_time:.2f} seconds")
print(f"Encrypted model size : {encrypted_size_mb:.2f} MB")
print(
    f"Partially encrypted model size ({ENCRYPTION_RATIO*100:.0f}% encrypted + {(1-ENCRYPTION_RATIO)*100:.0f}% raw) - Total: {partially_encrypted_total_mb:.2f} MB"
)
print(
    f"  - Encrypted part size ({ENCRYPTION_RATIO*100:.0f}% of parameters): {encrypted_part_size_mb:.2f} MB"
)
print(f"  - Raw part size ({(1-ENCRYPTION_RATIO)*100:.0f}% of parameters): {raw_part_size_mb:.2f} MB")
print(f"Decryption time: {decryption_time:.2f} seconds")
print(f"Total execution time: {total_execution_time:.2f} seconds")

Raw model size: 1.61 MB
Partial encryption time (50%): 0.05 seconds
Encrypted model size : 5.53 MB
Partially encrypted model size (50% encrypted + 50% raw) - Total: 7.07 MB
  - Encrypted part size (50% of parameters): 5.53 MB
  - Raw part size (50% of parameters): 1.54 MB
Decryption time: 0.03 seconds
Total execution time: 0.08 seconds
