## Load in Dependencies and declare split model

In [1]:
import syft as sy
import torch
import torchvision

In [2]:
class SyNet1(sy.Module):
    def __init__(self, torch_ref):
        super(SyNet1, self).__init__(torch_ref=torch_ref)
        self.conv1 = self.torch_ref.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = self.torch_ref.nn.Conv2d(32, 64, 3, 1) 

    def forward(self, x):
        x = self.conv1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.torch_ref.nn.functional.max_pool2d(x, 2)
        output = self.torch_ref.nn.functional.log_softmax(x, dim=1)
        return output
    

In [3]:
class SyNet2(sy.Module):
    def __init__(self, torch_ref):
        super(SyNet2, self).__init__(torch_ref=torch_ref)
        self.dropout1 = self.torch_ref.nn.Dropout2d(0.25)
        self.dropout2 = self.torch_ref.nn.Dropout2d(0.5)
        self.fc1 = self.torch_ref.nn.Linear(9216, 128)
        self.fc2 = self.torch_ref.nn.Linear(128, 10)

    def forward(self, x):
        x = self.dropout1(x)
        x = self.torch_ref.flatten(x, 1)
        x = self.fc1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = self.torch_ref.nn.functional.log_softmax(x, dim=1)
        return output

### Declare our transforms and starting parameters 

In [4]:
# we need some transforms for the MNIST data set
local_transform_1 = torchvision.transforms.ToTensor()  # this converts PIL images to Tensors
local_transform_2 = torchvision.transforms.Normalize(0.1307, 0.3081)  # this normalizes the dataset

# compose our transforms
local_transforms = torchvision.transforms.Compose([local_transform_1, local_transform_2])

args = {
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 14,
    "lr": 1.0,
    "gamma": 0.7,
    "no_cuda": False,
    "dry_run": False,
    "seed": 42, # the meaning of life
    "log_interval": 10,
    "save_model": True,
}

### Split our sample data items up

In [6]:
test_kwargs = {
    "batch_size": args["test_batch_size"],
}

test_data = torchvision.datasets.MNIST('../data', train=False, download=True, transform=local_transforms)
test_loader = torch.utils.data.DataLoader(test_data,**test_kwargs)

images = []
labels = []

for image, label in test_loader:
    images = image
    labels = label

### Declare local models

In [7]:
model1 = SyNet1(torch)
model2 = SyNet2(torch)

> Creating local model
> Creating local model


In [8]:
## Perform prediction with first model half
activation = model1.forward(images)

## Clone the output, detach the clone from the current computation graph
remote_activation = activation.clone().detach()

## Make sure that our activation clone will accomodate gradients
remote_activation.requires_grad=True

## perform the rest of the prediciton 
prediction = model2.forward(remote_activation)

In [9]:
## Perform backprop on first half of the model
loss = torch.nn.functional.nll_loss(prediction, labels)
loss.backward(retain_graph=True)

In [14]:
## Check if gradients have backpropped to remote activation
remote_activation.grad[0]

tensor([[[ 0.0000e+00, -0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
          -0.0000e+00, -0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 0.0000e+00, -0.0000e+00,  0.0000e+00,  ..., -0.0000e+00,
          -0.0000e+00, -0.0000e+00],
         [-0.0000e+00, -0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00, -0.0000e+00],
         [ 0.0000e+00,  0.0000e+00, -0.0000e+00,  ...,  0.0000e+00,
          -0.0000e+00, -0.0000e+00]],

        [[-1.7958e-06,  3.7914e-08, -1.7888e-06,  ..., -9.2197e-07,
           3.0884e-07,  3.3414e-07],
         [ 1.4294e-06, -4.8165e-07,  1.9906e-06,  ...,  1.8743e-07,
          -5.9748e-07, -1.1580e-06],
         [ 6.5884e-07, -1.3260e-06, -2.6586e-07,  ...,  6.5281e-07,
          -1.7911e-06,  6.3171e-07],
         ...,
         [ 2.0335e-06, -9

In [11]:
## Check if gradients have backpropped to final layer of model 1 (Should be empty)
model1.parameters()[0].grad

In [12]:
## Compute jacobian backprop use remote activation grads
activation.backward(remote_activation.grad)

In [16]:
## Check if gradients have backpropped to final layer of model 1 (Should be full)
model1.parameters()[0].grad

tensor([[[[ 1.5766e-03,  1.6416e-03,  1.1055e-03],
          [ 6.8521e-04,  3.2545e-05, -2.9522e-05],
          [-7.1609e-04, -5.7486e-04, -4.7383e-04]]],


        [[[-6.5282e-04,  9.3117e-04,  4.1995e-04],
          [ 1.1060e-04,  1.6374e-03,  1.1962e-03],
          [ 1.1581e-03,  2.1012e-03,  2.3780e-03]]],


        [[[ 2.1236e-03,  3.9780e-03,  5.4800e-03],
          [ 5.1397e-03,  6.8246e-03,  7.6834e-03],
          [ 6.8627e-03,  8.1989e-03,  7.4044e-03]]],


        [[[ 7.6919e-03,  7.9978e-03,  6.7736e-03],
          [ 8.7192e-03,  8.8712e-03,  6.2176e-03],
          [ 7.2188e-03,  7.0097e-03,  4.3711e-03]]],


        [[[-8.6447e-04, -6.0188e-05,  2.8168e-04],
          [-9.1491e-04, -1.0318e-03, -1.7026e-04],
          [-7.8739e-04, -8.9623e-04,  1.0096e-04]]],


        [[[-4.6456e-03, -2.9234e-03, -2.8406e-03],
          [-3.7162e-03, -2.7176e-03, -4.7130e-03],
          [-6.5337e-03, -6.4974e-03, -7.6757e-03]]],


        [[[ 9.2172e-03,  4.4644e-03,  1.0688e-03],
       