In [None]:
import numpy as np
import onnx
import onnxruntime as rt
import scipy as sp
import wfdb

In [None]:
DIAGNOSES_LIST = ["1AVB", "2AVB", "3AVB", "ABQRS", "AFIB", "AFLT", "ALMI", "AMI", "ANEUR", "ASMI", "BIGU", "CLBBB", "CRBBB", "DIG", "EL", "HVOLT", "ILBBB", "ILMI", "IMI", "INJAL", "INJAS", "INJIL", "INJIN", "INJLA", "INVT", "IPLMI", "IPMI", "IRBBB", "ISCAL", "ISCAN", "ISCAS", "ISCIL", "ISCIN", "ISCLA", "ISC_", "IVCD", "LAFB", "LAO/LAE", "LMI", "LNGQT", "LOWT", "LPFB", "LPR", "LVH", "LVOLT", "NDT", "NORM", "NST_", "NT_", "PAC", "PACE", "PMI", "PRC(S)", "PSVT", "PVC", "QWAVE", "RAO/RAE", "RVH", "SARRH", "SBRAD", "SEHYP", "SR", "STACH", "STD_", "STE_", "SVARR", "SVTAC", "TAB_", "TRIGU", "VCLVH", "WPW"]

# Load and preprocess data
ecg = wfdb.rdsamp("data/ECG/ath_001")
ecg_resampled = sp.signal.resample(ecg[0], 1000, axis=0)
X_test = np.expand_dims(ecg_resampled, axis=0).astype(np.float32)
X_test.shape, X_test[0:1]

## Plaintext ONNX

In [None]:
session = rt.InferenceSession("data/ECG/ecg_classifier_base.onnx")
pred_plaintext = session.run(["dense"], {"input": X_test})[0]
pred_plaintext

In [None]:
session = rt.InferenceSession("data/ECG/ecg_classifier.onnx")

model = onnx.load("data/ECG/ecg_classifier.onnx")
onnx.checker.check_model(model)

pred_plaintext_updated = session.run(["dense"], {"input": X_test})[0]

print(f"Largest difference: {np.max(np.abs(pred_plaintext - pred_plaintext_updated)):.9f}")

## Concrete ML

In [None]:
from concrete.ml.torch.compile import compile_onnx_model

model = onnx.load("data/ECG/ecg_classifier.onnx")

input_set = np.random.uniform(-1, 1, size=(1, 1000, 12))

fhe_model = compile_onnx_model(
    model, 
    input_set,
    n_bits=8,
    rounding_threshold_bits={"n_bits": 8, "method": "approximate"}
)

# Benchmarks
Comparing inference times for a single ECG

## Plaintext

In [None]:
%%timeit
session = rt.InferenceSession("data/ECG/ecg_classifier.onnx")
pred = session.run(["dense"], {"input": X_test})[0]