Imports

In [9]:
import json
import re
import ezkl
import os
import onnxruntime as ort
import torchvision
import numpy as np

Constants

In [10]:
PY_RUN_ARGS = ezkl.PyRunArgs()
PY_RUN_ARGS.input_visibility = "public"
PY_RUN_ARGS.output_visibility = "public"
PY_RUN_ARGS.param_visibility = "fixed"
MODEL_DIR = "./model"
DATA_DIR = "./data"
PROOF_DIR = "./proof"
# Either resources or accuracy
EZKL_OPTIMIZATION_GOAL = "resources"
MODEL_ID = "model_2"

Setup & Helper Functions

In [11]:
os.makedirs(PROOF_DIR, exist_ok=True)

def fetch_test_data():
    test_data = torchvision.datasets.MNIST(root=DATA_DIR, train=False, download=False, transform=torchvision.transforms.ToTensor())
    return fetch_first_image(test_data)

def fetch_first_image(data):
    image, _ = data[0]
    np_image = image.numpy()[np.newaxis, :]
    return np_image
    
    
# Set input_data to the output of the previous shard. Doesn't need to be set for the first shard.
def generate_example_model_output(model_path, data_path, input_data=None):
    ort_session = ort.InferenceSession(model_path)
    if input_data is None:
        input_data = fetch_test_data()
    output_data = ort_session.run(None, {'input': input_data})
    witness_data = dict(input_shapes=[input_data.shape],
                        input_data=[input_data.reshape([-1]).tolist()],
                        output_data=[o.reshape([-1]).tolist() for o in output_data])
    with open(data_path, 'w') as f:
        json.dump(witness_data, f)
    # The first element in output_data is the actual NumPy ndarray we need to pass to the function when we call it for
    # the next shard.
    return output_data[0] 
    
    
def get_number_of_shards(model_id):
    pattern = re.compile(f"{model_id}_shard_(\d+).onnx")
    highest_shard_id = -1
    for file in os.listdir(MODEL_DIR):
        match = pattern.match(file)
        if match:
            match_shard_id = int(match.group(1))
            if match_shard_id > highest_shard_id:
                highest_shard_id = match_shard_id
    return highest_shard_id
        
        
    

In [12]:
def generate_proof(shard_id, previous_shard_output):
    MODEL_PATH = f"{MODEL_DIR}/{MODEL_ID}_shard_{shard_id}.onnx"
    SETTINGS_PATH = f"{PROOF_DIR}/{MODEL_ID}_shard_{shard_id}_settings.json"
    DATA_PATH = f"{PROOF_DIR}/{MODEL_ID}_shard_{shard_id}_data.json"
    COMPILED_MODEL_PATH = f"{PROOF_DIR}/{MODEL_ID}_shard_{shard_id}_network.compiled"
    SRS_PATH = f"{PROOF_DIR}/{MODEL_ID}_shard_{shard_id}_kzg.srs"
    VK_PATH = f"{PROOF_DIR}/{MODEL_ID}_shard_{shard_id}_vk.key"
    PK_PATH = f"{PROOF_DIR}/{MODEL_ID}_shard_{shard_id}_pk.key"
    WITNESS_PATH = f"{PROOF_DIR}/{MODEL_ID}_shard_{shard_id}_witness.json"
    PROOF_PATH = f"{PROOF_DIR}/{MODEL_ID}_shard_{shard_id}_proof.pf"
    
    # Generate Settings File
    res = ezkl.gen_settings(
        model=MODEL_PATH,
        output=SETTINGS_PATH,
        py_run_args=PY_RUN_ARGS
    )
    assert res == True
    
    # Calibrate Settings
    # We only need to save the intermediate output here, as the witness data is placed in `DATA_PATH`.
    DATA_OUTPUT = generate_example_model_output(MODEL_PATH, DATA_PATH, previous_shard_output)
    res = ezkl.calibrate_settings(
        data=DATA_PATH,
        model=MODEL_PATH,
        settings=SETTINGS_PATH,
        target=EZKL_OPTIMIZATION_GOAL
    )
    assert res == True
    
    # Compile model to a circuit
    res = ezkl.compile_circuit(
        model=MODEL_PATH,
        compiled_circuit=COMPILED_MODEL_PATH,
        settings_path=SETTINGS_PATH
    )
    assert res == True
    
    # (Down)load an SRS String
    res = ezkl.get_srs(
        settings_path=SETTINGS_PATH,
        logrows=None,
        srs_path=SRS_PATH
    )
    assert res == True
    assert os.path.isfile(SRS_PATH)
    
    # Setup Proof
    res = ezkl.setup(
        model=COMPILED_MODEL_PATH,
        vk_path=VK_PATH,
        pk_path=PK_PATH,
        srs_path=SRS_PATH
    )
    assert res == True
    assert os.path.isfile(VK_PATH)
    assert os.path.isfile(PK_PATH)
    
    # Generate witness file
    res = ezkl.gen_witness(
        data=DATA_PATH,
        model=COMPILED_MODEL_PATH,
        output=WITNESS_PATH
    )
    assert os.path.isfile(WITNESS_PATH)
    
    # Create ZK-SNARK for the execution of the model
    res = ezkl.prove(
        witness=WITNESS_PATH,
        model=COMPILED_MODEL_PATH,
        pk_path=PK_PATH,
        proof_path=PROOF_PATH,
        srs_path=SRS_PATH
    )
    assert os.path.isfile(PROOF_PATH)
    
    # Return the intermediate output of the shard, as this is needed for the subsequent proof.
    return DATA_OUTPUT, PROOF_PATH, SETTINGS_PATH, VK_PATH, SRS_PATH
    
    
def verify_proof(proof_path, settings_path, vk_path, srs_path):
    #Verify proof
    res = ezkl.verify(
        proof_path=proof_path,
        settings_path=settings_path,
        vk_path=vk_path,
        srs_path=srs_path
    )
    assert res == True
    
    

Prove Shard(s)

In [13]:
num_shards = get_number_of_shards(MODEL_ID)
previous_shard_output = None
for shard_id in range(num_shards+1):
    previous_shard_output, proof_path, settings_path, vk_path, srs_path = generate_proof(
        shard_id=shard_id,
        previous_shard_output=previous_shard_output
    )
    print(f"Proof of Shard {shard_id} generated at {proof_path}")
    
    # Added to ensure proofs are correct
    verify_proof(proof_path, settings_path, vk_path, srs_path)
    print(f"Proof of Shard {shard_id} verified")
    num_shards += 1
    
print(f"Completed processing for {num_shards} shards")



 <------------- Numerical Fidelity Report (input_scale: 13, param_scale: 13, scale_input_multiplier: 10) ------------->

+------------------+--------------+----------------+-----------------+-----------------+------------------+----------------+---------------+---------------------+--------------------+------------------------+
| mean_error       | median_error | max_error      | min_error       | mean_abs_error  | median_abs_error | max_abs_error  | min_abs_error | mean_squared_error  | mean_percent_error | mean_abs_percent_error |
+------------------+--------------+----------------+-----------------+-----------------+------------------+----------------+---------------+---------------------+--------------------+------------------------+
| 0.00000012489642 | 0            | 0.000060796738 | -0.000058874488 | 0.0000044188305 | 0                | 0.000060796738 | 0             | 0.00000000017286579 | NaN                | NaN                    |
+------------------+--------------+------

Proof of Shard 0 generated at ./proof/model_2_shard_0_proof.pf
Proof of Shard 0 verified


Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 2 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 4 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 8 columns for non-linearity table.
Using 3 columns for non-linearity table.
Using 6 columns for non-linearity table.
Using 11 columns

Proof of Shard 1 generated at ./proof/model_2_shard_1_proof.pf
Proof of Shard 1 verified
