In [1]:
#Set up
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
import os

#First, a slightly modified version of the model 
#that returns the difference between the top 2 probs of the output.
class LeNetZKP(nn.Sequential):
    """
    Adaptation of LeNet that uses ReLU activations
    """

    # network architecture:
    def __init__(self):
        super(LeNetZKP, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.act = nn.Softmax(dim=1)
        

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.act(x)
        top2_values, _ = torch.topk(x, 2)
        diff = top2_values[:, 0] - top2_values[:, 1]
        return -diff

modelZK = LeNetZKP()
x = torch.randn(1, 3, 32, 32)
y = modelZK(x)
print(y)

tensor([-0.0056], grad_fn=<NegBackward0>)


In [2]:
import torchvision
import random
import torchvision.transforms as transforms
import numpy as np

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=False,transform=transform,download=True)


# Randomly select 100 images as Bob's dataset
indices = random.sample(range(len(trainset)), 100)
selected_images = np.array([trainset[i][0].numpy() for i in indices])
selected_labels = np.array([trainset[i][1]  for i in indices])

Files already downloaded and verified


In [3]:
#Specifying some path parameters
model_path = os.path.join('data','network.onnx')
compiled_model_path = os.path.join('data','network.compiled')
pk_path = os.path.join('data','test.pk')
vk_path = os.path.join('data','test.vk')
settings_path = os.path.join('data','settings.json')

witness_path = os.path.join('data','witness.json')
data_path = os.path.join('data','input.json')
output_path = os.path.join('data','output.json')
label_path = os.path.join('data','label.json')

#Model export 
torch.onnx.export(modelZK,               # model being run
    x,                   # model input (or a tuple for multiple inputs)
    model_path,            # where to save the model (can be a file or file-like object)
    export_params=True,        # store the trained parameter weights inside the model file
    opset_version=10,          # the ONNX version to export the model to
    do_constant_folding=True,  # whether to execute constant folding for optimization
    input_names = ['input'],   # the model's input names
    output_names = ['output'], # the model's output names
    dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                'output' : {0 : 'batch_size'}})

In [4]:
#Settigns, calibration
import ezkl
import json 
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public" #Bob can see this
py_run_args.output_visibility = "public" #This hash is given to Bob
py_run_args.param_visibility = "private" 

res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)
assert res

cal_path = os.path.join('data',"calibration.json")

#Alice should use some real data to calibrate the model, here we use random data
data_array = (trainset.data[:10]).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump(data, open(cal_path, 'w'))

await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res 

# srs path - This actually requires a trusted setup.
res = await ezkl.get_srs(settings_path)
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
    )

assert res
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

[tensor] decomposition error: integer 611118674 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer 390866514 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -611378416 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -1821847314 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -600577992 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -974516168 is too large to be represented by base 16384 and n 2
forward pass failed

In [16]:
point = selected_images[8]
point_tensor = torch.tensor(point).unsqueeze(0)
data_array = point_tensor.detach().numpy().reshape([-1]).tolist()
data = dict(input_data = [data_array])
json.dump(data, open(data_path, 'w'))
modelZK.eval()
output = modelZK(point_tensor)
print(output)

tensor([-0.0034], grad_fn=<NegBackward0>)


In [17]:
res = ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert res
proof_path = os.path.join('data','test.pf')
res = ezkl.prove(witness_path, compiled_model_path, pk_path, proof_path, "single")
assert res
res = ezkl.verify(proof_path, settings_path, vk_path)
assert res