## Hash set membership demo

In [None]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "pytest"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

import logging
FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'
logging.basicConfig(format=FORMAT)
logging.getLogger().setLevel(logging.DEBUG)

# here we create and (potentially train a model)

# make sure you have the dependencies required here already installed
from torch import nn
import ezkl
import os
import json
import torch

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, x, y):
        diff = (x - y)
        membership_test = torch.prod(diff, dim=1)
        return (membership_test,y)


circuit = MyModel()

# Train the model as you like here (skipped for brevity)




In [None]:
model_path = os.path.join('network.onnx')
compiled_model_path = os.path.join('network.compiled')
pk_path = os.path.join('test.pk')
vk_path = os.path.join('test.vk')
settings_path = os.path.join('settings.json')

witness_path = os.path.join('witness.json')
data_path = os.path.join('input.json')

In [None]:
# print pytorch version
print(torch.__version__)

In [None]:


x = torch.zeros(1,*[1], requires_grad=True)
y = torch.tensor([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], requires_grad=True)

y_input = [0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]

# Create an empty list to store the results
result = []

# Loop through each element in the y tensor
for e in y_input:
    # Apply the custom function and append the result to the list
    print(ezkl.float_to_felt(e,7))
    result.append(ezkl.poseidon_hash([ezkl.float_to_felt(e, 7)])[0])

y = y.unsqueeze(0)
y = y.reshape(1, 9)

# Flips the neural net into inference mode
circuit.eval()

    # Export the model
torch.onnx.export(circuit,               # model being run
                      (x,y),                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=14,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

data_array_x = ((x).detach().numpy()).reshape([-1]).tolist()
data_array_y = result
print(data_array_y)

data = dict(input_data = [data_array_x, data_array_y])

print(data)

    # Serialize data into file:
json.dump( data, open(data_path, 'w' ))


In [None]:
run_args = ezkl.PyRunArgs()
# "hashed/private" means that the output of the hashing is not visible to the verifier and is instead fed into the computational graph
run_args.input_visibility = "hashed/private/0"
# as the inputs are felts we turn off input range checks
run_args.ignore_range_check_inputs_outputs = True
# we set it to fix the set we want to check membership for
run_args.param_visibility = "fixed"
# the output is public -- set membership fails if it is not = 0
run_args.output_visibility = "fixed"
run_args.variables = [("batch_size", 1)]
# never rebase the scale
run_args.scale_rebase_multiplier = 1000
# logrows
run_args.logrows = 11

#  this creates the following sequence of ops:
# 1. hash the input -> poseidon(x)
# 2. compute the set difference -> poseidon(x) - set
# 3. compute the product of the set difference -> prod(poseidon(x) - set)


# TODO: Dictionary outputs
res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)
assert res == True


In [None]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

In [None]:
# srs path
res = await ezkl.get_srs( settings_path)

In [None]:
# now generate the witness file

res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [None]:
# now generate a faulty input + witness file (x input not in the set)

data_path_faulty = os.path.join('input_faulty.json')

witness_path_faulty = os.path.join('witness_faulty.json')

x = torch.ones(1,*[1], requires_grad=True)
y = torch.tensor([0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], requires_grad=True)

y = y.unsqueeze(0)
y = y.reshape(1, 9)

data_array_x = ((x).detach().numpy()).reshape([-1]).tolist()
data_array_y = result
print(data_array_y)

data = dict(input_data = [data_array_x, data_array_y])

print(data)

    # Serialize data into file:
json.dump( data, open(data_path_faulty, 'w' ))

res = await ezkl.gen_witness(data_path_faulty, compiled_model_path, witness_path_faulty)
assert os.path.isfile(witness_path_faulty)

In [None]:
# now generate a truthy input + witness file (x input not in the set)
import random

# Generate a random integer between 1 and 8, inclusive
random_value = random.randint(1, 8)

data_path_truthy = os.path.join('input_truthy.json')

witness_path_truthy = os.path.join('witness_truthy.json')

set = [0.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]

x = torch.tensor([set[random_value]])
y = torch.tensor(set, requires_grad=True)

y = y.unsqueeze(0)
y = y.reshape(1, 9)

x = x.unsqueeze(0)
x = x.reshape(1,1)

data_array_x = ((x).detach().numpy()).reshape([-1]).tolist()
data_array_y = result
print(data_array_y)

data = dict(input_data = [data_array_x, data_array_y])

print(data)

# Serialize data into file:
json.dump( data, open(data_path_truthy, 'w' ))

res = await ezkl.gen_witness(data_path_truthy, compiled_model_path, witness_path_truthy)
assert os.path.isfile(witness_path_truthy)

In [None]:
witness = json.load(open(witness_path, "r"))
witness

In [None]:

# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK

# we force the output to be 0 this corresponds to the set membership test being true -- and we set this to a fixed vis output
# this means that the output is fixed and the verifier can see it but that if the input is not in the set the output will not be 0 and the verifier will reject
witness = json.load(open(witness_path, "r"))
witness["outputs"][0] = ["0000000000000000000000000000000000000000000000000000000000000000"]
json.dump(witness, open(witness_path, "w"))

witness = json.load(open(witness_path, "r"))
print(witness["outputs"][0])

res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        witness_path = witness_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [None]:
# GENERATE A PROOF


proof_path = os.path.join('test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        
        "single",
    )

print(res)
assert os.path.isfile(proof_path)

In [None]:
# GENERATE A FAULTY PROOF


proof_path_faulty = os.path.join('test_faulty.pf')

res = ezkl.prove(
        witness_path_faulty,
        compiled_model_path,
        pk_path,
        proof_path_faulty,
        
        "single",
    )

print(res)
assert os.path.isfile(proof_path_faulty)

In [None]:
# GENERATE A TRUTHY PROOF


proof_path_truthy = os.path.join('test_truthy.pf')

res = ezkl.prove(
        witness_path_truthy,
        compiled_model_path,
        pk_path,
        proof_path_truthy,
        
        "single",
    )

print(res)
assert os.path.isfile(proof_path_truthy)

In [None]:
# VERIFY IT

res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        
    )
assert res == True

res = ezkl.verify(
        proof_path_truthy,
        settings_path,
        vk_path,
        
    )
assert res == True

In [None]:
import pytest
def test_verification():
    with pytest.raises(RuntimeError, match='Failed to run verify: \\[halo2\\] The constraint system is not satisfied'):
        ezkl.verify(
            proof_path_faulty,
            settings_path,
            vk_path,
            
        )

# Run the test function
test_verification()