In [1]:
import torch as th
from torch import nn, optim
import syft as sy

In [2]:
hook = sy.TorchHook(th)

In [3]:
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")
jon = sy.VirtualWorker(hook, id="jon")

In [4]:
#A Toy Dataset
data = th.tensor([[1.,1],[0,1],[1,0],[0,0]], requires_grad = True)
target = th.tensor([[1.],[1], [0], [0]], requires_grad=True)

In [5]:
data_bob = data[0:2].send(bob)
target_bob = target[0:2].send(bob)

In [6]:
data_alice = data[0:2].send(alice)
target_alice = data[0:2].send(alice)

In [7]:
datasets = [(data_bob, target_bob), (data_alice, target_alice)]

In [8]:
def train(iterations=20):
    
    model = nn.Linear(2,1)
    opt = optim.SGD(params=model.parameters(), lr=0.1)
    
    for iter in range(iterations):
        for _data, _target in datasets:
            
            #send model to the data
            model = model.send(_data.location)
            
            #Do normal training
            opt.zero_grad()
            
            pred = model(_data)
            
            loss = ((pred - _target)**2).sum()
            
            loss.backward()
            
            opt.step()
            
            #Get smarter model back
            model = model.get()
            
            print(loss.get())

In [9]:
train()

tensor(10.2006, requires_grad=True)
tensor(0.7412, requires_grad=True)
tensor(0.1211, requires_grad=True)
tensor(0.8294, requires_grad=True)
tensor(0.3117, requires_grad=True)
tensor(0.7673, requires_grad=True)
tensor(0.3064, requires_grad=True)
tensor(0.7407, requires_grad=True)
tensor(0.3169, requires_grad=True)
tensor(0.7217, requires_grad=True)
tensor(0.3245, requires_grad=True)
tensor(0.7088, requires_grad=True)
tensor(0.3308, requires_grad=True)
tensor(0.6997, requires_grad=True)
tensor(0.3357, requires_grad=True)
tensor(0.6933, requires_grad=True)
tensor(0.3396, requires_grad=True)
tensor(0.6887, requires_grad=True)
tensor(0.3426, requires_grad=True)
tensor(0.6854, requires_grad=True)
tensor(0.3448, requires_grad=True)
tensor(0.6830, requires_grad=True)
tensor(0.3466, requires_grad=True)
tensor(0.6812, requires_grad=True)
tensor(0.3479, requires_grad=True)
tensor(0.6799, requires_grad=True)
tensor(0.3489, requires_grad=True)
tensor(0.6789, requires_grad=True)
tensor(0.3497, requ