In [289]:
import sys
sys.path.append('../')

import torch
from torchvision import datasets, transforms
from torch import nn, optim
import torch.nn.functional as F 
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.utils.data._utils.collate import default_collate

from typing import List, Tuple
import random
from uuid import uuid4, UUID
import numpy as np
import pandas as pd
import copy

import syft as sy

from util import Client, Server

hook = sy.TorchHook(torch)



## Dataset Class

In [290]:
class VerticalDataset(Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, ids, data, targets, *args, **kwargs):
        'Initialization'
        super().__init__(*args, **kwargs)
        self.ids = ids
        self.data = data
        self.targets = targets

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.ids)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Load data and get label
        uuid = self.ids[index]
        if self.data is not None:
            X = self.data[index]
        else:
            X = None
        if self.targets is not None:
            y = self.targets[index]
        else:
            y = None
        return (*filter(lambda x: x is not None, (uuid, X, y)),)
    
    def get_ids(self) -> List[str]:
        """Return a list of the ids of this dataset."""
        return [str(_) for _ in self.ids]
    
    def sort_by_ids(self):
        """Sort the dataset by IDs in ascending order"""
        ids = self.get_ids()
        sorted_idxs = np.argsort(ids)

        if self.data is not None:
            self.data = self.data[sorted_idxs] 

        if self.targets is not None:
            self.targets = self.targets[sorted_idxs]

        self.ids = self.ids[sorted_idxs]

In [291]:
def add_ids(cls):
    """Decorator to add unique IDs to a dataset

    Args:
        cls (torch.utils.data.Dataset) : dataset to generate IDs for

    Returns:
        VerticalDataset : A class which wraps cls to add unique IDs as an attribute,
            and returns data, target, id when __getitem__ is called
    """
    class VerticalDataset(cls):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

            self.ids = np.array([uuid4() for _ in range(len(self))])

        def __getitem__(self, index):
            uuid = self.ids[index]

            if self.data is not None:
                X = self.data[index]
            else:
                X = None

            if self.targets is not None:
                y = self.targets[index]
            else:
                y = None
            return (*filter(lambda x: x is not None, (uuid, X, y)),)

        def __len__(self):
            if self.data is not None:
                return self.data.size(0)
            else:
                return len(self.targets)

        def get_ids(self) -> List[str]:
            """Return a list of the ids of this dataset."""
            return [str(id_) for id_ in self.ids]

        def sort_by_ids(self):
            """
            Sort the dataset by IDs in ascending order
            """
            ids = self.get_ids()
            sorted_idxs = np.argsort(ids)

            if self.data is not None:
                self.data = self.data[sorted_idxs]

            if self.targets is not None:
                self.targets = self.targets[sorted_idxs]

            self.ids = self.ids[sorted_idxs]

    return VerticalDataset

In [292]:
def partition_dataset(
    dataset: Dataset,
    keep_order: bool = False,
) -> Tuple[Dataset, Dataset]:
    'Vertically partition a torch dataset in two'
    partition1 = copy.deepcopy(dataset)
    partition2 = copy.deepcopy(dataset)
    
    # p1 has all features, p2 has all targets
    partition1.targets = None
    partition2.data = None
    
    # disorder indexing
    idxs1 = np.arange(len(partition1)) 
    idxs2 = np.arange(len(partition2))
    
    if not keep_order:
        np.random.shuffle(idxs1)
        np.random.shuffle(idxs2)
        
    partition1.data = partition1.data[idxs1]
    partition1.ids = partition1.ids[idxs1]

    partition2.targets = partition2.targets[idxs2]
    partition2.ids = partition2.ids[idxs2]
    
    return partition1, partition2

In [293]:
def id_collate_fn(batch: Tuple) -> List:
    """Collate data, targets and IDs  into batches

    This custom function is necessary as default collate
    functions cannot handle UUID objects.

    Args:
        batch (tuple of (data, target, id) tuples) : tuple of data returns from each index call
            to the dataset in a batch. To be turned into batched data

    Returns:
        list : List of batched data objects:
            data (torch.Tensor), targets (torch.Tensor), IDs (tuple of strings)
    """
    results = []

    for samples in zip(*batch):
        if isinstance(samples[0], UUID):
            # Turn into a tuple of strings
            samples = (*map(str, samples),)

        # Batch data
        results.append(default_collate(samples))
    return results

In [294]:
class SinglePartitionDataLoader(DataLoader):
    """DataLoader for a single vertically-partitioned dataset"""

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.collate_fn = id_collate_fn

## Partitioned Data Class

In [295]:
class VerticalDataLoader:
    def __init__(self, dataset, *args, **kwargs):
        # Split datasets
        self.partition1, self.partition2 = partition_dataset(dataset)
        
        assert self.partition1.targets is None
        assert self.partition2.data is None
        
        self.dataloader1 = SinglePartitionDataLoader(self.partition1, *args, **kwargs)
        self.dataloader2 = SinglePartitionDataLoader(self.partition2, *args, **kwargs)
        
    def __len__(self):
        return (len(self.dataloader1) + len(self.dataloader2)) // 2
    
    def __iter__(self):
        return zip(self.dataloader1, self.dataloader2)
    
    def drop_non_intersection(self, intersection: List[int]):
        """Remove elements and ids in the datasets that are not in the intersection."""
        self.dataloader1.dataset.data = self.dataloader1.dataset.data[intersection]
        self.dataloader1.dataset.ids = self.dataloader1.dataset.ids[intersection]

        self.dataloader2.dataset.targets = self.dataloader2.dataset.targets[intersection]
        self.dataloader2.dataset.ids = self.dataloader2.dataset.ids[intersection]
        
    def sort_by_ids(self) -> None:
        """Sort each dataset by ids"""
        self.dataloader1.dataset.sort_by_ids()
        self.dataloader2.dataset.sort_by_ids()

## Load Data

In [296]:
# Parameters
params = {'batch_size': 1,
          'shuffle': True,
          'num_workers': 6}

max_epochs = 100

# Dataset
ids = np.array([uuid4() for i in range(10)])
features = torch.randn((10, 30))
labels = torch.randint(0, 2, (10,))

# Generator
data = VerticalDataset(ids, features, labels)
# data = add_ids(Dataset(features, labels))
dataloader = VerticalDataLoader(data)

TypeError: Dataset() takes no arguments

In [None]:
print(vars(dataloader.dataloader1.dataset))

print(vars(dataloader.dataloader2.dataset))

## Implement PSI and order the datasets accordingly

In [None]:
if dataloader.dataloader1.dataset.ids[:].all() != dataloader.dataloader2.dataset.ids[:].all():
    print("Partitioned data is disordered")
    
# Compute private set intersection
client_items = dataloader.dataloader1.dataset.get_ids()
server_items = dataloader.dataloader2.dataset.get_ids()

client = Client(client_items)
server = Server(server_items)

setup, response = server.process_request(client.request, len(client_items))
intersection = client.compute_intersection(setup, response)

# Order data
dataloader.drop_non_intersection(intersection)
dataloader.sort_by_ids()

if dataloader.dataloader1.dataset.ids[:].all() == dataloader.dataloader2.dataset.ids[:].all():
    print("Partitioned data is aligned")

In [None]:
class Parser:
    def __init__(self):
        self.epochs = 10
        self.lr = 0.01
        self.seed = 0
        self.input_size = 30 # 30 dimensions
        self.hidden_sizes = [64, 16, 4] # can be altered
        self.output_size = 2 # 0 or 1
    
args = Parser()
torch.manual_seed(args.seed)

In [None]:
class Net():
    def __init__(self, models, optimizers):
        self.models = models
        self.optimizers = optimizers
        
        self.data = []
        self.remote_tensors = []
        
    def forward(self, x):
        data = []
        remote_tensors = []

#         data.append(self.models[0](x))
#         print(data)
        return x
    
    def backward(self):
        
        return
    
    def zero_grads(self):
        for opt in self.optimizers:
            opt.zero_grad()

    def step(self):
        for opt in self.optimizers:
            opt.step()
        

In [None]:
class SplitNN:
    def __init__(self, models, optimizers):
        self.models = models
        self.optimizers = optimizers

        self.data = []
        self.remote_tensors = []

    def forward(self, x):
        data = []
        remote_tensors = []

        data.append(self.models[0](x))

        if data[-1].location == self.models[1].location:
            remote_tensors.append(data[-1].detach().requires_grad_())
        else:
            remote_tensors.append(
                data[-1].detach().move(self.models[1].location).requires_grad_()
            )

        i = 1
        while i < (len(models) - 1):
            data.append(self.models[i](remote_tensors[-1]))

            if data[-1].location == self.models[i + 1].location:
                remote_tensors.append(data[-1].detach().requires_grad_())
            else:
                remote_tensors.append(
                    data[-1].detach().move(self.models[i + 1].location).requires_grad_()
                )

            i += 1

        data.append(self.models[i](remote_tensors[-1]))

        self.data = data
        self.remote_tensors = remote_tensors

        return data[-1]
    
    def backward(self):
        for i in range(len(models) - 2, -1, -1):
            if self.remote_tensors[i].location == self.data[i].location:
                grads = self.remote_tensors[i].grad.copy()
            else:
                grads = self.remote_tensors[i].grad.copy().move(self.data[i].location)
    
            self.data[i].backward(grads)

    def zero_grads(self):
        for opt in self.optimizers:
            opt.zero_grad()

    def step(self):
        for opt in self.optimizers:
            opt.step()

In [None]:
# create workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
workers = [alice, bob]

# create models
models = [
    nn.Sequential(
        nn.Linear(args.input_size, args.hidden_sizes[0]),
        nn.ReLU(),
        nn.Linear(args.hidden_sizes[0], args.hidden_sizes[1]),
        nn.ReLU(),
        nn.Linear(args.hidden_sizes[1], args.hidden_sizes[2]),
        nn.ReLU(),
    ),
    nn.Sequential(nn.Linear(args.hidden_sizes[2], args.output_size), nn.LogSoftmax(dim=1)),
]

# init optimizers
optimizers = [optim.SGD(model.parameters(), lr=args.lr,) for model in models]

# send models to each working node
for model, worker in zip(models, workers):
    model.send(worker)
    
# init splitNN
splitNN = Net(models, optimizers)

In [None]:
def train(features, labels, network):
    
    #1) Zero our grads
    network.zero_grads()
    
    #2) Make a prediction
    pred = network.forward(features)
    
    #3) Figure out how much we missed by
    criterion = nn.MSELoss()
    loss = criterion(pred, labels)
    
    #4) Backprop the loss on the end layer
    loss.backward()
    
    #5) Feed Gradients backward through the network
    network.backward()
    
    #6) Change the weights
    network.step()
    
    return loss, pred

In [None]:
for epoch in range(args.epochs):
    running_loss = 0
    correct_preds = 0
    total_preds = 0
    
    for (ids1, features), (ids2, labels) in dataloader:
        # format data
        features = features.send(models[0].location)
        features = features.view(features.shape[0], -1)
        labels = labels.send(models[-1].location)
#     #         labels = labels.view(len(labels), 1)
        print(features, labels)

        # training
        loss, preds = train(features, labels, splitNN)

        # Collect statistics
        running_loss += loss.get()
        correct_preds += preds.max(1)[1].eq(labels).sum().get().item()
        total_preds += preds.get().size(0)

    print(f"Epoch {i} - Training loss: {running_loss/len(dataloader):.3f} - Accuracy: {100*correct_preds/total_preds:.3f}")
        
        