In [None]:
import onnxruntime as rt
import numpy as np
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
import datetime
import os

In [None]:
# Predition on original data
session = rt.InferenceSession("benchmarks/data/ECG/ecg_classifier.onnx")
input_name = session.get_inputs()[0].name
original_data = np.random.randn(1, 1000, 12).astype(np.float32)
original_output = session.run(None, {input_name: original_data})

In [None]:
# Enclave setup
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
public_key = private_key.public_key()
subject = issuer = x509.Name([
    x509.NameAttribute(NameOID.COMMON_NAME, "enclave")
])
certificate = (
    x509.CertificateBuilder()
    .subject_name(subject)
    .issuer_name(issuer)
    .public_key(public_key)
    .serial_number(x509.random_serial_number())
    .not_valid_before(datetime.datetime.utcnow())
    .not_valid_after(datetime.datetime.utcnow() + datetime.timedelta(days=10))
    .sign(private_key, hashes.SHA256())
)

In [None]:
# Local encryption
data_key = AESGCM.generate_key(bit_length=256)
nonce = os.urandom(12)
# Use public key from attestation doc instead
data_key_enc = public_key.encrypt(
    data_key,
    padding.OAEP(
        mgf=padding.MGF1(algorithm=hashes.SHA256()),
        algorithm=hashes.SHA256(),
        label=None
    )
)
aesgcm = AESGCM(data_key)
encrypted_data = aesgcm.encrypt(nonce, original_data.tobytes(), None)
with open("encrypted_data.bin", "wb") as f:
    # Write length first for easy reading
    f.write(len(data_key_enc).to_bytes(4, "big"))
    f.write(data_key_enc)
    f.write(nonce)
    f.write(encrypted_data)

In [None]:
# Decrypt and run inference
with open("encrypted_data.bin", "rb") as f:
    key_length = int.from_bytes(f.read(4), "big")
    data_key_enc = f.read(key_length)
    nonce = f.read(12)
    encrypted_data = f.read()

data_key = private_key.decrypt(
    data_key_enc,
    padding.OAEP(
        mgf=padding.MGF1(algorithm=hashes.SHA256()),
        algorithm=hashes.SHA256(),
        label=None
    )
)

aesgcm = AESGCM(data_key)
data = aesgcm.decrypt(nonce, encrypted_data, None)
input_data = np.frombuffer(data, dtype=np.float32).reshape(1, 1000, 12)
output = session.run(None, {input_name: input_data})

print("Data integrity check:", np.allclose(input_data, original_data))