In [8]:
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10
import torch.nn as nn
import torch.optim as optim

import syft as sy

SYFT_VERSION = ">=0.8.2.b0,<0.9"
sy.requires(SYFT_VERSION)

✅ The installed version of syft==0.8.7 matches the requirement >=0.8.2b0 and the requirement <0.9


In [9]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # Normalize the images
])

batch_size = 32

trainset = CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)


Files already downloaded and verified
Files already downloaded and verified


In [10]:
server = sy.orchestra.launch(name="test-datasite-4", dev_mode=True, reset=True)
datasite_client = server.login(email="info@openmined.org", password="changethis")

Logged into <test-datasite-4: High side Datasite> as <info@openmined.org>


In [11]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        x = torch.flatten(x, 1)  # Flatten all dimensions except the batch dimension
        x = nn.functional.relu(self.fc1(x))
        x = nn.functional.relu(self.fc2(x))
        x = self.fc3(x)
        return x

net = SimpleCNN()

In [29]:
# third party
from result import Err
from result import Ok

# syft absolute
from syft.client.api import AuthedServiceContext
from syft.client.api import ServerIdentity


class CustomExactMatch(sy.CustomInputPolicy):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        pass

    def filter_kwargs(self, kwargs, context, code_item_id):
        # stdlib

        try:
            allowed_inputs = self.allowed_ids_only(
                allowed_inputs=self.inputs, kwargs=kwargs, context=context
            )
            results = self.retrieve_from_db(
                code_item_id=code_item_id,
                allowed_inputs=allowed_inputs,
                context=context,
            )
        except Exception as e:
            return Err(str(e))
        return results

    def retrieve_from_db(self, code_item_id, allowed_inputs, context):
        # syft absolute
        from syft import ServerType
        from syft.service.action.action_object import TwinMode

        action_service = context.server.get_service("actionservice")
        code_inputs = {}

        # When we are retrieving the code from the database, we need to use the server's
        # verify key as the credentials. This is because when we approve the code, we
        # we allow the private data to be used only for this specific code.
        # but we are not modifying the permissions of the private data

        root_context = AuthedServiceContext(
            server=context.server, credentials=context.server.verify_key
        )
        if context.server.server_type == ServerType.DATASITE:
            for var_name, arg_id in allowed_inputs.items():
                kwarg_value = action_service._get(
                    context=root_context,
                    uid=arg_id,
                    twin_mode=TwinMode.NONE,
                    has_permission=True,
                )
                if kwarg_value.is_err():
                    return Err(kwarg_value.err())
                code_inputs[var_name] = kwarg_value.ok()
        else:
            raise Exception(
                f"Invalid Server Type for Code Submission:{context.server.server_type}"
            )
        return Ok(code_inputs)

    def allowed_ids_only(
        self,
        allowed_inputs,
        kwargs,
        context,
    ):
        # syft absolute
        from syft import ServerType
        from syft import UID

        if context.server.server_type == ServerType.DATASITE:
            server_identity = ServerIdentity(
                server_name=context.server.name,
                server_id=context.server.id,
                verify_key=context.server.signing_key.verify_key,
            )
            allowed_inputs = allowed_inputs.get(server_identity, {})
        else:
            raise Exception(
                f"Invalid Server Type for Code Submission:{context.server.server_type}"
            )
        filtered_kwargs = {}
        for key in allowed_inputs.keys():
            if key in kwargs:
                value = kwargs[key]
                uid = value
                if not isinstance(uid, UID):
                    uid = getattr(value, "id", None)

                if uid != allowed_inputs[key]:
                    raise Exception(
                        f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}"
                    )
                filtered_kwargs[key] = value
        return filtered_kwargs

    def _is_valid(
        self,
        context,
        usr_input_kwargs,
        code_item_id,
    ):
        return Ok(True)


def allowed_ids_only(
    self,
    allowed_inputs,
    kwargs,
    context,
):
    # syft absolute
    from syft import ServerType
    from syft import UID
    from syft.client.api import ServerIdentity

    if context.server.server_type == ServerType.DATASITE:
        server_identity = ServerIdentity(
            server_name=context.server.name,
            server_id=context.server.id,
            verify_key=context.server.signing_key.verify_key,
        )
        allowed_inputs = allowed_inputs.get(server_identity, {})
    else:
        raise Exception(
            f"Invalid Server Type for Code Submission:{context.server.server_type}"
        )
    filtered_kwargs = {}
    for key in allowed_inputs.keys():
        if key in kwargs:
            value = kwargs[key]
            uid = value
            if not isinstance(uid, UID):
                uid = getattr(value, "id", None)

            if uid != allowed_inputs[key]:
                raise Exception(
                    f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}"
                )
            filtered_kwargs[key] = value
    return filtered_kwargs


In [44]:
@sy.syft_function_single_use()
def train_batch(*args, **kwargs):
    print("Arguments received:", args)
    print("Keyword arguments received:", kwargs)
    return "test"
    # # Zero the parameter gradients
    # optimizer.zero_grad()

    # # Forward + backward + optimize
    # outputs = net(inputs)
    # loss = criterion(outputs, labels)
    # loss.backward()
    # optimizer.step()

    # return {
    #     "loss": loss.item(),
    #     "gradients": {name: param.grad for name, param in net.named_parameters()}
    # }
    
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

In [45]:
for epoch in range(1):  # Loop over the dataset multiple times
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs_sy = sy.ActionObject.from_obj(data)
        
        pointer = train_batch(inputs_sy)
        
        request = datasite_client.code.request_code_execution(train_batch)
        request.approve()
        result_ptr = datasite_client.code.train_batch(weights=inputs_sy.id, data=labels_sy.id)
        batch_result = result_ptr.get()

        
        running_loss += batch_result["loss"]


print('Finished Training')

SyftInfo: Closing the server after time_alive=300 (the default value)


Logged into <ephemeral_server_train_batch_3379: High side Datasite> as <info@openmined.org>


Approving request on change train_batch for datasite ephemeral_server_train_batch_3379
SyftInfo: Landing the ephmeral server...
Approving request on change train_batch for datasite test-datasite-4


AttributeError: You have tried accessing `get` on a SyftError with message: Failed to run. 'NoneType' object has no attribute 'items', Traceback (most recent call last):
  File "/opt/conda/envs/myenv/lib/python3.11/site-packages/syft/service/code/user_code_service.py", line 590, in _call
    action_service._user_code_execute(
  File "/opt/conda/envs/myenv/lib/python3.11/site-packages/syft/service/action/action_service.py", line 437, in _user_code_execute
    for key, kwarg_value in filtered_kwargs.items():
                            ^^^^^^^^^^^^^^^^^^^^^
AttributeError: 'NoneType' object has no attribute 'items'


SyftInfo: Server Landed!


In [None]:
for inputs, labels in train_data:
    print(inputs.shape, labels.shape)

ValueError: too many values to unpack (expected 2)

In [6]:
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the network on the 10000 test images: {100 * correct / total:.2f}%')

Accuracy of the network on the 10000 test images: 59.73%
