In [None]:
import multiprocessing

import numpy as np
import onnxruntime as rt
import torch
import torch.nn as nn
import torch.nn.functional as F
import wfdb
import scipy as sp

multiprocessing.set_start_method("fork")

# Load data
ecg = wfdb.rdsamp("../data/ECG/ath_001")
ecg_resampled = sp.signal.resample(ecg[0], 1000, axis=0)
x = np.expand_dims(np.transpose(ecg_resampled), axis=0).astype(np.float32).reshape(1, -1)

In [None]:
import warnings
warnings.filterwarnings("ignore", message="The given NumPy array is not writable")

In [None]:
# Benchmarks
pt_us = 114  # 0.11 ms
crypten_us = 332 * 1e3  # 332 ms
concrete_us = 33.3 * 1e6  # 33 sek
tenseal_us = (3*60 + 56) * 1e6  # 4 min

print(f"CrypTen  is {int(crypten_us/pt_us):,} times slower")
print(f"Concrete is {int(concrete_us/pt_us):,} times slower")
print(f"TenSEAL  is {int(tenseal_us/pt_us):,} times slower")

# Plaintext models

In [None]:
# 3.1M params
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(12000, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 71)
    
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        return x

In [None]:
# 4.6M params
class ConvNet(nn.Module):
    def __init__(self, hidden=128, output=71):
        super(ConvNet, self).__init__()
        # Input: (n, 12, 1000)
        self.conv1 = nn.Conv1d(12, 36, kernel_size=1, bias=False)
        
        # Calculate the output length from formula
        # After conv1: (n, 36, 1000)
        self.fc1 = nn.Linear(int(36*1000), hidden)
        self.fc2 = nn.Linear(hidden, output)

    def forward(self, x):
        x = self.conv1(x)  
        x = x * x
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = x * x
        x = self.fc2(x)
        return x

## Export models

In [None]:
torch.manual_seed(0)

mlp = MLP()
mlp_input = torch.randn(1, 12000)
torch.save(mlp, "mlp.pt")
torch.onnx.export(
    mlp,
    mlp_input,
    "mlp.onnx",
    export_params=True,
    input_names=["input"],
    output_names=["output"],
    opset_version=14,
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
    keep_initializers_as_inputs=False,
)
torch.onnx.export(
    mlp,
    mlp_input,
    "../cryptflow/mlp.onnx",
    export_params=True,
    input_names=["input"],
    output_names=["output"],
    opset_version=14,
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
    keep_initializers_as_inputs=False,
)

convnet = ConvNet()
convnet_input = torch.randn(1, 12, 1000)
torch.save(convnet, "convnet.pt")
torch.onnx.export(
    convnet,
    convnet_input,
    "convnet.onnx",
    export_params=True,
    input_names=["input"],
    output_names=["output"],
    opset_version=14,
    dynamic_axes={
        "input": {0: "batch_size"},
        "output": {0: "batch_size"},
    },
    keep_initializers_as_inputs=False,
)

## ONNX plaintext inference

In [None]:
mlp_session = rt.InferenceSession("mlp.onnx")
out_pt = mlp_session.run(["output"], {"input": x})[0]
out_pt

In [None]:
%%timeit
mlp_session.run(["output"], {"input": x})

## Concrete ML

In [None]:
import onnx
from concrete.ml.torch.compile import compile_onnx_model

model = onnx.load("mlp.onnx")
input_set = np.random.normal(size=(1, 12000))

cml_model = compile_onnx_model(
    model,
    input_set,
    n_bits=6,
    rounding_threshold_bits={"n_bits": 6, "method": "approximate"},
)

In [None]:
# TODO: Need to encrypt input?
out_cml = cml_model.forward(x, fhe="execute")
out_cml

In [None]:
%%timeit
cml_model.forward(x, fhe="execute")

## TenSeal

In [None]:
import tenseal as ts

class EncMLP:
    def __init__(self, model):
        self.fc1_weight = model.state_dict()["fc1.weight"].transpose(0, 1)
        self.fc1_bias = model.state_dict()["fc1.bias"]
        self.fc2_weight = model.state_dict()["fc2.weight"].transpose(0, 1)
        self.fc2_bias = model.state_dict()["fc2.bias"]
        self.fc3_weight = model.state_dict()["fc3.weight"].transpose(0, 1)
        self.fc3_bias = model.state_dict()["fc3.bias"]
    
    def forward(self, x):
        x = x.mm(self.fc1_weight) + self.fc1_bias
        x.square_()        
        x = x.mm(self.fc2_weight) + self.fc2_bias 
        x.square_()        
        x = x.mm(self.fc3_weight) + self.fc3_bias 
        return x


# Encryption parameters
# TODO: Figure out how to set these
bits_scale = 26
context = ts.context(
    ts.SCHEME_TYPE.CKKS,
    poly_modulus_degree=32768,
    coeff_mod_bit_sizes=[31, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, bits_scale, 31]
)
context.global_scale = pow(2, bits_scale)
context.generate_galois_keys()  # Required to do ciphertext rotations

# Load model
model = torch.load("mlp.pt")
ts_model = EncMLP(model)

# Encrypt input
x_ts_enc = ts.ckks_vector(context, x[0])

In [None]:
out_ts_enc = ts_model.forward(x_ts_enc)
out_ts = out_ts_enc.decrypt()
out_ts = np.asarray(out_ts).reshape(1, -1)
out_ts

In [None]:
%%timeit
ts_model.forward(x_ts_enc)

## CrypTen

In [None]:
import crypten
import crypten.mpc as mpc
import torch
import torch.nn as nn
import torch.nn.functional as F

crypten.init()
# Disables OpenMP threads -- needed by @mpc.run_multiprocess which uses fork
torch.set_num_threads(1)


CLIENT = 0
SERVER = 1

crypten.common.serial.register_safe_class(MLP)

model = MLP()
torch.save(model, "mlp_local.pt")

In [None]:
@mpc.run_multiprocess(world_size=2)
def run():
    dummy_model = MLP()
    dummy_input = torch.empty((1, 12000))
    
    # Encrypt model
    model_data = torch.load("mlp_local.pt", weights_only=False)
    dummy_model.load_state_dict(model_data.state_dict())
    private_model = crypten.nn.from_pytorch(dummy_model, dummy_input)
    private_model.encrypt(src=SERVER)

    # Encrypt data
    data = torch.load("input.pth")
    data_enc = crypten.cryptensor(data, src=CLIENT)

    # Encrypted inference
    private_model.eval()
    out_enc = private_model(data_enc)
    out = out_enc.get_plain_text()
    # crypten.print(f"Output: {out}")

In [None]:
%%timeit
run()