Imports

In [61]:
import json
import ezkl
import os
import onnxruntime as ort
import torchvision
import numpy as np

Constants

In [62]:
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "fixed"
MODEL_DIR = "./model"
DATA_DIR = "./data"
PROOF_DIR = "./proof"
# Either resources or accuracy
EZKL_OPTIMIZATION_GOAL = "resources"
model_id = "model"

In [63]:
model_path = f"{MODEL_DIR}/{model_id}.onnx"
settings_path = f"{PROOF_DIR}/{model_id}_settings.json"
data_path = f"{PROOF_DIR}/{model_id}_data.json"
circuit_path = f"{PROOF_DIR}/{model_id}_circuit.ezkl"
compiled_model_path = f"{PROOF_DIR}/{model_id}_network.compiled"
srs_path = f"{PROOF_DIR}/{model_id}_kzg.srs"
vk_path = f"{PROOF_DIR}/{model_id}_vk.key"
pk_path = f"{PROOF_DIR}/{model_id}_pk.key"
witness_path = f"{PROOF_DIR}/{model_id}_witness.json"
proof_path = f"{PROOF_DIR}/{model_id}_proof.pf"

Setup & Helper Functions

In [64]:
os.makedirs(PROOF_DIR, exist_ok=True)

In [65]:
def fetch_test_data():
    test_data = torchvision.datasets.MNIST(root=DATA_DIR, train=False, download=False, transform=torchvision.transforms.ToTensor())
    return fetch_first_image(test_data)

def fetch_first_image(data):
    image, _ = data[0]
    np_image = image.numpy()[np.newaxis, :]
    return np_image
    
    
def generate_example_model_output(model_path, data_path):
    ort_session = ort.InferenceSession(model_path)
    input_data = fetch_test_data()
    output_data = ort_session.run(None, {'input': input_data})
    witness_data = dict(input_shapes=[input_data.shape],
                        input_data=[input_data.reshape([-1]).tolist()],
                        output_data=[o.reshape([-1]).tolist() for o in output_data])
    with open(data_path, 'w') as f:
        json.dump(witness_data, f)
    return data_path

Generate Proof

In [66]:
# Generate Settings File
res = ezkl.gen_settings(model_path, settings_path, py_run_args = py_run_args)
assert res == True

In [67]:
# Calibrate Settings
data_path = generate_example_model_output(model_path, data_path)
res = ezkl.calibrate_settings(data_path, model_path, settings_path, EZKL_OPTIMIZATION_GOAL)
assert res == True

Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 7 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 6 columns for non-linearity table.
Using 11 columns

In [68]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [70]:
# This does not work
# res = ezkl.get_srs(srs_path, settings_path)
# assert res == True
os.system(f"ezkl get-srs --srs-path {srs_path} -S {settings_path}")
assert os.path.isfile(srs_path)

[1;34m[[0m[1;34m*[0m[1;34m][0m [[95m2024-03-13 20:41:44[0m, ezkl] - [1;37m
[1;37m | [0m 
[1;37m | [0m         ███████╗███████╗██╗  ██╗██╗
[1;37m | [0m         ██╔════╝╚══███╔╝██║ ██╔╝██║
[1;37m | [0m         █████╗    ███╔╝ █████╔╝ ██║
[1;37m | [0m         ██╔══╝   ███╔╝  ██╔═██╗ ██║
[1;37m | [0m         ███████╗███████╗██║  ██╗███████╗
[1;37m | [0m         ╚══════╝╚══════╝╚═╝  ╚═╝╚══════╝
[1;37m | [0m 
[1;37m | [0m         -----------------------------------------------------------
[1;37m | [0m         Easy Zero Knowledge for the Laconic.
[1;37m | [0m         -----------------------------------------------------------
[1;37m | [0m 
[1;37m | [0m         [0m
[1;34m[[0m[1;34m*[0m[1;34m][0m [[95m2024-03-13 20:41:44[0m, ezkl::execute] - [1;37mSRS already exists at that path[0m
[1;34m[[0m[1;34m*[0m[1;34m][0m [[95m2024-03-13 20:41:44[0m, ezkl::execute] - [1;37mread 134217988 bytes from SRS file (vector of len = 134217988)[0m
[1;34m[[0

In [71]:
res = ezkl.setup(compiled_model_path, vk_path, pk_path, srs_path)
assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)

In [57]:
res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [59]:
# res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, srs_path, "single")
os.system(f"ezkl prove -M {compiled_model_path} --witness {witness_path} --pk-path={pk_path} --proof-path={proof_path} --srs-path={srs_path}")
assert os.path.isfile(proof_path)

[1;34m[[0m[1;34m*[0m[1;34m][0m [[95m2024-03-13 20:14:21[0m, ezkl] - [1;37m
[1;37m | [0m 
[1;37m | [0m         ███████╗███████╗██╗  ██╗██╗
[1;37m | [0m         ██╔════╝╚══███╔╝██║ ██╔╝██║
[1;37m | [0m         █████╗    ███╔╝ █████╔╝ ██║
[1;37m | [0m         ██╔══╝   ███╔╝  ██╔═██╗ ██║
[1;37m | [0m         ███████╗███████╗██║  ██╗███████╗
[1;37m | [0m         ╚══════╝╚══════╝╚═╝  ╚═╝╚══════╝
[1;37m | [0m 
[1;37m | [0m         -----------------------------------------------------------
[1;37m | [0m         Easy Zero Knowledge for the Lyrical.
[1;37m | [0m         -----------------------------------------------------------
[1;37m | [0m 
[1;37m | [0m         [0m
[1;34m[[0m[1;34m*[0m[1;34m][0m [[95m2024-03-13 20:14:21[0m, ezkl::pfsys::srs] - [1;37mloading srs from "./proof/model_kzg.srs"[0m
[1;34m[[0m[1;34m*[0m[1;34m][0m [[95m2024-03-13 20:14:21[0m, ezkl::execute] - [1;37mdownsizing params to 20 logrows[0m
[1;34m[[0m[1;34m*[0m[1;34

Verification

In [60]:
res = ezkl.verify(proof_path, settings_path, vk_path, srs_path)
assert res == True