# Semi-honest, multi-point scenario

In this scenario, we assume Bob has multiple data points to contribute to Alice's ML model. Now Alice is trying to value the dataset as a whole, judging on the diversity, uncertainty of the datasets as well as the current model's performance on the dataset. 

# Part 0: The setup

In [1]:
#First, we define Alice's model M. We assume a simple CNN model.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim
import matplotlib.pyplot as plt
import os

#Don't use GPU for now
os.environ["CUDA_VISIBLE_DEVICES"] = ""

class LeNet(nn.Sequential):
    """
    Adaptation of LeNet that uses ReLU activations
    """

    # network architecture:
    def __init__(self):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.act = nn.Softmax()
        

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.act(x)
        return x
    
model = LeNet()

os.makedirs('data', exist_ok=True)
torch.save(model.state_dict(), 'data/model.pth')
torch.save(model, 'data/alice_model.pth')

#Next, we define the data loader for CIFAR-10 dataset.
import torchvision
import random
import torchvision.transforms as transforms
import numpy as np

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=False,transform=transform,download=True)


# Randomly select 100 images as Bob's dataset
indices = random.sample(range(len(trainset)), 100)
selected_images = np.array([trainset[i][0].numpy() for i in indices])
selected_labels = np.array([trainset[i][1]  for i in indices])

# Save images and labels separately
# torch.save(selected_images, 'data/selected_images.pth')
# torch.save(selected_labels, 'data/selected_labels.pth')


Files already downloaded and verified


# Part 1: Clustering

Before submitting points to Alice for evaluation, Bob needs to select a subset of representative data points. To do this, we recommend using K-means clustering to select a diverse set of points where K is defined by the number of data points Alice wishs to check. Bob can select a data point closest to the centroid of each cluster. It is ultimately up to Bob to decide which points to submit, even if they are not ideal so we do not need to securely compute this step.

We further make an enhancement to pure K-means selection by trying to select the most uncertain points in each cluster. As determining the uncertainty requires model inference, we define a computing budget B which is the number of points Bob and Alice can afford to evaluate. Bob can then strategically select some points in each cluster to calculate its uncertainty, and submit the points with the highest uncertainty  to Alice. 


In [3]:
#First, we run the Kmeans clustering algorithm locally on Bob's device
from sklearn.cluster import KMeans

# Set the number of clusters
K = 10

# Reshape the images to be a 2D array (each image is flattened)
flattened_images = selected_images.reshape(selected_images.shape[0], -1)

# Perform K-means clustering
kmeans = KMeans(n_clusters=K, random_state=0).fit(flattened_images)

# Get the cluster labels
cluster_labels = kmeans.labels_

# Get the cluster centers
cluster_centers = kmeans.cluster_centers_

print("Cluster centers shape:", cluster_centers.shape)

Cluster centers shape: (10, 3072)


After clustering, we loop through each cluster and select the point with the highest uncertainty. 

In [5]:
#First, we define the uncertainty function with Crypten. 
import crypten
import crypten.mpc as mpc
import torch

crypten.init()
torch.set_num_threads(1)
ALICE = 0
BOB = 1
# Save the plaintext toy data
BOB_INPUT_PATH = 'data/bob_dp.pth'
BOB_LABEL_PATH = 'data/bob_label.pth'
ALICE_INPUT_PATH = 'data/alice_model.pth'
MPC_OUTPUT_PATH = 'data/mpc_output.pth'
INPUT_SIZE = (1,3,32,32)
crypten.common.serial.register_safe_class(LeNet)

@mpc.run_multiprocess(world_size=2)
def uncertainty_mpc():
    #Load alice model
    model = crypten.load_from_party(ALICE_INPUT_PATH, src=ALICE)
    dummy_input = torch.empty(INPUT_SIZE)
    private_model = crypten.nn.from_pytorch(model, dummy_input)
    private_model.encrypt(src=ALICE)
    #Load Bob's data point
    data = crypten.load_from_party(BOB_INPUT_PATH, src=BOB)
    data = data.unsqueeze(0)
    #Run model inference
    private_model.eval()
    result = private_model(data)
    maxval = result.max()
    mask_max = (result >= maxval)
    n_max = mask_max.sum()
    cond = n_max > 1
    result_masked = result - (mask_max * crypten.cryptensor(2))
    second_max_val = result_masked.max()
    diff_raw = maxval - second_max_val
    diff_final = -diff_raw * (crypten.cryptensor(1) - cond)
    # result = data.get_plain_text()[0]
    crypten.save_from_party(diff_final.get_plain_text(),MPC_OUTPUT_PATH,src=BOB)
    # return data

  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


[None, None]

In [6]:
TOTAL_BUDGET = 30 #This means maximum of 3 queries per cluster
budget = TOTAL_BUDGET // K
# Loop through each cluster
points_to_submit = []
labels_to_submit = []
for cluster_idx in range(K):
    # Get the indices of points in the current cluster
    cluster_points_indices = np.where(cluster_labels == cluster_idx)[0]
    
    # Get the points in the current cluster
    cluster_points = flattened_images[cluster_points_indices]
    
    # Calculate the distance of each point to the cluster center
    distances = np.linalg.norm(cluster_points - cluster_centers[cluster_idx], axis=1)
    
    # Sort the points by distance (from closest to furthest)
    sorted_indices = np.argsort(distances)
    
    used_budget = 0
    max_uncertainty = -999
    best_point = None
    best_label = None
    for idx in sorted_indices:
        if used_budget >= budget:
            if best_point is None:
                best_point = cluster_points[idx].reshape(3,32,32) #Simply choose the point closest if no point can be queried.
                best_label = selected_labels[cluster_points_indices[idx]]
            break
        print(f"Point index: {cluster_points_indices[idx]}, Distance: {distances[idx]}")
        point = cluster_points[idx].reshape(3,32,32)
        label = selected_labels[cluster_points_indices[idx]]
        #Reshape the point back to input shape
        point_tensor = torch.tensor(point)
        torch.save(point_tensor, BOB_INPUT_PATH)
        uncertainty_mpc()
        answer = torch.load(MPC_OUTPUT_PATH).numpy().item()
        if answer > max_uncertainty:
            max_uncertainty = answer
            best_point = point
            best_label = label
        used_budget += 1
    points_to_submit.append(best_point)    
    labels_to_submit.append(best_label)
assert len(points_to_submit) == K

Point index: 54, Distance: 15.268499374389648


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 68, Distance: 15.268500328063965


  answer = torch.load(MPC_OUTPUT_PATH).numpy().item()
  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 35, Distance: 17.567258834838867


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 34, Distance: 17.69978904724121


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 83, Distance: 17.711179733276367


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 74, Distance: 12.884759902954102


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 40, Distance: 13.395118713378906


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 61, Distance: 13.67123794555664


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 47, Distance: 13.629413604736328


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 33, Distance: 13.630913734436035


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 53, Distance: 15.057271957397461


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 58, Distance: 12.117508888244629


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 63, Distance: 17.469074249267578


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 0, Distance: 17.98807716369629


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 27, Distance: 11.064971923828125


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 95, Distance: 12.358888626098633


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 97, Distance: 17.00782585144043


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))


Point index: 88, Distance: 16.52522087097168


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 44, Distance: 17.36598014831543


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 93, Distance: 18.34111785888672


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 64, Distance: 18.631057739257812


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 4, Distance: 19.063385009765625


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 6, Distance: 21.14727783203125


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 18, Distance: 15.904213905334473


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 79, Distance: 15.904213905334473


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 45, Distance: 16.16516876220703


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 59, Distance: 18.01510238647461


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


Point index: 84, Distance: 18.15633773803711


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))
  result = load_closure(f, **kwargs)


In [7]:
bob_rep_points = torch.tensor(np.array(points_to_submit))
torch.save(bob_rep_points, BOB_INPUT_PATH)
#Turn the labels into one-hot encoding
num_classes = 10
labels_one_hot = torch.zeros(len(labels_to_submit), num_classes)
labels_one_hot.scatter_(1, torch.tensor(labels_to_submit).unsqueeze(1), 1)
torch.save(labels_one_hot, BOB_LABEL_PATH)

#Normal calculation of pairwise distance
bs = bob_rep_points.shape[0]
images_flat = bob_rep_points.view(bs, -1)
std_per_feature = images_flat.std(dim=0)
mean_std = std_per_feature.mean()
print(mean_std)
# pw_dist = torch.pdist(images_flat,p=2)
# apd = pw_dist.mean()
# print("apd1", apd)

# x_expanded = images_flat.unsqueeze(1)
# y_expanded = images_flat.unsqueeze(0)
# diff = x_expanded - y_expanded
# dist_matrix = diff.norm(p=2,dim=2)
# print(dist_matrix)
# mask_upper = torch.triu(torch.ones(bs,bs), diagonal=1)
# pairwise = dist_matrix * mask_upper
# apd = pairwise.sum() / mask_upper.sum()
# print("apd2", apd)
# ata_flattened = data.view(n, -1)
#     x_expanded = data_flattened.unsqueeze(1)
#     y_expanded = data_flattened.unsqueeze(0)
#     # crypten.print("Y expanded shape: ", y_expanded.get_plain_text().shape)
    
#     #Calculate the pairwise distance
#     diff = x_expanded - y_expanded
    
#Normalize by max distance between 0 and 1
# shape = bob_rep_points.shape[1:]
# all_zeros = torch.zeros(shape)
# all_ones = torch.ones(shape)
# max_dist = torch.norm(all_zeros - all_ones)
# print(apd, max_dist, apd/max_dist)


tensor(0.4620)


After clustering, Bob and Alices together computes an overall valuation score for the dataset:
$$v = \alpha_{1}s_1 + \alpha_2 s_2 + \alpha_3 s_3$$
Where:  
$s_1$ is the diversity score of the dataset, calculated as the average distance between each pair of points in the representative dataset.  
$s_2$ is the uncertainty score of the dataset, calculated as the average uncertainty of the points in the representative dataset.  
$s_3$ is the model performance score of the dataset, calculated as the average model loss of the points in the representative dataset.

In [12]:
alpha_1 = 0.3
alpha_2 = 0.3
alpha_3 = 0.4

#Now we define the function to calculate the valuation.
@mpc.run_multiprocess(world_size=2)
def data_valuation():
    #Load Bob's data point
    data = crypten.load_from_party(BOB_INPUT_PATH, src=BOB)
    label = crypten.load_from_party(BOB_LABEL_PATH, src=BOB)
    #Number of vectors
    n = data.size(0)
    #Flatten the data
    # crypten.print("Data: ", data.get_plain_text())
    data_flattened = data.view(n, -1)
    mean_per_feature = data_flattened.mean(dim=0)
    mean_sq_per_feature = (data_flattened * data_flattened).mean(dim=0)
    var_per_feature = mean_sq_per_feature - mean_per_feature * mean_per_feature
    
    #Calculate the diversity score
    std_per_feature = var_per_feature.sqrt()
    diversity_score = std_per_feature.mean()
    
    crypten.print("Diversity score: ", diversity_score.get_plain_text())
    # crypten.print("Max pairwise distance: ", max_pairwise.get_plain_text())

    #Calculate the Uncertainty score next
    #Load alice model
    model = crypten.load_from_party(ALICE_INPUT_PATH, src=ALICE)
    dummy_input = torch.empty(INPUT_SIZE)
    private_model = crypten.nn.from_pytorch(model, dummy_input)
    private_model.encrypt(src=ALICE)
    
    #Run model inference
    private_model.eval()
    result = private_model(data)
    entropy = -result * result.log()
    uncertainty_score = entropy.sum() / n
    crypten.print("Uncertainty score: ", uncertainty_score.get_plain_text())

    #Calculate model loss
    loss = crypten.nn.CrossEntropyLoss()
    output = private_model(data)
    loss_score = loss(output,label)
    crypten.print("Loss score: ", loss_score.get_plain_text())
    
    #Total score
    total_score = alpha_1 * diversity_score + alpha_2 * uncertainty_score + alpha_3 * loss_score
    crypten.print("Final valuation is encrypted:", crypten.is_encrypted_tensor(total_score))
    crypten.print("Final Valuation:", total_score.get_plain_text())
    
    
data_valuation()

  result = load_closure(f, **kwargs)


Diversity score:  tensor(0.4378)


  result = load_closure(f, **kwargs)
  return forward_call(*args, **kwargs)
  return forward_call(*args, **kwargs)
  param = torch.from_numpy(numpy_helper.to_array(node))
  param = torch.from_numpy(numpy_helper.to_array(node))


Uncertainty score:  tensor(2.3114)
Loss score:  tensor(2.3150)
Final valuation is encrypted: True
Final Valuation: tensor(1.7507)


[None, None]