In [1]:
# https://keras.io/getting_started/
import os
os.environ["KERAS_BACKEND"] = "torch"

In [2]:
from dataclasses import dataclass, field
from enum import Enum
import json
from pathlib import Path
from typing import Iterable, Literal, Optional

import ezkl
import keras
import torch

In [3]:
REPO_ROOT = Path().absolute().parent
REPO_ROOT

PosixPath('/home/suller/ezkl')

In [4]:
MODELS_DIR = REPO_ROOT / "models"
MODELS_DIR

PosixPath('/home/suller/ezkl/models')

## Define Neural Network Models

[Section 5.1](https://www.politesi.polimi.it/retrieve/ab2f9f29-9491-444a-aade-be38b88dc67d/2023_05_Cerioli_01.pdf#section.5.1)

In [5]:
class ModelAttributes(Enum):
    FNN = ("fnn", (50,))
    SMALL_CNN = ("small_cnn", (1,10,10))
    MNIST = ("mnist", (1,28,28))
    LENET5 = ("lenet5", (1,32,32))
    # VGG11 = ("vgg11", (224,224,3))

    def __init__(self, name: str, shape: Iterable[int]) -> None:
        self.model_name = name
        self.input_shape = shape

In [6]:
@dataclass
class Model:
    name: str
    input_shape: Iterable[int]
    model: keras.Model = None
    root: Optional[Path] = None
    onnx_path: Path = field(init=False)
    calibration_data_path: Path = field(init=False)
    inference_data_path: Path = field(init=False)
    output_dir: Path = field(init=False)

    def __post_init__(self):
        if self.root is None:
            self.root = REPO_ROOT
        data_dir = self.root / "data"
        self.onnx_path = self.root / "models" / f"{self.name}.onnx"
        self.calibration_data_path = data_dir / "2-calibration" / f"{self.name}.json"
        self.inference_data_path = data_dir / "3-inference" / f"{self.name}.json"
        self.output_dir = self.root / "output" / self.name

In [7]:
models = {
    attributes.model_name: Model(attributes.model_name, attributes.input_shape)
    for attributes in ModelAttributes
}
models

{'fnn': Model(name='fnn', input_shape=(50,), model=None, root=PosixPath('/home/suller/ezkl'), onnx_path=PosixPath('/home/suller/ezkl/models/fnn.onnx'), calibration_data_path=PosixPath('/home/suller/ezkl/data/2-calibration/fnn.json'), inference_data_path=PosixPath('/home/suller/ezkl/data/3-inference/fnn.json'), output_dir=PosixPath('/home/suller/ezkl/output/fnn')),
 'small_cnn': Model(name='small_cnn', input_shape=(10, 10, 1), model=None, root=PosixPath('/home/suller/ezkl'), onnx_path=PosixPath('/home/suller/ezkl/models/small_cnn.onnx'), calibration_data_path=PosixPath('/home/suller/ezkl/data/2-calibration/small_cnn.json'), inference_data_path=PosixPath('/home/suller/ezkl/data/3-inference/small_cnn.json'), output_dir=PosixPath('/home/suller/ezkl/output/small_cnn')),
 'mnist': Model(name='mnist', input_shape=(28, 28, 1), model=None, root=PosixPath('/home/suller/ezkl'), onnx_path=PosixPath('/home/suller/ezkl/models/mnist.onnx'), calibration_data_path=PosixPath('/home/suller/ezkl/data/2-ca

### Fully Connected Neural Network

In [8]:
fnn = keras.Sequential((
    keras.layers.Input(shape=models["fnn"].input_shape),
    keras.layers.Dense(25),
    keras.layers.Dense(2),
))
fnn.compile()
fnn.summary()

models["fnn"].model = fnn

### Small Convolution

In [9]:
small_cnn = keras.Sequential((
    keras.layers.Input(shape=models["small_cnn"].input_shape),
    keras.layers.Conv2D(filters=6, kernel_size=3),
))
small_cnn.compile()
small_cnn.summary()

models["small_cnn"].model = small_cnn

### Convolutional Neural Network

In [10]:
def polynomial_activation(x):
    return x*x + (10**6)*x

#### MNIST Conv-Net

In [11]:
mnist = keras.Sequential((
    keras.layers.Input(shape=models["mnist"].input_shape),
    keras.layers.Conv2D(filters=4, kernel_size=3),
    # Activation multiplies values by 4 so Average Pooling becomes
    # equivalent to Sum Pooling employed in the thesis
    keras.layers.Activation(lambda x: 4*(x*x + (10**6)*x)),
    keras.layers.AvgPool2D(pool_size=2, strides=2),
    keras.layers.Conv2D(filters=8, kernel_size=3),
    # Activation multiplies values by 4 so Average Pooling becomes
    # equivalent to Sum Pooling employed in the thesis
    keras.layers.Activation(lambda x: 4*(x*x + (10**15)*x)),
    keras.layers.AvgPool2D(pool_size=2, strides=2),
    keras.layers.Flatten(),
    keras.layers.Dense(10),
))
mnist.compile()
mnist.summary()

models["mnist"].model = mnist

#### LeNet5

In [12]:
lenet5 = keras.Sequential((
    keras.layers.Input(shape=models["lenet5"].input_shape),
    keras.layers.Conv2D(filters=6, kernel_size=5),
    keras.layers.Activation(lambda x: x*x + (10**6)*x),
    keras.layers.AvgPool2D(pool_size=2, strides=2),
    keras.layers.Conv2D(filters=16, kernel_size=5),
    keras.layers.Activation(lambda x: x*x + (10**15)*x),
    keras.layers.AvgPool2D(pool_size=2, strides=2),
    keras.layers.Flatten(),
    keras.layers.Dense(120),
    keras.layers.Dense(84),
    keras.layers.Dense(10),
))
lenet5.compile()
lenet5.summary()

models["lenet5"].model = lenet5

#### VGG-11

In [13]:
# TODO

## Export models to ONNX

In [14]:
def export(model: Model, path: Optional[Path] = None):
    input_ = torch.rand(1, *model.input_shape)
    torch.onnx.export(
        model.model,  # Actual keras.Model object
        input_,
        str(path or model.onnx_path),
        export_params=True,
        opset_version=10,
        do_constant_folding=True,
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "batch_size"},
            "output": {0: "batch_size"},
        }
    )

In [15]:
for model in models.values():
    export(model)

  shape = tuple(map(lambda x: int(x) if x is not None else None, shape))
  if channels % kernel_in_channels > 0:


## Generate proofs using `ezkl`

In [22]:
def torch_tensor_to_list(tensor: torch.Tensor) -> list[float]:
    return (tensor.detach().numpy()).reshape([-1]).tolist()

In [23]:
def random_input_data(samples: int, *shape: int, scale: int = 1) -> dict[Literal["input_data"], list[list[float]]]:
    data_list = torch_tensor_to_list(
        torch.randn(samples, *shape, requires_grad=True) * scale
    )
    return {"input_data": [data_list]}

In [None]:
for model in models.values():
    py_run_args = ezkl.PyRunArgs()
    py_run_args.input_visibility = "private"
    py_run_args.output_visibility = "public"
    py_run_args.param_visibility = "fixed"  # private by default

    ezkl.gen_settings(
        model.onnx_path, model.output_dir / "settings.json", py_run_args=py_run_args
    )

    if not (path := model.calibration_data_path).exists():
        with path.open("w") as f:
            json.dump(random_input_data(20, *model.input_shape), f)
    await ezkl.calibrate_settings(
        model.calibration_data_path,
        model.onnx_path,
        model.output_dir / "settings.json",
        "resources",
    )

    ezkl.compile_circuit(
        model.onnx_path,
        model.output_dir / "compiled",
        model.output_dir / "settings.json",
    )

    await ezkl.get_srs(model.output_dir / "settings.json")


    if not (path := model.inference_data_path).exists():
        with path.open("w") as f:
            json.dump(random_input_data(1, *model.input_shape), f)
    await ezkl.gen_witness(
        model.inference_data_path,
        model.output_dir / "compiled",
        model.output_dir / "witness.json",
    )
    ezkl.setup(
        model.output_dir / "compiled",
        model.output_dir / "vk",
        model.output_dir / "pk",
    )

    ezkl.prove(
        model.output_dir / "witness.json",
        model.output_dir / "compiled",
        model.output_dir / "pk",
        model.output_dir / "proof",
        "single",
    )

----

## Generate proof for a single test model

In [18]:
data_path = 'input.json'

model_path = "test.onnx"
settings_path = "settings.json"

compiled_model_path = 'test.compiled'

pk_path = 'test.pk'
vk_path = 'test.vk'

witness_path = 'witness.json'


In [19]:
test_model = models["small_cnn"]
test_model

Model(name='small_cnn', input_shape=(10, 10, 1), model=<Sequential name=sequential_1, built=True>, root=PosixPath('/home/suller/ezkl'), onnx_path=PosixPath('/home/suller/ezkl/models/small_cnn.onnx'), calibration_data_path=PosixPath('/home/suller/ezkl/data/2-calibration/small_cnn.json'), inference_data_path=PosixPath('/home/suller/ezkl/data/3-inference/small_cnn.json'), output_dir=PosixPath('/home/suller/ezkl/output/small_cnn'))

In [20]:
export(test_model, path=Path(model_path))

  shape = tuple(map(lambda x: int(x) if x is not None else None, shape))
  if channels % kernel_in_channels > 0:


In [24]:
data_array = torch_tensor_to_list(torch.randn(1, *test_model.input_shape))
data = dict(input_data = [data_array])
with open(data_path, "w") as f:
    json.dump(data, f)

In [25]:
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "private"
py_run_args.output_visibility = "public"
py_run_args.param_visibility = "fixed" # private by default

res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)

RuntimeError: Failed to generate settings: [graph] a node is has misformed params: data or kernel in wrong format

In [None]:
cal_path = os.path.join("calibration.json")

data_array = (torch.rand(20, *test_model.input_shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump(data, open(cal_path, 'w'))


await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")

In [None]:
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
res

In [None]:
# srs path
res = await ezkl.get_srs(settings_path)
res

In [None]:
# now generate the witness file

res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

In [None]:
# HERE WE SETUP THE CIRCUIT PARAMS
# WE GOT KEYS
# WE GOT CIRCUIT PARAMETERS
# EVERYTHING ANYONE HAS EVER NEEDED FOR ZK


res = ezkl.setup(
    compiled_model_path,
    vk_path,
    pk_path,
)

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [None]:
# GENERATE A PROOF


proof_path = os.path.join("test.pf")

res = ezkl.prove(
    witness_path,
    compiled_model_path,
    pk_path,
    proof_path,
    "single",
)

print(res)
assert os.path.isfile(proof_path)

In [None]:
# VERIFY IT

res = ezkl.verify(
    proof_path,
    settings_path,
    vk_path,
)

if res:
    print("verified")