# Split Neural Network (SplitNN)

Traditionally, PySyft has been used to facilitate federated learning. However, we can also leverage the tools included in this framework to implement distributed neural networks. 

### What is a SplitNN?

<img src="images/anatomy.png" width="50%">

The training of a neural network (NN) is 'split' accross a chain of multiple hosts. Each segment in the chain is a self contained NN that feeds into the segment in front. The host with the training data has the beginning segment of the network and the end segment. Intermediate segments of the chain are held by participating hosts.

### Training Process

The SplitNN network is assembled as a chain of NNs, each feeding into the next. The data subject has both the beginning and the end of this chain.

<img src="images/training.png" width="80%">

When forward propogation commences, the data subject propogates the x values forward through the network at the start of the chain and sends their activation signals to the next intermediate host. This host feeds the recieved activation signal forward through their network and to the next link in the chain. This continues until the end of the chain is reached. The data subject then recieves an activation signal and forward propogates. They then compute the loss using their y-values.

They backward propogate the gradients of the activation signals they recieved to the host previous to them in the chain. This host then computes their gradients, sending the gradient of the activeation signal backward. Eventually, the data subject recieves the gradients of their activation signal at the start of the chain and computes gradients. 

The NNs in the chain update their weights and biases, commenceing the next epoch. When a host is finished training, they pass the front and end segments to the next person with data to train.

<img src="images/BatchFlow.png" width="40%">


### Why use a SplitNN?

The SplitNN has been shown to provide a dramatic reduction to the computational burden of training while maintaining higher accuracies when training over large number of clients [[1](https://arxiv.org/abs/1812.00564)]. In the figure below, the Blue line denotes distributed deep learning using splitNN, red line indicate federated learning (FL) and green line indicates Large Batch Stochastic Gradient Descent (LBSGD).

<img src="images/AccuracyvsFlops.png" width="60%">

<img src="images/computation.png" width="40%">
 
Table 1 shows computational resources consumed when training CIFAR 10 over VGG. Theses are a fraction of the resources of FL and LBSGD. Table 2 shows the bandwith usage when training CIFAR 100 over ResNet. Federated learning is less bandwidth intensive with fewer than 100 clients. However, the SplitNN outperforms other approaches as the number of clients grow[[1](https://arxiv.org/abs/1812.00564)].

<img src="images/bandwidth.png" width="40%">

Using this technique, nobody knows the input data and labels apart from the data subject. All that is sent or recieved between nodes is activation signals during forward propogation and their corresponding gradients during backpropogation. Entropy can be added to the activation signals through adding layers to the model segments. Entropy of training data could potentially be measured to arrive at the appropriate number of layers to use in order to adequately hide the orignal values in start and end segments.

During this process, no hosts involved in the learning process have a full picture of the network. As a result there is very little risk of the model being stolen by participating hosts. Models could only be fully recovered by malicious participants if they were to collude with every other host. 

### Advantages

- The accuracy should be almost identical to a non-split version of the same model, trained locally. 
- Models and data can be homomorphically encrypted for added security at the cost of added computation.
- Model is distributed, meaning all segment holders must consent in order to aggregate the model at the end of training.
- The scalability of this approach, in terms of both network and computational resources, could make this an a valid alternative to FL and LBSGD, particularly on low power devices.
- Could be an effective mechanism for both horizontal and vertical data distributions.
- As computational cost is already quite low, proportianate homomorphic encryption cost is also minimised.
- Only activation signal gradients are sent/ recieved, meaning that malicious actors cannot use gradients of model parameters to reverse engineer the original values

### Constraints

- A new technique with little surroundung literature, a large amount of compartison and evaluation is still to be done.
- This approach requires all hosts to remain online during the entire learning process.
    - makes approach less fesible for hand-held devices
- Less established in toolkits than FL and LBSGD
- While most aspects of the learning process are anonymised, the intermediary hosts are need to know the location of those ahead and behind to send and recieve data during learning. Ideally this would provide anonymity to participants.
- This approach is less secure with small groups training but becomes more secure as the number of model segments increases

### Tutorial 

The SpliNN has the capacity to be a significant contribution to the growing ecosystem of privacy preserving learning methodologies. This tutorial has three purposes;

- To explain in as clear terms as possible what is going on during the training of a SplitNN.
- To provide a working example of a SplitNN learning on arbitrary sets of; training data, model segments and data providers.
- To provide an implementation of this technique on the PySyft framework so that this can be further validated against other techniques. 

Authors:
- Adam Hall - Twitter: [@AJH4LL](https://twitter.com/AJH4LL)


# Section 1.1 - A Toy NN Example

<img src="images/wholeNetwork.png" width="60%">


We will begin by training a normal model to benchmark our SplitNN against. This will be the exact same specification as the SPlitNN onlt not distributed. In order to get as close as possible, we will use the same random seed for intitialisation. The dataset will take in arbitrary dataset of four binary features. The only x instance with a y value that is 1 will be 1111.

<img src="images/benchmarkExample.png" width="60%">

We will create three identical datasets to simulate the batches for each data owner.

<img src="images/identicalDatasets.png" width="40%">




In [1]:
import torch
from torch import nn
from torch import optim
#from torchviz import make_dot, make_dot_from_trace
from torch.autograd import Variable
import time

In [2]:
torch.manual_seed(1)

# A Toy Dataset
data = torch.tensor([[0,0,0,0],[1,0,0,0],[0,1,0,0],[0,0,1,0],[1,1,0,0],[1,0,1,0],[0,1,1,0],[1,1,1,0],[0,0,0,1],[1,0,0,1],[0,1,0,1],[0,0,1,1],[1,1,0,1],[1,0,1,1],[0,1,1,1],[1,1,1,1.]])
targets = torch.tensor([[0],[0],[0],[0],[0],[0],[0],[0],[1],[1],[1],[1],[1],[1],[1],[1.]])

# Create 3 copies of the dataset
datasets = [
    (data.clone(), targets.clone())
    for i in range(3)
]

# One Model
model = nn.Sequential(
            nn.Linear(4, 3),
            nn.Tanh(),
            nn.Linear(3, 3),
            nn.Sigmoid(),
            nn.Linear(3, 3),
            nn.Sigmoid(),
            nn.Linear(3, 2),
            nn.Tanh(),
            nn.Linear(2, 1),
            nn.Sigmoid()
        )

## Model Summary

Here we see the parameter values for our model and the computation graph.

<img src="images/wholeNetwork.png" width="60%">

### Model Parameters and Computation Graph

In [3]:
# for name, param in model.named_parameters():
#     if param.requires_grad:
#         print(name, param.data)
# make_dot(model(x), params=dict(model.named_parameters()))

## Training Logic

<img src="images/benchmarkExample.png" width="60%">

In [4]:
def train(x, y):
    # Training logic
    
    epochs = 300
    lr = 0.2
    
    opt = optim.SGD(params=model.parameters(),lr=lr)
    start_time = time.time()
    
    for epoch in range(epochs):

        # 1) erase previous gradients (if they exist)
        opt.zero_grad()

        # 2) make a prediction
        pred = model(x)

        # 3) calculate how much we missed
        loss = ((pred - y)**2).sum()

        # 4) figure out which weights caused us to miss
        loss.backward()

        # 5) change those weights
        opt.step()

        # 6) print our progress every 30 epochs
        if epoch % 30 == 0:
            print(f"Epoch: {epoch}/{epochs} \tLoss: ", "{:.4f}\tRuntime: {:.2f}s".format(loss.data, time.time() - start_time))

In [5]:
for dataset in datasets:
    print("\nNEXT BATCH")
    data, targets = dataset
    train(data, targets)

print("\nFinal Predictions:\n", torch.t(model(data)).data)



NEXT BATCH
Epoch: 0/300 	Loss:  3.9986	Runtime: 0.00s
Epoch: 30/300 	Loss:  3.9964	Runtime: 0.02s
Epoch: 60/300 	Loss:  3.9931	Runtime: 0.04s
Epoch: 90/300 	Loss:  3.9816	Runtime: 0.06s
Epoch: 120/300 	Loss:  3.8692	Runtime: 0.08s
Epoch: 150/300 	Loss:  0.0908	Runtime: 0.10s
Epoch: 180/300 	Loss:  0.0213	Runtime: 0.12s
Epoch: 210/300 	Loss:  0.0115	Runtime: 0.14s
Epoch: 240/300 	Loss:  0.0078	Runtime: 0.16s
Epoch: 270/300 	Loss:  0.0058	Runtime: 0.18s

NEXT BATCH
Epoch: 0/300 	Loss:  0.0046	Runtime: 0.00s
Epoch: 30/300 	Loss:  0.0039	Runtime: 0.02s
Epoch: 60/300 	Loss:  0.0033	Runtime: 0.05s
Epoch: 90/300 	Loss:  0.0029	Runtime: 0.07s
Epoch: 120/300 	Loss:  0.0025	Runtime: 0.09s
Epoch: 150/300 	Loss:  0.0023	Runtime: 0.11s
Epoch: 180/300 	Loss:  0.0021	Runtime: 0.13s
Epoch: 210/300 	Loss:  0.0019	Runtime: 0.15s
Epoch: 240/300 	Loss:  0.0017	Runtime: 0.17s
Epoch: 270/300 	Loss:  0.0016	Runtime: 0.19s

NEXT BATCH
Epoch: 0/300 	Loss:  0.0015	Runtime: 0.00s
Epoch: 30/300 	Loss:  0.0014	Ru

# Section 2.1 - A Distributed Training Example

We will train a splitNN model that has been distributed to three different hosts. One host, Alice, is the data subject. Alice has the labelled data and will also be the custodian of the network start and end segments. Claire and Bob are worker hosts. They will feed the activation signals from the start of the chain forward until it reaches alices end layer. They will do the reverse with gradients in the backpropogation process. 

## Section 2.1.1 - Set up environmental variables

Here we will import our required libraries and initialise our model segments and data. We will need;

<img src="images/distributed.png" width="50%">

- A dummy distributed dataset
- 5 model segments
- 3 Virtual Workers

In [6]:
import torch
from torch import nn
from torch import optim
import syft as sy
hook = sy.TorchHook(torch)

#from torchviz import make_dot, make_dot_from_trace
from torch.autograd import Variable

In [7]:
@property
def location(self):
    m = self.__getitem__(0)
    w = m.weight[0]
    return w.location

nn.Sequential.location = location

In [8]:
# A Toy Dataset
x = torch.tensor([[0,0,0,0],[1,0,0,0],[0,1,0,0],[0,0,1,0],[1,1,0,0],[1,0,1,0],[0,1,1,0],[1,1,1,0],[0,0,0,1],[1,0,0,1],[0,1,0,1],[0,0,1,1],[1,1,0,1],[1,0,1,1],[0,1,1,1],[1,1,1,1.]])
y = torch.tensor([[0],[0],[0],[0],[0],[0],[0],[0],[1],[1],[1],[1],[1],[1],[1],[1.]])

torch.manual_seed(1)

# Define 5 chained models
models = [
    nn.Sequential(
        nn.Linear(4, 3),
        nn.Tanh()
    ),
    nn.Sequential(
        nn.Linear(3, 3),
        nn.Sigmoid()
    ),
    nn.Sequential(
        nn.Linear(3, 3),
        nn.Sigmoid()
    ),
    nn.Sequential(
        nn.Linear(3, 2),
        nn.Tanh()
    ),
    nn.Sequential(
        nn.Linear(2, 1),
        nn.Sigmoid()
    )
]

# create some workers
alice = sy.VirtualWorker(hook, id="alice")
bob = sy.VirtualWorker(hook, id="bob")
claire = sy.VirtualWorker(hook, id="claire")
workers = alice, bob, claire

The final predictions are shown above, we can compare this with the output of the same 'split' neural network

## Section 2.1.2 - Send Variables to Starting Locations

In this example, Alice is the worker with the data and labels. Bob and Claire are intermediary hosts in the chain. Alice has the start and end model segments. Bob and Claire have intermediary segments.

We send the models and data to their respective hosts and store the pointers in associative arrays; the Model Chain (MC) and the xy Chain (xyC). These contain the locations of the data, but no actual values. These are the only necessary parameters for coordinating this learning process. A summary of this is seen below

<img src="images/Parameters.png" width="50%">

In this experiment, the models and data are initialised locally and then distributed out.

In [9]:
# Send Model Segments to starting locations
model_locations = [alice, alice, bob, claire, alice]

for model, location in zip(models, model_locations):
    model.send(location)


# Create a remote copy of the dataset for each worker
datasets = [
    sy.BaseDataset(x.send(worker), y.send(worker))
    for worker in (alice, bob, claire)
]

## Section 2.1.3 - Forward Propogation

We will need to define the logic of forward and backward propogation. 

Forward propogation feeds the input data into Alice's segment at the beginning of the chain. Alice then sends her activation signal to the location of the next model in the chain. This model propogates this activation and sends it onward to the location of the next segment. The signal will eventually reach alice's end segment to perform a prediction. We store pointers to the activations of each layer using the Activation Chain (AC). This allows us to retrieve the values when processing gradients. When the activations have fully propogated the MC, the method returns the resultant AC for use in the backpropogation function. 

<img src="images/activationchain.png" width="50%">

In [10]:
def forward(models, x):

    inputs = []
    outputs = []
    
    # First: provide x as input
    inputs.append(x)
    outputs.append(models[0](x))
    next_input = outputs[-1].copy().get().send(models[1].location)
    
    for i in range(1, len(models)-1):
        inputs.append(next_input)
        outputs.append(models[i](next_input))
        next_input = outputs[-1].copy().get().send(models[i+1].location)
 
    # Last: don't move the result to the next location
    inputs.append(next_input)
    outputs.append(models[len(models)-1](next_input))
    
    return inputs, outputs

## Section 2.1.4 - Backward Propogation

The backpropogation function takes the MC, xyC and AC as input parameters.

<img src="images/backpropParams.png" width="80%">


First the backpropogation algorithm computes the loss on Alice's prediction. We use <b>**** what seems to be**** </b> the sum of squared error as our loss function.

<img src="images/loss.png" width="100%">

We then calculate the gradients for the parameters of the end segment using the chain rule.

<img src="images/chainRule.png" width="40%">

This is done automatically for the layers in each segment but we have to recalculate loss for each model segment during the backpropogation phase.

<img src="images/intermediateLoss.png" width="80%">




Each segment feeds the gradients of their activation function back to the segment behind them and updates their weights w.r.t these gradients. This layer computes it's loss by dot joining the orignal activation signal and it's new gradient. The sum of the result is used to feed back error down the line. After each segment is complete, the optimiser for that model updates. The process is repeated until the segment at the beginning of the chain is reached and alice updates the gradients on her beginning segment.

In [11]:
def backward(models, optimizers, segment_inputs, segment_outputs, dataset):
    data, targets = dataset.data, dataset.targets
        
    # Destroy pre-existing gradient of final layer
    optimizers[len(optimizers)-1].zero_grad()
   
    #     TODO: LOOKS LIKE JUST SQUARED ERROR, NOT MEAN SQUARED. 
    #         NOT SURE IF I HAVE THE RIGHT LOSS EQUATION. COULD BE
    #         THAT THIS IS DONE AS PART OF THE .SUM() FUNCTION THOUGH?
    #         WHEN I ADD THE /n PART IT DOESN'T LEARN SO WELL..
    # Calculates Loss
    loss = (((segment_outputs[-1] - targets)**2).sum())

    # Compute gradients
    loss.backward()
    
    # End layer sends the gradient of the activation signal back to the layer behind
    input_gradient = segment_inputs[-1].grad.clone().get().send(models[len(models)-2].location)
    
    # End layer updates weights
    optimizers[-1].step()

    # Compute Intermediary Layers: repeat the same operations
    for iter in range(len(models)-1, 1, -1): 
        optimizers[iter-1].zero_grad()
        intermediate_loss = torch.matmul(torch.t(segment_outputs[iter-1]), input_gradient).sum()
        intermediate_loss.backward()
        input_gradient = segment_inputs[iter-1].grad.clone().get().send(models[iter-2].location)
        optimizers[iter-1].step()

    # Compute Final Layer, same but now input is the real input data
    optimizers[0].zero_grad()
    segment_output = segment_outputs[0]
    intermediate_loss = torch.matmul(torch.t(segment_output), input_gradient).sum()
    intermediate_loss.backward()
    optimizers[0].step()
        
    return segment_outputs[-1], loss

## Section 2.1.5 - Run Training Logic

Now we will run the training process over 200 epochs for each data owner. Every 20 epochs we will print our progress. The front and end sections of the model will be swapped between data owners training each individual batch.

<img src="images/BatchFlow.png" width="40%">


In [12]:
def splitNN_train(models, xyChain):
    
    #   Variables for performance metrics
    start_time = time.time()
    epochs = 300
    lr = 0.2
    counter = 0
    
    # Create optimisers for each segment and link to their segment
    optimizers = [
        optim.SGD(params=model.parameters(),lr=lr)
        for model in models
    ]
    
    for i, local_worker in enumerate(workers):
        
        # Begin work on current data subject
        dataset = datasets[i]
        
        print('*', dataset.location.id, models[0].location.id)
        
        for epoch in range(epochs):
            # Forward propogate through network until final layer is reached
            segment_inputs, segment_outputs = forward(models, dataset.data)
            
            # Backward propogate
            predictions, loss = backward(models, optimizers, segment_inputs, segment_outputs, dataset)

            if epoch % 30 == 0:
                print(f"Epoch: {epoch}/{epochs} \tLoss: ", "{:.4f}\tRuntime: {:.2f}s".format(loss.get().data, time.time() - start_time))
        
        # If we are not at the end of the data owner chain send perimeter segments to next data owner
        if i < len(workers)-1:
            models[0].get().send(datasets[i+1].location)
            models[len(models)-1].get().send(datasets[i+1].location)      
            

            print("\nNEXT DATA OWNER\n")
            print("MODEL CHAIN LOCATIONS")
            for iter in range(len(models)):
                print(models[iter].location.id)  
            print("\n")
    
    # Send models back to researcher
    [model.get() for model in models]
    
    # Perform predictions with updates weights
    out = torch.tensor([[0,0,0,0],[1,0,0,0],[0,1,0,0],[0,0,1,0],[1,1,0,0],[1,0,1,0],[0,1,1,0],[1,1,1,0],[0,0,0,1],[1,0,0,1],[0,1,0,1],[0,0,1,1],[1,1,0,1],[1,0,1,1],[0,1,1,1],[1,1,1,1.]])
    for i in range(len(models)):
        out = models[i](out)
        
    print("\n\nFinal Predictions:", torch.t(out).data)
    

In [13]:
splitNN_train(models, datasets)

* alice alice
Epoch: 0/300 	Loss:  3.9986	Runtime: 0.02s
Epoch: 30/300 	Loss:  3.9976	Runtime: 0.70s
Epoch: 60/300 	Loss:  3.9973	Runtime: 1.39s
Epoch: 90/300 	Loss:  3.9970	Runtime: 2.05s
Epoch: 120/300 	Loss:  3.9966	Runtime: 2.74s
Epoch: 150/300 	Loss:  3.9963	Runtime: 3.42s
Epoch: 180/300 	Loss:  3.9959	Runtime: 4.12s
Epoch: 210/300 	Loss:  3.9955	Runtime: 4.88s
Epoch: 240/300 	Loss:  3.9951	Runtime: 5.70s
Epoch: 270/300 	Loss:  3.9946	Runtime: 6.52s

NEXT DATA OWNER

MODEL CHAIN LOCATIONS
bob
alice
bob
claire
bob


* bob bob
Epoch: 0/300 	Loss:  3.9940	Runtime: 7.36s
Epoch: 30/300 	Loss:  3.9934	Runtime: 8.15s
Epoch: 60/300 	Loss:  3.9926	Runtime: 8.97s
Epoch: 90/300 	Loss:  3.9917	Runtime: 9.80s
Epoch: 120/300 	Loss:  3.9905	Runtime: 10.70s
Epoch: 150/300 	Loss:  3.9889	Runtime: 11.52s
Epoch: 180/300 	Loss:  3.9868	Runtime: 12.31s
Epoch: 210/300 	Loss:  3.9836	Runtime: 13.11s
Epoch: 240/300 	Loss:  3.9774	Runtime: 13.91s
Epoch: 270/300 	Loss:  3.9534	Runtime: 14.70s

NEXT DATA OW

# Results

- The SplitNN trains more slowly than the centralised network.
    - This is expected due to the current processing redundancies and network latencies ().
- The SplitNN takes more epochs to train correctly
    - This could be to do with the way that loss is transferred down the line (the activation signals and corresponding gradients are dot-joined and summed. This acts as a loss equation for computing the gradients of intermediate layers)
    - The same random seed is used when generating both models in order to make these as similar as possible

# Future work

- The present approach stops anyone from having knowledge of the NN segment behind them. This is true except in one instance. When Alice passes segment 1 and 5 to Bob, Alice has prior knowledge of Segment 1. Then when Bob then passes his x values through and sends the resultant a values to segment 2 (Alice), Alice knows the model and resultant activation signal. Alice also knows the error being fed back so she can maintain knowledge of the model as it's. With the knowledge of the model and activation signals, Alice could reverse engineer Bob's x values.
    - This could potentially be solved by homomorphically encrypting x values being fed into the model. Alice would then not know the activation signal, however a robust approach for this is yet to be established.
- A tokenisation infrastructure or masked dns service could be implemented to provide anonymity to hosts during traing. Ideally the difference 'chains' involved here could be written to a smart contract and be publicly available information. 


### TODO:
- Figure out whether we need to reconstruct computation graph at each layer
    - I am unclear as to whether i can work with a computation graph that is distributed. For this reason I recreate the computation graph by pushing through the activation signals of An-1 again and recompute error with that. This means that we are doing around twice the computation, which is a bad thing. Hopefully there is a workaround!
- Implement .move() instead of .get().send()
    - Should move data directly between owners