# Malicious scenario

In the malicious scenario, we do not use an MPC protocol due to the very high computational cost. We however can rely on Alice computing the model inference offline to obtain y' (prediction on x), then participate in the MPC protocol to compute the loss. Given that Alice is malicious, we ask for a ZKP to verify that the prediction is correct.

First, Bob will send his data point (without label) to Alice. Alice runs inference on her model and sends the hash of the prediction, with a ZKP, back to Bob. 

# Part 0: The Setup

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import ezkl
import os
import json

In [5]:
class LeNet(nn.Sequential):
    """
    Adaptation of LeNet that uses ReLU activations
    """

    # network architecture:
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.act = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.act(x)
        return x


circuit = LeNet()

#Next, we define the data loader for CIFAR-10 dataset.
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=False,transform=transform)
data, lbl  = trainset[4]
classes = trainset.classes  # ['airplane', 'automobile', 'bird', ..., 'truck']
class_name = classes[lbl]

label_eye = torch.eye(10)
y = label_eye[lbl]

x = data.unsqueeze(0)  # Add batch dimension

In [6]:
#Specifying some path parameters
model_path = os.path.join('data','network.onnx')
compiled_model_path = os.path.join('data','network.compiled')
pk_path = os.path.join('data','test.pk')
vk_path = os.path.join('data','test.vk')
settings_path = os.path.join('data','settings.json')

witness_path = os.path.join('data','witness.json')
data_path = os.path.join('data','input.json')
output_path = os.path.join('data','output.json')
label_path = os.path.join('data','label.json')

In [7]:
# Flips the neural net into inference mode
circuit.eval()

out = circuit(x)
print(out)


    # Export the model
torch.onnx.export(circuit,               # model being run
                      x,                   # model input (or a tuple for multiple inputs)
                      model_path,            # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=10,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names = ['input'],   # the model's input names
                      output_names = ['output'], # the model's output names
                      dynamic_axes={'input' : {0 : 'batch_size'},    # variable length axes
                                    'output' : {0 : 'batch_size'}})

data_array = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_data = [data_array])

    # Serialize data into file:
json.dump( data, open(data_path, 'w' ))

out_array = out.detach().numpy().tolist()
output = dict(output_data = out_array)

    # Serialize data into file:
json.dump( output, open(output_path, 'w' ))

label = y.detach().numpy().tolist()
label = dict(label = label)

json.dump( label, open(label_path, 'w' ))

tensor([[0.1084, 0.0992, 0.0921, 0.0948, 0.1103, 0.1035, 0.0966, 0.0898, 0.0956,
         0.1097]], grad_fn=<SoftmaxBackward0>)


# Part 1: The ZKP

In [8]:
py_run_args = ezkl.PyRunArgs()
py_run_args.input_visibility = "public" #Bob can see this
py_run_args.output_visibility = "hashed/public" #This hash is given to Bob
py_run_args.param_visibility = "private" 

res = ezkl.gen_settings(model_path, settings_path, py_run_args=py_run_args)
assert res == True


In [9]:
import random
import numpy as np
cal_path = os.path.join('data',"calibration.json")

indices = random.sample(range(len(trainset)), 10)
cal_images = np.array([trainset[i][0].numpy() for i in indices])
#Alice should use some real data to calibrate the model, here we use random data
data_array = (cal_images).reshape([-1]).tolist()

data = dict(input_data = [data_array])

# Serialize data into file:
json.dump(data, open(cal_path, 'w'))


await ezkl.calibrate_settings(cal_path, model_path, settings_path, "resources")
res = ezkl.compile_circuit(model_path, compiled_model_path, settings_path)
assert res == True

# srs path - This actually requires a trusted setup.
res = await ezkl.get_srs( settings_path)
res = await ezkl.gen_witness(data_path, compiled_model_path, witness_path) #Put this later
assert os.path.isfile(witness_path)

[tensor] decomposition error: integer 920493118 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -2986058034 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer 756237951 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -4199595594 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer -7495443061 is too large to be represented by base 16384 and n 2
forward pass failed: "failed to forward: [halo2] General synthesis error"
[tensor] decomposition error: integer 571055396 is too large to be represented by base 16384 and n 2
forward pass faile

In [10]:
res = ezkl.setup(
        compiled_model_path,
        vk_path,
        pk_path,
    )

assert res == True
assert os.path.isfile(vk_path)
assert os.path.isfile(pk_path)
assert os.path.isfile(settings_path)

In [12]:
import time
start = time.time()
# GENERATE A PROOF
proof_path = os.path.join('data','test.pf')

res = ezkl.prove(
        witness_path,
        compiled_model_path,
        pk_path,
        proof_path,
        "single",
    )

print(res)
assert os.path.isfile(proof_path)

end = time.time()
print("Time taken to generate proof: ", end-start)

{'instances': [['51f0ffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '93f1ffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '0decffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '89e7ffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'c7e5ffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '0beaffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '0cebffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '4cebffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '4cebffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '0beaffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '84e2ffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'c5e3ffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '0cebffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', '4eedffef93f5e1439170b97948e833285d588181b64550b829a031e1724e6430', 'ceecffef93f5e1439170b97948e8332

In [14]:
# VERIFY IT
import time
start = time.time()
res = ezkl.verify(
        proof_path,
        settings_path,
        vk_path,
        
    )
assert res == True
print("verified")
end = time.time()
print("Time taken to verify proof: ", end-start)

verified
Time taken to verify proof:  0.4413723945617676


# Step 2: Secure MPC to calculate valuation

After Alice has computed $y'$, Alice and Bob can then compute the loss function in a secure MPC protocol.
The MPC protocol supposedly requires verifying the Poseidon hash. But we skipped this for now.

In [15]:
#Here is the plain text computation of squared loss
mpc_pred = out.detach().numpy()[0]
mpc_gt = y.detach().numpy()
sq_loss = 0
for i in range(10):
    sq_loss += (mpc_pred[i] - mpc_gt[i])**2
print(sq_loss)

0.90723324


In [16]:
# Here is the plain text computation for cross entropy loss
import numpy as np

mpc_pred = out.detach().numpy()[0]
mpc_gt = y.detach().numpy()

cross_entropy_loss = -1 * np.log(mpc_pred[np.argmax(mpc_gt)])
print(cross_entropy_loss)

2.336796


To run the MPC protocol, we need to use the SP-MPDZ library. MP-SPDZ supports multiple MPC protocols with different assumptions. We will use the MASCOT protocol, which supports malicious security.

In [17]:
# Prepare input data for MPC
if not os.path.exists('../../MP-SPDZ/Player-Data'):
    os.makedirs('../../MP-SPDZ/Player-Data')
p0_path = os.path.join('../../MP-SPDZ/Player-Data','Input-P0-0')
p1_path = os.path.join('../../MP-SPDZ/Player-Data','Input-P1-0')

with open(p0_path, 'w') as f:
    for i in mpc_gt:
        f.write(f'{i:.6f} ')
    f.write('\n')
with open(p1_path, 'w') as f:
    for i in mpc_pred:
        f.write(f'{i:.6f} ')
    f.write('\n')

In [19]:
# We have prepared the code for MPC in ../../MP-SPDZ/Programs/Source/dataval.mpc and ../MP-SPDZ/Programs/Source/dataval_ce.mpc
# Here we compile the MPC code
! cd ../../MP-SPDZ/ && ./compile.py dataval
! cd ../../MP-SPDZ/ && ./compile.py dataval_ce

Default bit length for compilation: 64
Default security parameter for compilation: 40
Compiling file Programs/Source/dataval.mpc


Writing to Programs/Bytecode/dataval-TruncPr(1)_47_16-1.bc
Writing to Programs/Schedules/dataval.sch
Writing to Programs/Bytecode/dataval-0.bc
Hash: d581f7c5c17b97640f4f824c5eedff66873a78cb835793dd48ef9ca82a202bfd
Program requires at most:
          10 integer inputs from player 0
          10 integer inputs from player 1
          11 integer opens
          10 integer triples
         870 integer bits
          32 virtual machine rounds
Default bit length for compilation: 64
Default security parameter for compilation: 40
Compiling file Programs/Source/dataval_ce.mpc
Writing to Programs/Bytecode/dataval_ce-TruncPr(1)_47_16-1.bc
Writing to Programs/Bytecode/dataval_ce-log2_fx(1)_31_16-3.bc
Writing to Programs/Bytecode/dataval_ce-FPDiv(1)_31_16-5.bc
Writing to Programs/Schedules/dataval_ce.sch
Writing to Programs/Bytecode/dataval_ce-0.bc
Hash: 9c484e1e427fde1fbd5ef1625cd7bcb41893e2b581946f6784ac95c242435d45
Program requires at most:
          10 integer inputs from player 0
          10 

In [20]:
#MPC for squared loss
import time

start = time.time()
! cd ../../MP-SPDZ/ && Scripts/mascot.sh dataval
end = time.time()
print(f"Time taken for squared loss computation: {end-start}")

Running /home/thomas/secure-data-valuation/MP-SPDZ/Scripts/../mascot-party.x 0 dataval -pn 19776 -h localhost -N 2
Running /home/thomas/secure-data-valuation/MP-SPDZ/Scripts/../mascot-party.x 1 dataval -pn 19776 -h localhost -N 2


Using statistical security parameter 40
Squared loss: 0.907242
The following benchmarks are including preprocessing (offline phase).
Time = 0.278077 seconds 
Data sent = 37.6783 MB in ~159 rounds (party 0 only; use '-v' for more details)
Global data sent = 75.3566 MB (all parties)
This program might benefit from some protocol options.
Consider adding the following at the beginning of your code:
	program.use_edabit(True)
Time taken for squared loss computation: 0.6344389915466309


In [21]:
start = time.time()
! cd ../../MP-SPDZ/ && Scripts/mascot.sh dataval_ce
end = time.time()
print(f"Time taken for cross entropy loss computation: {end-start}")

Running /home/thomas/secure-data-valuation/MP-SPDZ/Scripts/../mascot-party.x 0 dataval_ce -pn 17632 -h localhost -N 2
Running /home/thomas/secure-data-valuation/MP-SPDZ/Scripts/../mascot-party.x 1 dataval_ce -pn 17632 -h localhost -N 2
Using statistical security parameter 40


Cross Entropy loss: 2.33684
The following benchmarks are including preprocessing (offline phase).
Time = 0.66485 seconds 
Data sent = 118.858 MB in ~535 rounds (party 0 only; use '-v' for more details)
Global data sent = 237.716 MB (all parties)
This program might benefit from some protocol options.
Consider adding the following at the beginning of your code:
	program.use_edabit(True)
Time taken for cross entropy loss computation: 0.9749960899353027
