In [1]:
import torch as th
from torch import nn, optim
import syft as sy

In [None]:
hook = sy.TorchHook(th)

In [9]:
bob = sy.VirtualWorker(hook, id="bob")
alice = sy.VirtualWorker(hook, id="alice")

In [3]:
#A Toy Dataset
data = th.tensor([[1.,1],[0,1],[1,0],[0,0]], requires_grad = True)
target = th.tensor([[1.],[1], [0], [0]], requires_grad=True)

In [4]:
#A Toy Model
model = nn.Linear(2,1)

In [5]:
opt = optim.SGD(params=model.parameters(), lr=0.1)

In [6]:
def train(iterations=20):
    for iter in range(iterations):
        opt.zero_grad()
        
        pred = model(data)
        
        loss = ((pred - target)**2).sum()
        
        loss.backward()
        
        opt.step()
        
        print(loss.data)
        
        
train()

tensor(3.3852)
tensor(0.6510)
tensor(0.3045)
tensor(0.1879)
tensor(0.1209)
tensor(0.0783)
tensor(0.0509)
tensor(0.0332)
tensor(0.0217)
tensor(0.0142)
tensor(0.0094)
tensor(0.0062)
tensor(0.0041)
tensor(0.0027)
tensor(0.0018)
tensor(0.0013)
tensor(0.0009)
tensor(0.0006)
tensor(0.0004)
tensor(0.0003)


In [7]:
data_bob = data[0:2].send(bob)
target_bob = target[0:2].send(bob)

In [10]:
data_alice = data[0:2].send(alice)
target_alice = target[0:2].send(alice)

In [11]:
datasets = [(data_bob, target_bob), (data_alice, target_alice)]

In [12]:
def train(iterations=20):
    
    model = nn.Linear(2,1)
    opt = optim.SGD(params=model.parameters(), lr=0.1)

    for iter in range(iterations):
        for _data, _target in datasets:

            #send model to the data
            model = model.send(_data.location)

            #Do normal training
            opt.zero_grad()

            pred = model(_data)

            loss = ((pred - _target)**2).sum()

            loss.backward()

            opt.step()

            #Get smarter model back
            model = model.get()

            print(loss.get())

In [13]:
train()

tensor(2.5742, requires_grad=True)
tensor(0.3015, requires_grad=True)
tensor(0.2368, requires_grad=True)
tensor(0.1970, requires_grad=True)
tensor(0.1640, requires_grad=True)
tensor(0.1365, requires_grad=True)
tensor(0.1136, requires_grad=True)
tensor(0.0945, requires_grad=True)
tensor(0.0787, requires_grad=True)
tensor(0.0655, requires_grad=True)
tensor(0.0545, requires_grad=True)
tensor(0.0454, requires_grad=True)
tensor(0.0378, requires_grad=True)
tensor(0.0314, requires_grad=True)
tensor(0.0262, requires_grad=True)
tensor(0.0218, requires_grad=True)
tensor(0.0181, requires_grad=True)
tensor(0.0151, requires_grad=True)
tensor(0.0126, requires_grad=True)
tensor(0.0104, requires_grad=True)
tensor(0.0087, requires_grad=True)
tensor(0.0072, requires_grad=True)
tensor(0.0060, requires_grad=True)
tensor(0.0050, requires_grad=True)
tensor(0.0042, requires_grad=True)
tensor(0.0035, requires_grad=True)
tensor(0.0029, requires_grad=True)
tensor(0.0024, requires_grad=True)
tensor(0.0020, requi