In [1]:
import sys
sys.path.append('../')

import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F 
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate

from typing import List, Tuple
import random
from uuid import uuid4
import numpy as np
import pandas as pd
import copy

import syft as sy

from util import Client, Server

hook = sy.TorchHook(torch)

## Dataset Class

In [2]:
class Dataset(torch.utils.data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, ids, data, labels, *args, **kwargs):
        'Initialization'
        super().__init__(*args, **kwargs)
        self.ids = ids
        self.data = data
        self.labels = labels

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.ids)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Load data and get label
        uuid = self.ids[index]
        if self.data is not None:
            X = self.data[index]
        else:
            X = None
        if self.labels is not None:
            y = self.labels[index]
        else:
            y = None
        return uuid, X, y
    
    def get_ids(self) -> List[str]:
        """Return a list of the ids of this dataset."""
        return [str(_) for _ in self.ids]
    
    def sort_by_ids(self):
        """Sort the dataset by IDs in ascending order"""
        ids = self.get_ids()
        sorted_idxs = np.argsort(ids)

        if self.data is not None:
            self.data = self.data[sorted_idxs] 

        if self.labels is not None:
            self.labels = self.labels[sorted_idxs]

        self.ids = self.ids[sorted_idxs]

In [3]:
def partition_dataset(
    dataset: Dataset,
    keep_order: bool = False,
) -> Tuple[Dataset, Dataset]:
    'Vertically partition a torch dataset in two'
    partition1 = copy.deepcopy(dataset)
    partition2 = copy.deepcopy(dataset)
    
    # p1 has all features, p2 has all labels
    partition1.labels = None
    partition2.data = None
    
    # disorder indexing
    idxs1 = np.arange(len(partition1)) 
    idxs2 = np.arange(len(partition2))
    
    if not keep_order:
        np.random.shuffle(idxs1)
        np.random.shuffle(idxs2)
        
    partition1.data = partition1.data[idxs1]
    partition1.ids = partition1.ids[idxs1]

    partition2.labels = partition2.labels[idxs2]
    partition2.ids = partition2.ids[idxs2]
    
    return partition1, partition2

## Partitioned Data Class

In [4]:
class VerticalDataLoader:
    def __init__(self, dataset, *args, **kwargs):
        # Split datasets
        self.partition1, self.partition2 = partition_dataset(dataset)
        
    def __len__(self):
        return len(self.partition1)
    
    def drop_non_intersection(self, intersection: List[int]):
        """Remove elements and ids in the datasets that are not in the intersection."""
        self.partition1.data = self.partition1.data[intersection]
        self.partition1.ids = self.partition1.ids[intersection]

        self.partition2.labels = self.partition2.labels[intersection]
        self.partition2.ids = self.partition2.ids[intersection]
        
    def sort_by_ids(self) -> None:
        """Sort each dataset by ids"""
        self.partition1.sort_by_ids()
        self.partition2.sort_by_ids()

## Load Data

In [5]:
# Parameters
params = {'batch_size': 1,
          'shuffle': True,
          'num_workers': 6}

max_epochs = 100

# Dataset
ids = np.array([uuid4() for i in range(10)])
partition = torch.randn((10, 30))
labels = torch.randint(0, 2, (10,))

# Generator
data = Dataset(ids, partition, labels)
dataloader = VerticalDataLoader(data)

## Implement PSI and order the datasets accordingly

In [6]:
if dataloader.partition1.ids[:].all() != dataloader.partition2.ids[:].all():
    print("Patitioned data is disordered")
    
# Compute private set intersection
client_items = dataloader.partition1.get_ids()
server_items = dataloader.partition2.get_ids()

client = Client(client_items)
server = Server(server_items)

setup, response = server.process_request(client.request, len(client_items))
intersection = client.compute_intersection(setup, response)

# Order data
dataloader.drop_non_intersection(intersection)
dataloader.sort_by_ids()

if dataloader.partition1.ids[:].all() == dataloader.partition2.ids[:].all():
    print("Patitioned data is aligned")

Patitioned data is aligned


In [7]:
class Parser:
    def __init__(self):
        self.epochs = 10
        self.lr = 0.01
        self.seed = 0
        self.input_size = 30 # 30 dimensions
        self.hidden_sizes = [64, 16, 4] # can be altered
        self.output_size = 2 # 0 or 1
    
args = Parser()
torch.manual_seed(args.seed)

<torch._C.Generator at 0x7feabca1b0b0>

In [19]:
class Net():
    def __init__(self, models, optimizers):
        super().__init__()
        self.models = models
        self.optimizers = optimizers
        
        self.data = []
        self.remote_tensors = []
        
        print(self.models)
        print('---------------------------')
        print(self.optimizers)
        
    def forward(self, x):
        data = []
        remote_tensors = []

#         data.append(self.models[0](x))
        print(data)
        return x
    
    def backward(self):
        
        return
    
    def zero_grads(self):
        for opt in self.optimizers:
            opt.zero_grad()

    def step(self):
        for opt in self.optimizers:
            opt.step()
        

In [20]:
# create workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
workers = [alice, bob]

# create models
models = [
    nn.Sequential(
        nn.Linear(args.input_size, args.hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(args.hidden_sizes[0], args.hidden_sizes[1]),
        nn.ReLU(),
        nn.Linear(args.hidden_sizes[1], args.hidden_sizes[2]),
        nn.ReLU(),
    ),
    nn.Sequential(nn.Linear(args.hidden_sizes[2], args.output_size), nn.LogSoftmax(dim=1)),
]

# init optimizers
optimizers = [optim.SGD(model.parameters(), lr=args.lr,) for model in models]

# send models to each working node
for model, worker in zip(models, workers):
    model.send(worker)
    
# init splitNN
splitNN = Net(models, optimizers)

In [21]:
def train(features, labels, network):
    
    #1) Zero our grads
    network.zero_grads()
    
    #2) Make a prediction
    pred = network.forward(features)
    
    #3) Figure out how much we missed by
    criterion = nn.MSELoss()
    loss = criterion(pred, labels)
    
    #4) Backprop the loss on the end layer
    loss.backward()
    
    #5) Feed Gradients backward through the network
    network.backward()
    
    #6) Change the weights
    network.step()
    
    return loss, pred

In [32]:
for epoch in range(args.epochs):
    running_loss = 0
    correct_preds = 0
    total_preds = 0
    
#     for (ids1, features, labels_are_none), (ids2, features_are_none, label) in zip(dataloader.partition1, dataloader.partition2):
        # format data
    features = dataloader.partition1.data.send(models[0].location)
#         features = features.view(len(features[0]), -1)
    labels = dataloader.partition2.labels.send(models[-1].location)
#         labels = labels.view(len(labels), 1)

    print(features, labels)

    # training
    loss, preds = train(features, labels, splitNN)

    # Collect statistics
    running_loss += loss.get()
    correct_preds += preds.max(1)[1].eq(labels).sum().get().item()
    total_preds += preds.get().size(0)

    print(f"Epoch {i} - Training loss: {running_loss/len(dataloader):.3f} - Accuracy: {100*correct_preds/total_preds:.3f}")
        
        

(Wrapper)>[PointerTensor | me:93128013115 -> alice:87220430681] (Wrapper)>[PointerTensor | me:53112192320 -> bob:30425425363]
[(Wrapper)>[PointerTensor | me:2543323839 -> alice:16020958267]]


AttributeError: 'PointerTensor' object has no attribute 'size'