<a href="https://colab.research.google.com/github/walexi/pytorch_challenge/blob/master/demo_application.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Encrypted ML in Health Care
In this quick demo, we show the distributed training of a the Graph Attention Networks (GATConv from Veličković et al.: [Graph Attention Networks](https://arxiv.org/abs/1710.10903) (ICLR 2018)) on a dataset (to simulate private health data(1))




---



*1 The protein-protein interaction networks from the `"Predicting
    Multicellular Function through Multi-layer Tissue Networks"
    <https://arxiv.org/abs/1707.04638>`_ paper, containing positional gene
    sets, motif gene sets and immunological signatures as features (50 in
    total) and gene ontology sets as labels (121 in total).*

## Imports and training configuration

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import time
import os.path as osp
from sklearn.metrics import f1_score


In [0]:
!pip install --verbose --no-cache-dir torch-scatter
!pip install --verbose --no-cache-dir torch-sparse
!pip install --verbose --no-cache-dir torch-cluster
!pip install --verbose --no-cache-dir torch-spline-conv
!pip install torch-geometric

  x86_64-linux-gnu-g++ -pthread -shared -Wl,-O1 -Wl,-Bsymbolic-functions -Wl,-Bsymbolic-functions -Wl,-z,relro -Wl,-Bsymbolic-functions -Wl,-z,relro -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 build/temp.linux-x86_64-3.6/cpu/graclus.o -o build/lib.linux-x86_64-3.6/torch_cluster/graclus_cpu.cpython-36m-x86_64-linux-gnu.so
  building 'torch_cluster.grid_cpu' extension
  x86_64-linux-gnu-gcc -pthread -DNDEBUG -g -fwrapv -O2 -Wall -g -fstack-protector-strong -Wformat -Werror=format-security -Wdate-time -D_FORTIFY_SOURCE=2 -fPIC -I/usr/local/lib/python3.6/dist-packages/torch/include -I/usr/local/lib/python3.6/dist-packages/torch/include/torch/csrc/api/include -I/usr/local/lib/python3.6/dist-packages/torch/include/TH -I/usr/local/lib/python3.6/dist-packages/torch/include/THC -I/usr/include/python3.6m -c cpu/grid.cpp -o build/temp.linux-x86_64-3.6/cpu/grid.o -DTORCH_API_INCLUDE_EXTENSION_H -DTORCH_EXTENSION_NAME=grid_cpu -D_GLIBCXX_USE_CXX11_AB

In [0]:
!pip install syft

This class describes all the hyper-parameters for the training. Note that they are all public here.

In [0]:
class Arguments():
    def __init__(self):
        self.batch_size = 64
        self.test_batch_size = 64
        self.epochs = 10
        self.lr = 0.02
        self.seed = 1
        self.log_interval = 1 # Log info at each batch
        self.precision_fractional = 3

args = Arguments()

_ = torch.manual_seed(args.seed)

Here are PySyft imports. We connect to two remote workers that be call `alice` and `bob` and request another worker called the `crypto_provider` who gives all the crypto primitives we may need.

In [0]:
import syft as sy  # import the Pysyft library
hook = sy.TorchHook(torch)  # hook PyTorch to add extra functionalities like Federated and Encrypted Learning

# simulation functions
def connect_to_workers(n_workers):
    return [
        sy.VirtualWorker(hook, id=f"worker{i+1}")
        for i in range(n_workers)
    ]
def connect_to_crypto_provider():
    return sy.VirtualWorker(hook, id="crypto_provider")

workers = connect_to_workers(n_workers=2)
crypto_provider = connect_to_crypto_provider()



## Getting access and secret share data

Here we're using a utility function which simulates the following behaviour: we assume the dataset is distributed in parts each of which is held by one of our workers. The workers then split their data in batches and secret share their data between each others. The final object returned is an iterable on these secret shared batches, that we call the **private data loader**. Note that during the process the local worker (so us) never had access to the data.

We obtain as usual a training and testing private dataset, and both the inputs and labels are secret shared.

In [0]:
from torch_geometric.datasets import PPI
from torch_geometric.data import DataLoader
from torch_geometric.nn import GATConv

# We don't use the whole dataset for efficiency purpose, but feel free to increase these numbers
n_train_items = 640
n_test_items = 640

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'PPI')
train_dataset = PPI(path, split='train')
val_dataset = PPI(path, split='val')
test_dataset = PPI(path, split='test')
# val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)


def get_private_data_loaders(precision_fractional, workers, crypto_provider):
    
    def one_hot_of(index_tensor):
        """
        Transform to one hot tensor
        
        Example:
            [0, 3, 9]
            =>
            [[1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
             [0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]]
            
        """
        onehot_tensor = torch.zeros(*index_tensor.shape, 10) # 10 classes for MNIST
        onehot_tensor = onehot_tensor.scatter(1, index_tensor.view(-1, 1), 1)
        return onehot_tensor
        
    def secret_share(tensor):
        """
        Transform to fixed precision and secret share a tensor
        """
        return (
            tensor
            .fix_precision(precision_fractional=precision_fractional)
            .share(*workers, crypto_provider=crypto_provider, requires_grad=True)
        )
    
    transformation = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    
    private_train_loader = [
        (secret_share(data), secret_share(one_hot_of(target)))
        for i, (data, target) in enumerate(train_loader)
        if i < n_train_items / args.batch_size
    ]
    
    test_loader = DataLoader(test_dataset, batch_size=args.test_batch_size, shuffle=False)
    
    private_test_loader = [
        (secret_share(data), secret_share(target.float()))
        for i, (data, target) in enumerate(test_loader)
        if i < n_test_items / args.test_batch_size
    ]
    
    return private_train_loader, private_test_loader
    
    
private_train_loader, private_test_loader = get_private_data_loaders(
    precision_fractional=args.precision_fractional,
    workers=workers,
    crypto_provider=crypto_provider
)

## Model specification

Here is the model that we will use, it's a rather simple one but [it has proved to perform reasonably well on MNIST](https://towardsdatascience.com/handwritten-digit-mnist-pytorch-977b5338e627)

In [0]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GATConv(train_dataset.num_features, 256, heads=4)
        self.lin1 = nn.Linear(train_dataset.num_features, 4 * 256)
        self.conv2 = GATConv(4 * 256, 256, heads=4)
        self.lin2 = nn.Linear(4 * 256, 4 * 256)
        self.conv3 = GATConv(4 * 256, train_dataset.num_classes, heads=6, concat=False)
        self.lin3 = nn.Linear(4 * 256, train_dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.elu(self.conv1(x, edge_index) + self.lin1(x))
        x = F.elu(self.conv2(x, edge_index) + self.lin2(x))
        x = self.conv3(x, edge_index) + self.lin3(x)
        return x


## Training and testing functions

The training is done almost as usual, the real difference is that we can't use losses like negative log-likelihood (`F.nll_loss` in PyTorch) because it's quite complicated to reproduce these functions with SMPC. Instead, we use a simpler Mean Square Error loss.

In [0]:
def train(args, model, private_train_loader, loss_op, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(private_train_loader): # <-- now it is a private dataset
        start_time = time.time()
        num_graphs = data.num_graphs
        data.batch = None        
        optimizer.zero_grad()
        
        output = model(data)
        
        # loss = F.nll_loss(output, target)  <-- not possible here
        batch_size = output.shape[0]
        loss = ((output - target)**2).sum().refresh()/batch_size
        
        loss = loss_op(model(data.x, data.edge_index), data.y)
        total_loss += loss.item() * num_graphs

        loss.backward()
        
        optimizer.step()

        if batch_idx % args.log_interval == 0:
            loss = loss.get().float_precision()
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tTime: {:.3f}s'.format(
                epoch, batch_idx * args.batch_size, len(private_train_loader) * args.batch_size,
                100. * batch_idx / len(private_train_loader), total_loss / len(train_loader.dataset), time.time() - start_time))
            

The test function 

In [0]:
def test(args, model, private_test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in private_test_loader:
            start_time = time.time()
            
            output = model(data)
            pred = output.argmax(dim=1)
            correct += pred.eq(target.view_as(pred)).sum()

    correct = correct.get().float_precision()
    print('\nTest set: Accuracy: {}/{} ({:.0f}%)\n'.format(
        correct.item(), len(private_test_loader)* args.test_batch_size,
        100. * correct.item() / (len(private_test_loader) * args.test_batch_size)))

### Let's launch the training !

A few notes about what's happening here. First, we secret share all the model parameters across our workers. Second, we convert optimizer's hyperparameters to fixed precision. Note that we don't need to secret share them because they are public in our context, but as secret shared values live in finite fields we still need to move them in finite fields using using `.fix_precision`, in order to perform consistently operations like the weight update $W \leftarrow W - \alpha * \Delta W$.

In [0]:
model = Net()
model = model.fix_precision().share(*workers, crypto_provider=crypto_provider, requires_grad=True)

loss_op = nn.BCEWithLogitsLoss()
loss_op = loss_op.fix_precision()

optimizer = optim.Adam(model.parameters(), lr=0.005)
optimizer = optimizer.fix_precision() 

for epoch in range(1, args.epochs + 1):
    train(args, model, private_train_loader, optimizer, epoch)
    test(args, model, private_test_loader)


Test set: Accuracy: 300.0/640 (47%)


Test set: Accuracy: 462.0/640 (72%)


Test set: Accuracy: 497.0/640 (78%)


Test set: Accuracy: 517.0/640 (81%)


Test set: Accuracy: 529.0/640 (83%)


Test set: Accuracy: 539.0/640 (84%)


Test set: Accuracy: 544.0/640 (85%)


Test set: Accuracy: 553.0/640 (86%)

