# Malicious scenario

First, Bob will send his data point (without label) to Alice. Alice runs inference on her model and sends the hash of the prediction, with a ZKP, back to Bob. Bob verifies the ZKP, then participates in the MPC protocol to get the loss of the model.

In [2]:
from torch import nn
import ezkl
import os
import json
import torch

In [3]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=2, kernel_size=5, stride=2)
        self.conv2 = nn.Conv2d(in_channels=2, out_channels=3, kernel_size=5, stride=2)

        self.relu = nn.ReLU()

        self.d1 = nn.Linear(48, 48)
        self.d2 = nn.Linear(48, 10)

    def forward(self, x):
        # 32x1x28x28 => 32x32x26x26
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)

        # flatten => 32 x (32*26*26)
        x = x.flatten(start_dim = 1)

        # 32 x (32*26*26) => 32x128
        x = self.d1(x)
        x = self.relu(x)

        # logits => 32x10
        logits = self.d2(x)

        # reduce to single value
        logits = torch.mean(logits, dim=1)

        return logits


circuit = MyModel()


shape = [1, 28, 28]
# After training, export to onnx (network.onnx) and create a data file (input.json)
# x = 0.1*torch.rand(1,*shape, requires_grad=True)
x = torch.ones(1,*shape, requires_grad=True) * 3
print(x.shape)

torch.Size([1, 1, 28, 28])


In [4]:
#Specifying some path parameters
model_path = os.path.join('data','network.onnx')
compiled_model_path = os.path.join('data','network.compiled')
pk_path = os.path.join('data','test.pk')
vk_path = os.path.join('data','test.vk')
settings_path = os.path.join('data','settings.json')

witness_path = os.path.join('data','witness.json')
data_path = os.path.join('data','input.json')
output_path = os.path.join('data','output.json')

In [5]:
# Flips the neural net into inference mode
circuit.eval()

out = circuit(x)
print(out)


    # Export the model
torch.onnx.export(circuit,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump( data, open(data_path, 'w' ))

out_array = out.detach().numpy().tolist()
output = dict(output_data = out_array)

    # Serialize data into file:
json.dump( output, open(output_path, 'w' ))

tensor([0.0303], grad_fn=<MeanBackward1>)


In [6]:
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public" #Bob can see this
py_run_args.output_visibility = "hashed/public" #This hash is given to Bob
py_run_args.param_visibility = "private" 

res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)
assert res == True


In [7]:
cal_path = os.path.join('data',"calibration.json")

#Alice should use some real data to calibrate the model, here we use random data
data_array = (torch.rand(20, *shape, requires_grad=True).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump(data, open(cal_path, 'w'))


await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True
# srs path
res = await ezkl.get_srs( settings_path)
res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path)
assert os.path.isfile(witness_path)

[tensor] decomposition error: integer -317645693 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -1281068661 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -5127811340 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -2563294932 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -10260280118 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -30261717452 is too large to be represented by base 16384 and n 2
forward pas

In [8]:
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
        
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [9]:
# GENERATE A PROOF


proof_path = os.path.join('data','test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        
        "single",
    )

print(res)
assert os.path.isfile(proof_path)

{'instances': [['0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000000000000000000000000000000000000', '0060000000000000000000000000000

In [10]:
# VERIFY IT

res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        
    )

assert res == True
print("verified")

verified
