# Proof splitting

Here we showcase how to split a larger circuit into multiple smaller proofs. This is useful if you want to prove over multiple machines, or if you want to split a proof into multiple parts to reduce the memory requirements.

We showcase how to do this in the case where:
- intermediate calculations can be public (i.e. they do not need to be kept secret) and we can stitch the circuits together using instances
- intermediate calculations need to be kept secret (but not blinded !)  and we need to use the low overhead kzg commitment scheme detailed [here](https://blog.ezkl.xyz/post/commits/) to stitch the circuits together. 


First we import the necessary dependencies and set up logging to be as informative as possible. 

In [None]:
# check if notebook is in colab
try:
    # install ezkl
    import google.colab
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "ezkl"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "onnx"])

# rely on local installation of ezkl if the notebook is not in colab
except:
    pass

from torch import nn
import ezkl
import os
import json
import logging

# uncomment for more descriptive logging 
# FORMAT = '%(levelname)s %(name)s %(asctime)-15s %(filename)s:%(lineno)d %(message)s'
# logging.basicConfig(format=FORMAT)
# logging.getLogger().setLevel(logging.INFO)


Now we define our model. It is a humble model with but a conv layer and a $ReLU$ non-linearity, but it is a model nonetheless

In [None]:
import torch
# Defines the model
# we got convs, we got relu, 
# What else could one want ????

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=1, kernel_size=2, stride=4)
        self.conv2 = nn.Conv2d(in_channels=1, out_channels=1, kernel_size=2, stride=4)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)

        return x
    
    def split_1(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        return x


circuit = MyModel()

# this is where you'd train your model




We omit training for purposes of this demonstration. We've marked where training would happen in the cell above. 
Now we export the model to onnx and create a corresponding (randomly generated) input file.

You can replace the random `x` with real data if you so wish. 

In [None]:
x = torch.rand(1,*[3, 8, 8], requires_grad=True)

# Flips the neural net into inference mode
circuit.eval()

    # Export the model
torch.onnx.export(circuit,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      "network.onnx",            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})


data_path = os.path.join(os.getcwd(), "input_0.json")
data = dict(input_data = [((x).detach().numpy()).reshape([-1]).tolist()])
json.dump( data, open(data_path, 'w' ))

inter_1 = circuit.split_1(x)
data_path = os.path.join(os.getcwd(), "input_1.json")
data = dict(input_data = [((inter_1).detach().numpy()).reshape([-1]).tolist()])
json.dump( data, open(data_path, 'w' ))



Now we split the model into two parts. The first part is the first conv layer and the second part is the rest of the model.

In [None]:
import onnx

input_path = "network.onnx"
output_path = "network_split_0.onnx"
input_names = ["input"]
output_names = ["/relu/Relu_output_0"]
# first model
onnx.utils.extract_model(input_path, output_path, input_names, output_names)

In [None]:
import onnx

input_path = "network.onnx"
output_path = "network_split_1.onnx"
input_names = ["/relu/Relu_output_0"]
output_names = ["output"]
# second model
onnx.utils.extract_model(input_path, output_path, input_names, output_names)

### Public intermediate calculations

This is where the magic happens. We define our `PyRunArgs` objects which contains the visibility parameters for out model. 
- `input_visibility` defines the visibility of the model inputs
- `param_visibility` defines the visibility of the model weights and constants and parameters 
- `output_visibility` defines the visibility of the model outputs

There are currently 5 visibility settings:
- `public`: known to both the verifier and prover (a subtle nuance is that this may not be the case for model parameters but until we have more rigorous theoretical results we don't want to make strong claims as to this). 
- `private`: known only to the prover
- `hashed`: the hash pre-image is known to the prover, the prover and verifier know the hash. The prover proves that the they know the pre-image to the hash. 
- `encrypted`: the non-encrypted element and the secret key used for decryption are known to the prover. The prover and the verifier know the encrypted element, the public key used to encrypt, and the hash of the decryption hey. The prover proves that they know the pre-image of the hashed decryption key and that this key can in fact decrypt the encrypted message.
- `kzgcommit`: unblinded advice column which generates a kzg commitment. This doesn't appear in the instances of the circuit and must instead be modified directly within the proof bytes.  

Here we create the following setup:
- `input_visibility`: "public"
- `param_visibility`: "public"
- `output_visibility`: public


In [None]:
import ezkl

srs_path = os.path.join('kzg.srs')
data_path = os.path.join('input.json')

run_args = ezkl.PyRunArgs()
run_args.input_visibility = "public"
run_args.param_visibility = "fixed"
run_args.output_visibility = "public"
run_args.variables = [("batch_size", 1)]
run_args.input_scale = 1
run_args.param_scale = 1
run_args.logrows = 13


Now we generate a settings file. This file basically instantiates a bunch of parameters that determine their circuit shape, size etc... Because of the way we represent nonlinearities in the circuit (using Halo2's [lookup tables](https://zcash.github.io/halo2/design/proving-system/lookup.html)), it is often best to _calibrate_ this settings file as some data can fall out of range of these lookups.

You can pass a dataset for calibration that will be representative of real inputs you might find if and when you deploy the prover. Here we create a dummy calibration dataset for demonstration purposes. 

In [None]:
# iterate over each submodel gen-settings, compile circuit and setup zkSNARK

async def circuit_gen_settings(i):
    # file names
    model_path = os.path.join('network_split_'+str(i)+'.onnx')
    settings_path = os.path.join('settings_split_'+str(i)+'.json')
    data_path =  os.path.join('input_'+str(i)+'.json')

    # generate settings for the current model
    res = ezkl.gen_settings(model_path, settings_path, py_run_args=run_args)

    res = ezkl.calibrate_settings(data_path, model_path, settings_path, "resources", scales=[run_args.input_scale])

for i in range(2):
    await circuit_gen_settings(i)


As we use Halo2 with KZG-commitments we need an SRS string from (preferably) a multi-party trusted setup ceremony. For an overview of the procedures for such a ceremony check out [this page](https://blog.ethereum.org/2023/01/16/announcing-kzg-ceremony). The `get_srs` command retrieves a correctly sized SRS given the calibrated settings file from [here](https://github.com/han0110/halo2-kzg-srs). 

These SRS were generated with [this](https://github.com/privacy-scaling-explorations/perpetualpowersoftau) ceremony. 

In [None]:

def get_max_logrows():
    max_logrows = 0
    for i in range(2):
        settings_path = os.path.join('settings_split_'+str(i)+'.json')
        new_settings = json.load(open(settings_path))
        if new_settings["run_args"]['logrows'] > max_logrows:
            max_logrows = new_settings["run_args"]['logrows']
    return max_logrows

def circuit_compiled_model(i, max_lowgrows):
 # now set the next model's input scale to the current model's output scale
    settings_path = os.path.join('settings_split_'+str(i)+'.json')
    model_path = os.path.join('network_split_'+str(i)+'.onnx')
    compiled_model_path = os.path.join('network_split_'+str(i)+'.compiled')
    # compile the circuit
    
    settings = json.load(open(settings_path))
    settings["run_args"]['logrows'] = max_lowgrows
    #  save it
    json.dump(settings, open(settings_path, 'w'))
    
    res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
    assert res == True

max_lowgrows = get_max_logrows()
ezkl.get_srs(srs_path, logrows=max_lowgrows)
for i in range(2):
    circuit_compiled_model(i, get_max_logrows())

In [None]:
from multiprocessing import Pool

def setup_model(i): 
   compiled_model_path = os.path.join('network_split_'+str(i)+'.compiled')
   pk_path = os.path.join('test_split_'+str(i)+'.pk')
   vk_path = os.path.join('test_split_'+str(i)+'.vk')
      # HERE WE SETUP THE CIRCUIT PARAMS
      # WE GOT KEYS
      # WE GOT CIRCUIT PARAMETERS
      # EVERYTHING ANYONE HAS EVER NEEDED FOR ZK
   res = ezkl.setup(
         compiled_model_path,
         vk_path,
         pk_path,
         srs_path,
      )

   assert res == True
   assert os.path.isfile(vk_path)
   assert os.path.isfile(pk_path)
   
   print("Setup model "+str(i)+" done")
   
for i in range(2): 
    setup_model(i)

We now need to generate the (partial) circuit witness. These are the model outputs (and any hashes) that are generated when feeding the previously generated `input.json` through the circuit / model. 

In [None]:

def witness_gen_model(i):
      # generate witnesses in sequence
      data_path = os.path.join('input_'+str(i)+'.json')
      witness_path = os.path.join('witness_split_'+str(i)+'.json')
      compiled_model_path = os.path.join('network_split_'+str(i)+'.compiled')
      vk_path = os.path.join('test_split_'+str(i)+'.vk')

      if i > 0:
         prev_witness_path = os.path.join('witness_split_'+str(i-1)+'.json')
         witness = json.load(open(prev_witness_path, 'r'))
         data = dict(input_data = witness['outputs'])
         # Serialize data into file:
         json.dump(data, open(data_path, 'w' ))
      else:
         data_path = os.path.join('input_0.json')

      res = ezkl.gen_witness(data_path, compiled_model_path, witness_path, vk_path, srs_path)

for i in range(2):
    witness_gen_model(i)

Here we setup verifying and proving keys for the circuit. As the name suggests the proving key is needed for ... proving and the verifying key is needed for ... verifying. 

In [None]:
# GENERATE A PROOF

def prove_model(i): 
    proof_path = os.path.join('proof_split_'+str(i)+'.json')
    witness_path = os.path.join('witness_split_'+str(i)+'.json')
    compiled_model_path = os.path.join('network_split_'+str(i)+'.compiled')
    pk_path = os.path.join('test_split_'+str(i)+'.pk')
    vk_path = os.path.join('test_split_'+str(i)+'.vk')
    settings_path = os.path.join('settings_split_'+str(i)+'.json')

    res = ezkl.prove(
            witness_path,
            compiled_model_path,
            pk_path,
            proof_path,
            srs_path,
            "for-aggr",
        )

    print(res)
    assert os.path.isfile(proof_path)

    # Verify the proof
    if i > 0:
        # swap the proof commitments if we are not the first model
        prev_witness_path = os.path.join('witness_split_'+str(i-1)+'.json')
        prev_witness = json.load(open(prev_witness_path, 'r'))

        witness = json.load(open(witness_path, 'r'))

        print(prev_witness["processed_outputs"])
        print(witness["processed_inputs"])

        witness["processed_inputs"] = prev_witness["processed_outputs"]

        # now save the witness
        with open(witness_path, "w") as f:
            json.dump(witness, f)

        res = ezkl.swap_proof_commitments(proof_path, witness_path)

    res = ezkl.verify(
            proof_path,
            settings_path,
            vk_path,
            srs_path,
        )

    assert res == True
    print("verified")

for i in range(2):
    prove_model(i)

###  KZG commitment intermediate calculations

This time the visibility parameters are:
- `input_visibility`: "kzgcommit"
- `param_visibility`: "public"
- `output_visibility`: kzgcommit

In [None]:
import ezkl

run_args = ezkl.PyRunArgs()
run_args.input_visibility = "kzgcommit"
run_args.param_visibility = "fixed"
run_args.output_visibility = "kzgcommit"
run_args.variables = [("batch_size", 1)]
run_args.input_scale = 1
run_args.param_scale = 1
run_args.logrows = 9



In [None]:
for i in range(2):
    await circuit_gen_settings(i)

In [None]:
for i in range(2):
    circuit_compiled_model(i, get_max_logrows())

In [None]:
for i in range(2): 
    setup_model(i)

In [None]:
for i in range(2):
    witness_gen_model(i)

In [None]:
for i in range(2):
    prove_model(i)

You can also mock aggregate the split proofs into a single proof. This is useful if you want to verify the proof on chain at a lower cost. Here we mock aggregate the proofs to save time. You can use other notebooks to see how to aggregate in full ! 

In [None]:
# now mock aggregate the proofs
proofs = []
for i in range(2):
    proof_path = os.path.join('proof_split_'+str(i)+'.json')
    proofs.append(proof_path)

ezkl.mock_aggregate(proofs, logrows=23, split_proofs = True)