## Hash set membership demo

In [2]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass


# here we create and (potentially train a model)

# make sure you have the dependencies required here already installed
from torch import nn
import ezkl
import os
import json
import torch


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        # this is a constant set
        self.set = torch.nn.Parameter(torch.tensor([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]))

    def forward(self, x):
        diff = (x - self.set)
        membership_test = torch.prod(diff, dim=1)
        return membership_test


circuit = MyModel()

# Train the model as you like here (skipped for brevity)




In [3]:
model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')
srs_path = os.path.join('kzg.srs')
witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')

In [4]:
# print pytorch version 
print(torch.__version__)

2.0.1


In [5]:


# After training, export to onnx (network.onnx) and create a data file (input.json)
# hash(0) = 0x00000000, so this will be a member of the set
x = 0.1*torch.zeros(1,*[1], requires_grad=True)

# Flips the neural net into inference mode
circuit.eval()

    # Export the model
torch.onnx.export(circuit,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=14,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump( data, open(data_path, 'w' ))


verbose: False, log level: Level.ERROR



In [6]:
run_args = ezkl.PyRunArgs()
# "hashed/private" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph
run_args.input_visibility = "hashed/private"
# we set it to fix the set we want to check membership for
run_args.param_visibility = "public"
# the output is public -- set membership fails if it is not = 0
run_args.output_visibility = "public"
run_args.variables = [("batch_size", 1)]
# never rebase the scale
run_args.scale_rebase_multiplier = 1000
# logrows
run_args.logrows = 11

#  this creates the following sequence of ops:
# 1. hash the input -> poseidon(x)
# 2. compute the set difference -> poseidon(x) - set
# 3. compute the product of the set difference -> prod(poseidon(x) - set)


# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True


In [7]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [8]:
# srs path
res = ezkl.get_srs(srs_path, settings_path)

In [9]:
# now generate the witness file 

res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [10]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK



res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        srs_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

inputs: [Value { inner: Tensor { inner: [Value(Value { inner: Some(0x0000000000000000000000000000000000000000000000000000000000000000) })], dims: [1], scale: None, visibility: None }, dims: [1], scale: 1 }]
inputs: [Value { inner: Tensor { inner: [Value(Value { inner: Some(0x0000000000000000000000000000000000000000000000000000000000000000) })], dims: [1], scale: None, visibility: None }, dims: [1], scale: 1 }]


In [11]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        srs_path,
        "evm",
        "single",
    )

print(res)
assert os.path.isfile(proof_path)

inputs: [Value { inner: Tensor { inner: [Value(Value { inner: Some(0x0000000000000000000000000000000000000000000000000000000000000000) })], dims: [1], scale: None, visibility: None }, dims: [1], scale: 1 }]
{'instances': [[[0, 0, 0, 0]], [[0, 0, 0, 0]]], 'proof': '2133dee3b23dbe540ae13073c4d880ea1eac284160ac350ab6750684ae5c34e12db979956b4fe3dc1a1811ae65f4e23e2677042c860ccfb4400aef6c2637c5bd0425ce2f65456f05f3f1efed5f04c04a6354c1689f416661ec47734d68a984150f91e778876837f1a8fb1cb372ae78bff195e52dc576436274eab1c9820c15ca07b44d21ddfd4d11566f918b8f15549014a6b61edf1b3e9d7ae9de4654d3b5ad2a2501b5d5b9995d8a009d7c679ac563136ee4fdea1a697f0fe4bc6e97e20a6c05f74a2c659e4e01f4705992eb38b5320d43eceeca8aeea6c781927f5119564a1c863704e1d0e0e757b5b8cb72216b5db6ddbcd643d50bbf165b26b6f91069db208e7cb610a986fc1bc5da2f618070feefed6da97e0fda50f88168efec4033f62f0dd93f4984048b8654de0cd5ccac61c035b9bd6229fa723b035980d42e7f44101b51fc98b7de81ff580bdaf4bbb70e3057ceac997c0bdc61de5f2f54190e111d77145195139b6393e6f5dd23116e8

In [12]:
# VERIFY IT

res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        srs_path,
    )

assert res == True
print("verified")

verified
