In [None]:


import torch
import ezkl
import json
import subprocess
from pathlib import Path


class Passthrough(torch.nn.Module):
    def __init__(self, input_size=100):
        super().__init__()

    def forward(self, x):
        return x

def generate_random_data(size=100, min_val=1, max_val=10):
    return [min_val + (max_val - min_val) * torch.rand(1).item() for _ in range(size)]

def save_json(data, filename):
    with open(filename, 'w') as f:
        json.dump(data, f)

async def run_ezkl_pipeline():
    gip_run_args = ezkl.PyRunArgs()
    gip_run_args.input_visibility = "public"
    gip_run_args.output_visibility = "public"   # no parameters used
    gip_run_args.param_visibility = "fixed"
    gip_run_args.input_scale = 19
    gip_run_args.param_scale = 19
    gip_run_args.logrows = 8
    run_args = ezkl.gen_settings(py_run_args=gip_run_args)
    await ezkl.get_srs()
    ezkl.compile_circuit()
    res = ezkl.gen_witness()
    print(res)
    ezkl.setup()
    ezkl.prove(proof_path="proof.json")
    ezkl.verify()

def verify_proof_matches_input():
    settings = json.load(open("settings.json"))
    inputs = json.load(open("input.json"))
    proof = json.load(open("proof.json"))

    input_scale = settings["model_input_scales"][0]
    model_shapes = settings["model_instance_shapes"]

    flat_inputs = [x for arr in inputs["input_data"] for x in arr]
    scaled_inputs = [ezkl.float_to_felt(x, input_scale, ezkl.PyInputType.F32) for x in flat_inputs]
    proof_instances = proof["instances"][0]

    def get_group_index(i):
        pos = 0
        for idx, (batch, length) in enumerate(model_shapes):
            next_pos = pos + (batch * length)
            if i < next_pos:
                return idx
            pos = next_pos
        raise IndexError("Index out of bounds")

    for i, (scaled, instance) in enumerate(zip(scaled_inputs, proof_instances)):
        group_idx = get_group_index(i)
        _, length = model_shapes[group_idx]

        descaled_instance = ezkl.felt_to_float(instance, input_scale)
        descaled_input = ezkl.felt_to_float(scaled, input_scale)
        pretty_value = proof["pretty_public_inputs"]["rescaled_inputs"][group_idx][i % length]

        assert scaled == instance, f"Input mismatch at index {i}: {scaled} != {instance} ({descaled_instance} != {descaled_input} OG {flat_inputs[i]} PRETTY {pretty_value})"

model = Passthrough()
torch.onnx.export(model, torch.randn(1, 100), "network.onnx")

input_data = {"input_data": [generate_random_data()]}
save_json(input_data, "input.json")
save_json({"input_data": [generate_random_data()]}, "calibration.json")

await run_ezkl_pipeline()
verify_proof_matches_input()



{'inputs': [['a5c7080000000000000000000000000000000000000000000000000000000000', 'b09c1c0000000000000000000000000000000000000000000000000000000000', '29fe2e0000000000000000000000000000000000000000000000000000000000', '5d7e1a0000000000000000000000000000000000000000000000000000000000', 'f3ed390000000000000000000000000000000000000000000000000000000000', '93bf370000000000000000000000000000000000000000000000000000000000', '5973130000000000000000000000000000000000000000000000000000000000', 'f760370000000000000000000000000000000000000000000000000000000000', 'f79b1b0000000000000000000000000000000000000000000000000000000000', '2dee360000000000000000000000000000000000000000000000000000000000', 'f062370000000000000000000000000000000000000000000000000000000000', '5392270000000000000000000000000000000000000000000000000000000000', '2e64270000000000000000000000000000000000000000000000000000000000', 'd2ee1f0000000000000000000000000000000000000000000000000000000000', '1c194f0000000000000000000000000000

  from .autonotebook import tqdm as notebook_tqdm
