In [1]:
SYFT_VERSION = ">=0.8.2.b0,<0.9"
package_string = f'"syft{SYFT_VERSION}"'
# %pip install {package_string} -q

In [2]:
# third party
import torch
import torch.nn as nn
import torch.nn.functional as F

# syft absolute
import syft as sy

sy.requires(SYFT_VERSION)

✅ The installed version of syft==0.8.7 matches the requirement >=0.8.2b0 and the requirement <0.9


In [3]:
server = sy.orchestra.launch(name="test-datasite-2", dev_mode=True, reset=True)

In [4]:
datasite_client = server.login(email="info@openmined.org", password="changethis")

Logged into <test-datasite-2: High side Datasite> as <info@openmined.org>


In [5]:
# Set the random seed for reproducibility
torch.manual_seed(42)

<torch._C.Generator at 0x7f337a1077d0>

In [6]:
# Generate random data
train_data = torch.rand((4, 28, 28, 1))
train_data.shape

torch.Size([4, 28, 28, 1])

In [7]:
assert torch.round(train_data.sum()) == 1557

In [8]:
train = sy.ActionObject.from_obj(train_data)

In [9]:
type(train.syft_action_data), train.id, train.shape

(torch.Tensor,
 <UID: 12b1876b2f8742349155b37cbf23f3e9>,
 Pointer:
 torch.Size([4, 28, 28, 1]))

In [10]:
train_datasite_obj = train.send(datasite_client)
type(train_datasite_obj)

syft.service.action.action_object.AnyActionObject

In [11]:
train_datasite_obj


**Pointer**

tensor([[[[0.8823],
          [0.9150],
          [0.3829],
          ...,
          [0.2695],
          [0.3588],
          [0.1994]],

         [[0.5472],
          [0.0062],
          [0.9516],
          ...,
          [0.9103],
          [0.6440],
          [0.7071]],

         [[0.6581],
          [0.4913],
          [0.8913],
          ...,
          [0.1591],
          [0.7653],
          [0.2979]],

         ...,

         [[0.8029],
          [0.2662],
          [0.2614],
          ...,
          [0.6683],
          [0.6779],
          [0.0837]],

         [[0.0150],
          [0.2406],
          [0.8423],
          ...,
          [0.4931],
          [0.9576],
          [0.1999]],

         [[0.5039],
          [0.7378],
          [0.1548],
          ...,
          [0.3018],
          [0.6301],
          [0.6886]]],


        [[[0.2366],
          [0.0042],
          [0.7617],
          ...,
          [0.1946],
          [0.2539],
          [0.5961]],

         [[0.6356],
          [0.6922],
          [0.7744],
          ...,
          [0.4583],
          [0.6079],
          [0.2258]],

         [[0.6442],
          [0.0118],
          [0.1422],
          ...,
          [0.9184],
          [0.8874],
          [0.6511]],

         ...,

         [[0.3880],
          [0.3186],
          [0.6964],
          ...,
          [0.6831],
          [0.2704],
          [0.9291]],

         [[0.7386],
          [0.8284],
          [0.5660],
          ...,
          [0.3465],
          [0.2419],
          [0.3392]],

         [[0.3217],
          [0.9783],
          [0.6918],
          ...,
          [0.3515],
          [0.4982],
          [0.6605]]],


        [[[0.4890],
          [0.5231],
          [0.5633],
          ...,
          [0.1664],
          [0.6741],
          [0.2001]],

         [[0.3426],
          [0.2617],
          [0.1598],
          ...,
          [0.3733],
          [0.7579],
          [0.6981]],

         [[0.9136],
          [0.3976],
          [0.7355],
          ...,
          [0.0606],
          [0.4830],
          [0.2778]],

         ...,

         [[0.9142],
          [0.9869],
          [0.4895],
          ...,
          [0.2492],
          [0.5637],
          [0.7931]],

         [[0.7489],
          [0.2792],
          [0.0273],
          ...,
          [0.3172],
          [0.9907],
          [0.9064]],

         [[0.7904],
          [0.5280],
          [0.3381],
          ...,
          [0.3225],
          [0.2363],
          [0.5819]]],


        [[[0.6404],
          [0.8829],
          [0.3870],
          ...,
          [0.5927],
          [0.2921],
          [0.3720]],

         [[0.2418],
          [0.3944],
          [0.0356],
          ...,
          [0.3073],
          [0.8442],
          [0.1894]],

         [[0.0927],
          [0.6091],
          [0.2243],
          ...,
          [0.8426],
          [0.2078],
          [0.8937]],

         ...,

         [[0.9275],
          [0.5446],
          [0.6662],
          ...,
          [0.7440],
          [0.4418],
          [0.2538]],

         [[0.5158],
          [0.5912],
          [0.3605],
          ...,
          [0.8848],
          [0.3569],
          [0.9459]],

         [[0.3884],
          [0.1984],
          [0.2835],
          ...,
          [0.2397],
          [0.4570],
          [0.2640]]]])


In [12]:
assert torch.round(train_datasite_obj.syft_action_data.sum()) == 1557

In [13]:
class MLP(nn.Module):
    def __init__(self, out_dims):
        super().__init__()
        self.out_dims = out_dims
        self.linear1 = nn.Linear(784, 128)
        self.linear2 = nn.Linear(128, out_dims)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.linear1(x)
        x = F.relu(x)
        x = self.linear2(x)
        return x


model = MLP(out_dims=10)
model

MLP(
  (linear1): Linear(in_features=784, out_features=128, bias=True)
  (linear2): Linear(in_features=128, out_features=10, bias=True)
)

In [14]:
weights = model.state_dict()

In [15]:
assert isinstance(weights, dict)

In [16]:
w = sy.ActionObject.from_obj(weights)

In [17]:
type(w.syft_action_data), w.id

(collections.OrderedDict, <UID: 8d49199e0a844c4bb460d9449e3e01b7>)

In [18]:
weight_datasite_obj = w.send(datasite_client)

In [19]:
@sy.syft_function(
    input_policy=sy.ExactMatch(
        weights=weight_datasite_obj.id, data=train_datasite_obj.id
    ),
    output_policy=sy.SingleExecutionExactOutput(),
)
def train_mlp(weights, data):
    # third party
    import torch
    import torch.nn as nn
    import torch.nn.functional as F

    class MLP(nn.Module):
        def __init__(self, out_dims):
            super().__init__()
            self.out_dims = out_dims
            self.linear1 = nn.Linear(784, 128)
            self.linear2 = nn.Linear(128, out_dims)

        def forward(self, x):
            x = x.view(x.size(0), -1)
            x = self.linear1(x)
            x = F.relu(x)
            x = self.linear2(x)
            return x

    # Initialize the model
    model = MLP(out_dims=10)

    # Load weights into the model
    model.load_state_dict(weights)

    # Perform a forward pass
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        output = model(data)

    return output

In [20]:
pointer = train_mlp(weights=weight_datasite_obj, data=train_datasite_obj)
output = pointer.get()

SyftInfo: Closing the server after time_alive=300 (the default value)


Logged into <ephemeral_server_train_mlp_102: High side Datasite> as <info@openmined.org>


Approving request on change train_mlp for datasite ephemeral_server_train_mlp_102
SyftInfo: Landing the ephmeral server...


SyftInfo: Server Landed!


In [21]:
assert torch.allclose(torch.sum(output), torch.tensor(1.3907))

In [22]:
request = datasite_client.code.request_code_execution(train_mlp)
request

In [23]:
request.approve()

Approving request on change train_mlp for datasite test-datasite-2


In [24]:
datasite_client._api = None
_ = datasite_client.api

In [25]:
result_ptr = datasite_client.code.train_mlp(weights=w.id, data=train.id)

In [26]:
result = result_ptr.get()

In [27]:
assert torch.allclose(torch.sum(result), torch.tensor(1.3907))

In [28]:
if server.server_type.value == "python":
    server.land()

In [29]:
server.land()