In [21]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F

import syft as sy

SYFT_VERSION = ">=0.8.2.b0,<0.9"
sy.requires(SYFT_VERSION)

✅ The installed version of syft==0.8.7 matches the requirement >=0.8.2b0 and the requirement <0.9


In [22]:
# Transformations for the CIFAR10 data
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Load CIFAR10 training data
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

# Split the dataset between two workers
trainloader = torch.utils.data.DataLoader(trainset, batch_size=len(trainset)//20, shuffle=True)

# Initialize the server and two workers
server = sy.orchestra.launch(name="test-datasite-1", dev_mode=True, reset=True)
datasite_client = server.login(email="info@openmined.org", password="changethis")
datasite_client.register(
    name="Worker 1",
    email="w1@student.tuwien.ac.at",
    password="abc123",
    password_verify="abc123",
    institution="TU Wien",
    website="https://www.tuwien.ac.at/",
)


worker1 = server.login(email="w1@student.tuwien.ac.at", password="abc123")

# Distribute the data to the two workers
data_batches = iter(trainloader)
train1 = sy.ActionObject.from_obj(next(data_batches))
train_data_worker1 = train1.send(datasite_client)
display(train_data_worker1.id)

Files already downloaded and verified


Logged into <test-datasite-1: High side Datasite> as <info@openmined.org>


Logged into <test-datasite-1: High side Datasite> as <w1@student.tuwien.ac.at>


<UID: ab67e453d9c54acea581b87537c2d106>

In [23]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return self.fc3(x)

model = CNN()
weights = model.state_dict()
w = sy.ActionObject.from_obj(weights)
display(type(w.syft_action_data), w.id)
weight_datasite_obj1 = w.send(datasite_client)
display(weight_datasite_obj1.id)

collections.OrderedDict

<UID: f704cda451014c87bdf435b2f16f7ecd>

<UID: f704cda451014c87bdf435b2f16f7ecd>

In [24]:
@sy.syft_function(
    input_policy=sy.ExactMatch(weights=weight_datasite_obj1.id, data=train_data_worker1.id),
    output_policy=sy.SingleExecutionExactOutput(),
)
def train_cnn_epoch_w1(weights, data):
    # Load model weights
    model.load_state_dict(weights)
    model.train()

    # Training logic (simplified for example)
    for inputs, labels in data:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    return model.state_dict()

# Initialize optimizer and loss function here (not shown for brevity)
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

def average_weights(weights1, weights2):
    average_weights = {}
    for key in weights1.keys():
        # Ensure the weights are on the same device (CPU in this case)
        average_weights[key] = (weights1[key].cpu() + weights2[key].cpu()) / 2
    return average_weights


In [25]:
# Assuming the model and data pointers are already set up
display(weight_datasite_obj1.id, train_data_worker1.id)
pointer_w1 = train_cnn_epoch_w1(weights=weight_datasite_obj1, data=train_data_worker1)
weights_worker1 = pointer_w1.get()
# Aggregate the updates from both workers (simplified example)
display(weights_worker1)

<UID: f704cda451014c87bdf435b2f16f7ecd>

<UID: ab67e453d9c54acea581b87537c2d106>

SyftInfo: Closing the server after time_alive=300 (the default value)


Logged into <ephemeral_server_train_cnn_epoch_w1_7968: High side Datasite> as <info@openmined.org>


AttributeError: 'str' object has no attribute 'id'

In [24]:
if server.server_type.value == "python":
    server.land()
