In [2]:
from phe.paillier import generate_paillier_keypair

In [3]:
import multiprocessing
import time

import numpy as np

from linkefl.crypto import FastPaillier, Paillier
from linkefl.crypto.paillier import encode, fast_cipher_matmul

# shape = 10
# enc_mat_shape = (shape, shape)
# plain_mat_shape = (shape, shape)
enc_mat_shape = (10, 10)
plain_mat_shape = (10, 10)
precision = 0.001
np.random.seed(0)

crypto = Paillier()
enc_matrix = np.random.rand(*enc_mat_shape) * 2 - 1

enc_matrix = np.array(crypto.encrypt_vector(enc_matrix.flatten())).reshape(
    enc_mat_shape
)
plain_matrix = np.random.rand(*plain_mat_shape) * 2 - 1
encode_matrix, encode_mappings = encode(
    plain_matrix, crypto.pub_key, precision=precision
)

start_time = time.time()
res = np.matmul(enc_matrix, plain_matrix)
print(f"plain matmul: {time.time() - start_time}")


plain matmul: 0.12432622909545898


In [4]:
from linkefl.modelzoo import *
from torchinfo import summary

In [5]:
resnet18 = ResNet18(in_channel=3, num_classes=10)
summary(resnet18, input_size=(128, 3, 28, 28))

Layer (type:depth-idx)                   Output Shape              Param #
ResNet                                   [128, 10]                 --
├─Conv2d: 1-1                            [128, 64, 28, 28]         1,728
├─BatchNorm2d: 1-2                       [128, 64, 28, 28]         128
├─Sequential: 1-3                        [128, 64, 28, 28]         --
│    └─BasicBlock: 2-1                   [128, 64, 28, 28]         --
│    │    └─Conv2d: 3-1                  [128, 64, 28, 28]         36,864
│    │    └─BatchNorm2d: 3-2             [128, 64, 28, 28]         128
│    │    └─Conv2d: 3-3                  [128, 64, 28, 28]         36,864
│    │    └─BatchNorm2d: 3-4             [128, 64, 28, 28]         128
│    │    └─Sequential: 3-5              [128, 64, 28, 28]         --
│    └─BasicBlock: 2-2                   [128, 64, 28, 28]         --
│    │    └─Conv2d: 3-6                  [128, 64, 28, 28]         36,864
│    │    └─BatchNorm2d: 3-7             [128, 64, 28, 28]         

In [6]:
## SPNN
a_shape = (128, 64, 28, 28)
b_shape = (128, 64, 28, 28)

start = time.time()
data_a = np.random.rand(128)
data_b = np.random.rand(128)
enc_data_a = np.array(crypto.encrypt_vector(data_a))
enc_data_b = np.array(crypto.encrypt_vector(data_b))
enc_res = enc_data_b + enc_data_a
res = crypto.decrypt_vector(enc_res)
end_time = time.time() 
elapsed_time = (end_time - start_time) * 64 * 28 * 28
print(f"elapsed time: {elapsed_time}")

elapsed time: 65672.21411132812


In [8]:
## SFA
start = time.time()
embedding = np.random.rand(128, 10)
w_shape = (10, 10)
w = np.random.rand(*w_shape)
enc_w = np.array(crypto.encrypt_vector(w.flatten())).reshape(w_shape)
encode_matrix, encode_mappings = encode(
    embedding, crypto.pub_key, precision=precision
)

# start_time = time.time()
res = np.matmul(embedding, enc_w)
print(f"plain matmul: {time.time() - start}")


res_shape = (128, 10)
noise_data = np.random.rand(*res_shape)
noise_res = np.array(crypto.encrypt_vector(noise_data.flatten())).reshape(res_shape)
res = res - noise_res
res = crypto.decrypt_vector(res)
print(f"elapsed time: {time.time() - start}")

plain matmul: 1.6938319206237793


TypeError: Don't know the precision of type <class 'numpy.ndarray'>.

In [None]:
## BlindFL
conv = np.random.rand(3*3*3)
enc_conv = np.array(crypto.encrypt_vector(conv))
data = np.random.rand(3*3*3)
start = time.time()
for i in range(10):
    single_op = (data * enc_conv).sum()
    res = np.random.rand(1) - single_op
    res = crypto.decrypt(res)
single_time = (time.time() - start) / 10

all_time  = single_time * (28*28) * 64 * 4
print(all_time)

658.32744140625
