In [1]:
SYFT_VERSION = ">=0.8.2.b0,<0.9"
package_string = f'"syft{SYFT_VERSION}"'
# %pip install {package_string} -q

In [8]:
# third party
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np

# syft absolute
import syft as sy

sy.requires(SYFT_VERSION)

✅ The installed version of syft==0.8.7 matches the requirement >=0.8.2b0 and the requirement <0.9


In [3]:
server = sy.orchestra.launch(name="test-datasite-3", dev_mode=True, reset=True)

In [4]:
datasite_client = server.login(email="info@openmined.org", password="changethis")

Logged into <test-datasite-3: High side Datasite> as <info@openmined.org>


In [5]:
# Set the random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x7f1623ff77f0>

In [6]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

Files already downloaded and verified


In [11]:


# Split the dataset between two workers
trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset)//20, shuffle=True)
# Generate random data
data_batches = iter(trainloader)
train_data = next(data_batches)

In [27]:
train_data

[tensor([[[[ 0.1373,  0.1216,  0.1294,  ...,  0.1294,  0.1294,  0.1216],
           [ 0.1451,  0.1294,  0.1373,  ...,  0.1373,  0.1373,  0.1294],
           [ 0.1451,  0.1294,  0.1373,  ...,  0.1373,  0.1373,  0.1294],
           ...,
           [ 0.1294,  0.1137,  0.1216,  ...,  0.1451,  0.1451,  0.1373],
           [ 0.1294,  0.1137,  0.1216,  ...,  0.1373,  0.1373,  0.1373],
           [ 0.1216,  0.1059,  0.1137,  ...,  0.1216,  0.1294,  0.1373]],
 
          [[ 0.4980,  0.4824,  0.4824,  ...,  0.4902,  0.4902,  0.4824],
           [ 0.5137,  0.4902,  0.4980,  ...,  0.5059,  0.5059,  0.4902],
           [ 0.5137,  0.4902,  0.4980,  ...,  0.5059,  0.5059,  0.4902],
           ...,
           [ 0.5216,  0.4980,  0.5059,  ...,  0.5137,  0.5137,  0.4980],
           [ 0.5216,  0.4980,  0.5059,  ...,  0.5137,  0.5059,  0.4980],
           [ 0.5137,  0.4902,  0.4980,  ...,  0.5059,  0.5059,  0.4980]],
 
          [[ 0.9686,  0.9373,  0.9451,  ...,  0.9294,  0.9294,  0.9373],
           [ 

In [26]:
for inputs, labels in train_data:
    print(inputs.shape, labels.shape)

ValueError: too many values to unpack (expected 2)

In [7]:
assert torch.round(train_data.sum()) == 1557

In [12]:
train = sy.ActionObject.from_obj(train_data)

In [14]:
train_datasite_obj = train.send(datasite_client)
type(train_datasite_obj)

syft.service.action.action_object.AnyActionObject

In [15]:
train_datasite_obj


**Pointer**

[tensor([[[[ 0.1373,  0.1216,  0.1294,  ...,  0.1294,  0.1294,  0.1216],
          [ 0.1451,  0.1294,  0.1373,  ...,  0.1373,  0.1373,  0.1294],
          [ 0.1451,  0.1294,  0.1373,  ...,  0.1373,  0.1373,  0.1294],
          ...,
          [ 0.1294,  0.1137,  0.1216,  ...,  0.1451,  0.1451,  0.1373],
          [ 0.1294,  0.1137,  0.1216,  ...,  0.1373,  0.1373,  0.1373],
          [ 0.1216,  0.1059,  0.1137,  ...,  0.1216,  0.1294,  0.1373]],

         [[ 0.4980,  0.4824,  0.4824,  ...,  0.4902,  0.4902,  0.4824],
          [ 0.5137,  0.4902,  0.4980,  ...,  0.5059,  0.5059,  0.4902],
          [ 0.5137,  0.4902,  0.4980,  ...,  0.5059,  0.5059,  0.4902],
          ...,
          [ 0.5216,  0.4980,  0.5059,  ...,  0.5137,  0.5137,  0.4980],
          [ 0.5216,  0.4980,  0.5059,  ...,  0.5137,  0.5059,  0.4980],
          [ 0.5137,  0.4902,  0.4980,  ...,  0.5059,  0.5059,  0.4980]],

         [[ 0.9686,  0.9373,  0.9451,  ...,  0.9294,  0.9294,  0.9373],
          [ 0.9843,  0.9529,  0.9608,  ...,  0.9451,  0.9451,  0.9529],
          [ 0.9843,  0.9529,  0.9608,  ...,  0.9451,  0.9451,  0.9529],
          ...,
          [ 0.9843,  0.9529,  0.9608,  ...,  0.9529,  0.9529,  0.9608],
          [ 0.9765,  0.9529,  0.9608,  ...,  0.9529,  0.9529,  0.9608],
          [ 0.9765,  0.9451,  0.9529,  ...,  0.9608,  0.9608,  0.9608]]],


        [[[-0.3647, -0.1294,  0.1529,  ..., -0.5686, -0.5451,  0.0039],
          [-0.2157, -0.0824,  0.1451,  ..., -0.6078, -0.6235, -0.4353],
          [-0.0510, -0.0275,  0.0196,  ..., -0.6941, -0.6392, -0.5216],
          ...,
          [-0.8275, -0.8275, -0.8667,  ..., -0.7569, -0.7804, -0.8039],
          [-0.7647, -0.7804, -0.8196,  ..., -0.4824, -0.4902, -0.4510],
          [-0.5294, -0.5294, -0.5451,  ..., -0.2627, -0.3020, -0.2706]],

         [[-0.2784, -0.0275,  0.2706,  ..., -0.5216, -0.4980,  0.0431],
          [-0.0902,  0.0431,  0.2706,  ..., -0.5686, -0.5608, -0.3647],
          [ 0.0824,  0.0980,  0.1373,  ..., -0.6627, -0.5922, -0.4745],
          ...,
          [-0.8431, -0.8431, -0.8745,  ..., -0.7647, -0.7961, -0.8196],
          [-0.7725, -0.7804, -0.8196,  ..., -0.4588, -0.4667, -0.4353],
          [-0.5137, -0.5216, -0.5294,  ..., -0.2078, -0.2471, -0.2235]],

         [[-0.3176, -0.0510,  0.2392,  ..., -0.5451, -0.5059,  0.1373],
          [-0.1843, -0.0510,  0.1765,  ..., -0.5686, -0.6078, -0.3569],
          [-0.0431,  0.0196,  0.1137,  ..., -0.6784, -0.6314, -0.4980],
          ...,
          [-0.7569, -0.7725, -0.8118,  ..., -0.6784, -0.7098, -0.7412],
          [-0.6941, -0.7098, -0.7490,  ..., -0.2706, -0.2706, -0.2078],
          [-0.3333, -0.3412, -0.3490,  ...,  0.1059,  0.0667,  0.1059]]],


        [[[ 0.0745,  0.0745,  0.0824,  ..., -0.0196, -0.0745, -0.1216],
          [ 0.1373,  0.0902,  0.0902,  ...,  0.0431, -0.0118, -0.0824],
          [ 0.1529,  0.1137,  0.1373,  ..., -0.0039,  0.0039, -0.0902],
          ...,
          [-0.2549, -0.0510, -0.3333,  ..., -0.0980, -0.0980,  0.0118],
          [-0.2549, -0.1529, -0.3412,  ...,  0.0039, -0.1373, -0.0588],
          [-0.3725, -0.2471,  0.0039,  ...,  0.0353, -0.1294, -0.0667]],

         [[ 0.0824,  0.0902,  0.1059,  ...,  0.0745,  0.0118, -0.0431],
          [ 0.1451,  0.1059,  0.1137,  ...,  0.1373,  0.0824,  0.0039],
          [ 0.1608,  0.1294,  0.1608,  ...,  0.0902,  0.0980, -0.0039],
          ...,
          [-0.0745,  0.3725,  0.2863,  ..., -0.0824, -0.0980,  0.0118],
          [-0.0745,  0.1765, -0.0745,  ...,  0.0118, -0.1294, -0.0667],
          [-0.2471, -0.0667,  0.0667,  ...,  0.0510, -0.1294, -0.0667]],

         [[-0.4824, -0.4980, -0.5137,  ..., -0.4824, -0.5294, -0.5529],
          [-0.4118, -0.4824, -0.4980,  ..., -0.4431, -0.4667, -0.5216],
          [-0.4039, -0.4510, -0.4588,  ..., -0.5059, -0.4824, -0.5451],
          ...,
          [-0.2392,  0.5686,  0.5137,  ..., -0.6314, -0.6157, -0.4824],
          [-0.1294,  0.3647, -0.1529,  ..., -0.5216, -0.6392, -0.5373],
          [-0.4824, -0.2078, -0.4118,  ..., -0.4824, -0.6314, -0.5373]]],


        ...,


        [[[-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9843, -0.9922],
          ...,
          [-0.9922, -0.9922, -0.9843,  ..., -1.0000, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -1.0000, -0.9922, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -1.0000, -0.9922, -0.9922]],

         [[-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9765, -0.9922],
          [-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9765, -0.9922],
          [-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9843, -0.9922],
          ...,
          [-0.9922, -0.9922, -0.9843,  ..., -1.0000, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9843, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9843, -0.9843, -0.9922]],

         [[-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9765,  ..., -1.0000, -0.9765, -0.9922],
          ...,
          [-0.9922, -0.9922, -0.9843,  ..., -1.0000, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9843, -0.9843, -0.9922],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9843, -0.9922, -0.9922]]],


        [[[ 0.9843,  1.0000,  0.9922,  ...,  0.9686,  0.9843,  0.9608],
          [ 0.9451,  0.9765,  0.9686,  ...,  0.9373,  0.9529,  0.9294],
          [ 0.9451,  0.9765,  0.9765,  ...,  0.9373,  0.9529,  0.9294],
          ...,
          [ 0.5529,  0.6549,  0.6549,  ...,  0.7176,  0.7412,  0.7020],
          [ 0.6235,  0.6941,  0.6784,  ...,  0.4275,  0.4588,  0.4667],
          [ 0.6549,  0.7020,  0.6549,  ..., -0.2706, -0.2706, -0.2863]],

         [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.9765,  0.9922,  0.9686,  ...,  0.9765,  0.9922,  0.9686],
          [ 0.9843,  0.9922,  0.9608,  ...,  0.9765,  0.9922,  0.9686],
          ...,
          [ 0.2235,  0.2471,  0.2392,  ...,  0.7647,  0.7882,  0.7569],
          [ 0.2784,  0.3020,  0.3020,  ...,  0.4980,  0.5216,  0.5294],
          [ 0.3490,  0.3804,  0.3569,  ..., -0.2549, -0.2627, -0.2706]],

         [[ 1.0000,  1.0000,  1.0000,  ...,  1.0000,  1.0000,  1.0000],
          [ 0.9608,  0.9843,  0.9843,  ...,  0.9922,  1.0000,  0.9843],
          [ 0.9529,  0.9922,  0.9922,  ...,  0.9922,  1.0000,  0.9843],
          ...,
          [-0.2706, -0.2235, -0.2549,  ...,  0.8118,  0.8353,  0.7961],
          [-0.2314, -0.2471, -0.3098,  ...,  0.6314,  0.6549,  0.6627],
          [-0.1843, -0.2235, -0.3020,  ...,  0.2314,  0.2314,  0.2157]]],


        [[[-0.9686, -0.9686, -0.9686,  ..., -0.8745, -0.8902, -0.8980],
          [-0.9608, -0.9608, -0.9608,  ..., -0.8667, -0.8745, -0.8902],
          [-0.9529, -0.9529, -0.9529,  ..., -0.8588, -0.8667, -0.8824],
          ...,
          [-0.4510, -0.6471, -0.8745,  ..., -0.1294, -0.1216, -0.0824],
          [-0.4745, -0.6314, -0.8353,  ..., -0.1373, -0.0745, -0.0667],
          [-0.5922, -0.7255, -0.7176,  ..., -0.0902, -0.0824, -0.0824]],

         [[-0.9608, -0.9608, -0.9608,  ..., -0.9373, -0.9529, -0.9608],
          [-0.9529, -0.9529, -0.9529,  ..., -0.9294, -0.9373, -0.9529],
          [-0.9451, -0.9451, -0.9451,  ..., -0.9216, -0.9294, -0.9451],
          ...,
          [-0.3804, -0.5608, -0.7961,  ..., -0.3098, -0.3020, -0.2627],
          [-0.4039, -0.5529, -0.7569,  ..., -0.3098, -0.2471, -0.2471],
          [-0.5059, -0.6471, -0.6314,  ..., -0.2706, -0.2549, -0.2549]],

         [[-1.0000, -1.0000, -1.0000,  ..., -0.9608, -0.9765, -0.9843],
          [-0.9922, -0.9922, -0.9922,  ..., -0.9529, -0.9608, -0.9765],
          [-0.9843, -0.9843, -0.9843,  ..., -0.9451, -0.9529, -0.9686],
          ...,
          [-0.3569, -0.5843, -0.7961,  ..., -0.5529, -0.5451, -0.4902],
          [-0.3804, -0.5765, -0.7647,  ..., -0.5529, -0.4902, -0.4745],
          [-0.4745, -0.6471, -0.6235,  ..., -0.4902, -0.4824, -0.4667]]]]), tensor([2, 1, 0,  ..., 4, 3, 3])]


In [12]:
assert torch.round(train_datasite_obj.syft_action_data.sum()) == 1557

In [16]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

model = CNN()
model

CNN(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [17]:
weights = model.state_dict()

In [18]:
assert isinstance(weights, dict)

In [19]:
w = sy.ActionObject.from_obj(weights)

In [20]:
type(w.syft_action_data), w.id

(collections.OrderedDict, <UID: dc8cb9c102a84536a9e534dadaad47b6>)

In [21]:
weight_datasite_obj = w.send(datasite_client)

In [24]:
@sy.syft_function(
    input_policy=sy.ExactMatch(
        weights=weight_datasite_obj.id, data=train_datasite_obj.id
    ),
    output_policy=sy.SingleExecutionExactOutput(),
)
def train_cnn(weights, data):
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    class CNN(nn.Module):
        def __init__(self):
            super(CNN, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.pool = nn.MaxPool2d(2, 2)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16 * 5 * 5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, 10)

        def forward(self, x):
            x = self.pool(F.relu(self.conv1(x)))
            x = self.pool(F.relu(self.conv2(x)))
            x = x.view(-1, 16 * 5 * 5)
            x = F.relu(self.fc1(x))
            x = F.relu(self.fc2(x))
            return self.fc3(x)

    model = CNN()
    # Load model weights
    model.load_state_dict(weights)
    model.train()

    # Training logic (simplified for example)
    for inputs, labels in data:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    return model.state_dict()

optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

In [25]:
pointer = train_cnn(weights=weight_datasite_obj, data=train_datasite_obj)
output = pointer.get()

SyftInfo: Closing the server after time_alive=300 (the default value)


Logged into <ephemeral_server_train_cnn_9961: High side Datasite> as <info@openmined.org>


Approving request on change train_cnn for datasite ephemeral_server_train_cnn_9961
SyftInfo: Landing the ephmeral server...


21/07/24 18:23:19 EXCEPTION LOG:

Encountered while executing train_cnn:
Traceback (most recent call last):
  File "/opt/conda/envs/myenv/lib/python3.11/site-packages/syft/service/code/user_code.py", line 1867, in execute_byte_code
    result = eval(evil_string, _globals, _locals)  # nosec
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<string>", line 1, in <module>
  File "<string>", line 36, in user_func_train_cnn_d631739d7c7e0de853d9cf07ed46d5252963b2fe338b56a5931e2adb1dedf9ec_e6fbabfe0838aa02a9f5742e0bfe5a393dce44850d5ada16e124c38d64991aec
  File "<string>", line 29, in train_cnn
ValueError: too many values to unpack (expected 2)

    26      model.load_state_dict(weights)
    27      model.train()
--> 28      for inputs, labels in data:
    29          optimizer.zero_grad()



SyftInfo: Server Landed!


In [21]:
assert torch.allclose(torch.sum(output), torch.tensor(1.3907))

In [22]:
request = datasite_client.code.request_code_execution(train_mlp)
request

In [23]:
request.approve()

Approving request on change train_mlp for datasite test-datasite-2


In [24]:
datasite_client._api = None
_ = datasite_client.api

In [25]:
result_ptr = datasite_client.code.train_mlp(weights=w.id, data=train.id)

In [26]:
result = result_ptr.get()

In [27]:
assert torch.allclose(torch.sum(result), torch.tensor(1.3907))

In [28]:
if server.server_type.value == "python":
    server.land()

In [29]:
server.land()