In [1]:
import torch as th
import syft as sy
from torch import nn, optim

hook = sy.TorchHook(th)

In [2]:
# Create a couple of workers
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
secure_worker = sy.VirtualWorker(hook, id="secure_worker")

In [3]:
bob.add_workers([alice, secure_worker])
alice.add_workers([bob, secure_worker])
secure_worker.add_workers([alice, bob])



<VirtualWorker id:secure_worker #tensors:0>

In [4]:
# A Toy Dataset
data = th.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = th.tensor([[0],[0],[1],[1.]], requires_grad=True)

In [5]:
# get pointers to training data on each worker by
# sending some training data to bob and alice
bobs_data = data[0:2].send(bob)
bobs_target = target[0:2].send(bob)

In [6]:
alices_data = data[2:].send(alice)
alices_target = target[2:].send(alice)

In [7]:
# Initialize a toy model
model = nn.Linear(2,1)

In [8]:
for round_iter in range(10):
    
    bobs_model = model.copy().send(bob)

    alices_model = model.copy().send(alice)
    
    bobs_opt = optim.SGD(params=bobs_model.parameters(), lr=0.1)
    alices_opt = optim.SGD(params=alices_model.parameters(), lr=0.1)

    for i in range(10):

        bobs_opt.zero_grad()

        bobs_pred = bobs_model(bobs_data)

        bobs_loss = ((bobs_pred - bobs_target)**2).sum()

        bobs_loss.backward()

        bobs_opt.step()

        bobs_loss = bobs_loss.get().data
        bobs_loss

        alices_opt.zero_grad()

        alices_pred = alices_model(alices_data)

        alices_loss = ((alices_pred - alices_target)**2).sum()

        alices_loss.backward()

        alices_opt.step()

        alices_loss = alices_loss.get().data

        alices_loss

    alices_model.move(secure_worker)
    bobs_model.move(secure_worker)

    with th.no_grad():
        model.weight.set_(((alices_model.weight.data + bobs_model.weight.data) / 2).get())
        model.bias.set_(((alices_model.bias.data + bobs_model.bias.data) / 2).get())

    secure_worker.clear_objects()
    
    print("Bob:" + str(bobs_loss) + " Alice:" + str(alice_loss))

RuntimeError: set_storage is not allowed on Tensor created from .data or .detach()