## Set up local torch and torchvision

In [2]:
import syft as sy
import torch
import torchvision

In [3]:
class SyNet1(sy.Module):
    def __init__(self, torch_ref):
        super(SyNet1, self).__init__(torch_ref=torch_ref)
        self.conv1 = self.torch_ref.nn.Conv2d(1, 32, 3, 1)
        self.conv2 = self.torch_ref.nn.Conv2d(32, 64, 3, 1) 

    def forward(self, x):
        x = self.conv1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.conv2(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.torch_ref.nn.functional.max_pool2d(x, 2)
        output = self.torch_ref.nn.functional.log_softmax(x, dim=1)
        return output
    

In [4]:
class SyNet2(sy.Module):
    def __init__(self, torch_ref):
        super(SyNet2, self).__init__(torch_ref=torch_ref)
        self.dropout1 = self.torch_ref.nn.Dropout2d(0.25)
        self.dropout2 = self.torch_ref.nn.Dropout2d(0.5)
        self.fc1 = self.torch_ref.nn.Linear(9216, 128)
        self.fc2 = self.torch_ref.nn.Linear(128, 10)

    def forward(self, x):
        x = self.dropout1(x)
        x = self.torch_ref.flatten(x, 1)
        x = self.fc1(x)
        x = self.torch_ref.nn.functional.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = self.torch_ref.nn.functional.log_softmax(x, dim=1)
        return output

In [5]:
# we need some transforms for the MNIST data set
local_transform_1 = torchvision.transforms.ToTensor()  # this converts PIL images to Tensors
local_transform_2 = torchvision.transforms.Normalize(0.1307, 0.3081)  # this normalizes the dataset

# compose our transforms
local_transforms = torchvision.transforms.Compose([local_transform_1, local_transform_2])

In [6]:
# Lets define a few settings which are from the original MNIST example command line args
args = {
    "batch_size": 64,
    "test_batch_size": 1000,
    "epochs": 14,
    "lr": 1.0,
    "gamma": 0.7,
    "no_cuda": False,
    "dry_run": False,
    "seed": 42, # the meaning of life
    "log_interval": 10,
    "save_model": True,
}

In [7]:
# we will configure the test set here locally since we want to know if our Data Owner's
# private training dataset will help us reach new SOTA results for our benchmark test set
test_kwargs = {
    "batch_size": args["test_batch_size"],
}

test_data = torchvision.datasets.MNIST('../data', train=False, download=True, transform=local_transforms)
test_loader = torch.utils.data.DataLoader(test_data,**test_kwargs)

images = []
labels = []

for image, label in test_loader:
    images = image
    labels = label

In [16]:
test_loader.dataset.train_data



tensor([[[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        ...,

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]],

        [[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0,

In [7]:
model1 = SyNet1(torch)
model2 = SyNet2(torch)

> Creating local model
> Creating local model


In [8]:
activation = model1.forward(images)

remote_activation = activation.clone()

remote_activation.retain_grad()

prediction = model2.forward(remote_activation)

In [9]:
prediction.shape

torch.Size([1000, 10])

In [10]:
loss = torch.nn.functional.nll_loss(prediction, labels)
loss.backward(retain_graph=True)

In [11]:
remote_activation.grad

tensor([[[[ 6.8394e-06, -8.5468e-07, -2.2038e-06,  ..., -2.4106e-06,
            7.0512e-06,  2.5037e-06],
          [ 1.2837e-06, -2.1526e-06,  7.5765e-06,  ...,  2.6840e-06,
            2.3963e-07, -2.7948e-06],
          [ 2.9843e-06, -2.8461e-06,  2.9022e-06,  ...,  9.3773e-06,
           -5.5580e-07,  3.1974e-06],
          ...,
          [ 4.0675e-06, -3.7466e-06,  1.2139e-06,  ..., -2.7538e-06,
           -2.3363e-06,  1.1123e-05],
          [ 3.6373e-06, -2.6652e-06, -2.0236e-06,  ..., -3.4872e-06,
           -9.3439e-07,  4.3272e-06],
          [ 9.2105e-07, -7.4387e-06, -4.1883e-06,  ...,  1.2428e-06,
            7.4382e-06, -4.7440e-06]],

         [[-3.1219e-06,  1.3506e-06, -8.6934e-06,  ...,  7.9493e-07,
           -3.7253e-08,  6.2875e-06],
          [-1.1274e-05,  2.6723e-06, -2.1631e-06,  ...,  1.6658e-06,
            2.3781e-06,  5.9884e-06],
          [-1.4787e-05,  2.4529e-06, -5.4996e-06,  ..., -1.2375e-06,
           -8.5407e-06,  8.7623e-06],
          ...,
     

In [12]:
activation.backward(remote_activation.grad)