# Federated learning with trusted aggregator

In [2]:
import syft as sy
import torch as th
hook = sy.TorchHook(th)
from torch import nn, optim



In [3]:
#create workers
viper = sy.VirtualWorker(hook, id = "viper")
quassi = sy.VirtualWorker(hook, id = "quassi")
secure_worker = sy.VirtualWorker(hook, id = "secure_worker")


In [4]:
# letting each worker know that otherworkers also exists

viper.add_workers([quassi, secure_worker])
quassi.add_workers([viper, secure_worker])
secure_worker.add_workers([viper, quassi])



<VirtualWorker id:secure_worker #objects:0>

In [5]:
# A Toy Dataset
data = th.tensor([[0,0],[0,1],[1,0],[1,1.]], requires_grad=True)
target = th.tensor([[0],[0],[1],[1.]], requires_grad=True)

In [6]:
#send data to viper and quassi

viper_data = data[:2].send(viper)
viper_target = target[:2].send(viper)

quassi_data = data[2:].send(quassi)
quassi_target = target[2:].send(quassi)

In [7]:
#create model
model = nn.Linear(2, 1)


In [8]:
# send copy of model to each worker

viper_model = model.copy().send(viper)
quassi_model = model.copy().send(quassi)

In [10]:
# optimizers
viper_opt = optim.SGD(params=viper_model.parameters(),lr=0.1)
quassi_opt = optim.SGD(params=quassi_model.parameters(),lr=0.1)

In [14]:
#train
for i in range(10):
    viper_opt.zero_grad()
    viper_pred = viper_model(viper_data)
    viper_loss = ((viper_pred - viper_target) ** 2).sum()
    viper_loss.backward()
    viper_opt.step()
    viper_loss = viper_loss.get().data
    
    
    quassi_opt.zero_grad()
    quassi_pred = quassi_model(quassi_data)
    quassi_loss = ((quassi_pred - quassi_target) ** 2).sum()
    quassi_loss.backward()
    quassi_opt.step()
    quassi_loss = quassi_loss.get().data

    print(quassi_loss)




tensor(0.0329)
tensor(0.0274)
tensor(0.0228)
tensor(0.0190)
tensor(0.0158)
tensor(0.0132)
tensor(0.0109)
tensor(0.0091)
tensor(0.0076)
tensor(0.0063)


In [15]:
# send model to secure worker
viper_model.move(secure_worker)
quassi_model.move(secure_worker)

In [16]:
#avg the model
with th.no_grad():

    model.weight.set_(((viper_model.weight.data + quassi_model.weight.data) / 2).get())
    model.bias.set_(((viper_model.bias.data + quassi_model.bias.data) / 2).get())

In [20]:
iteration = 10
worker_iteration = 5

for i in range(iteration):
    # send copy of model to each worker

    viper_model = model.copy().send(viper)
    quassi_model = model.copy().send(quassi)
    
    # optimizers
    viper_opt = optim.SGD(params=viper_model.parameters(),lr=0.1)
    quassi_opt = optim.SGD(params=quassi_model.parameters(),lr=0.1)

    for wi in range(worker_iteration):
        viper_opt.zero_grad()
        viper_pred = viper_model(viper_data)
        viper_loss = ((viper_pred - viper_target) ** 2).sum()
        viper_loss.backward()
        viper_opt.step()
        viper_loss = viper_loss.get().data


        quassi_opt.zero_grad()
        quassi_pred = quassi_model(quassi_data)
        quassi_loss = ((quassi_pred - quassi_target) ** 2).sum()
        quassi_loss.backward()
        quassi_opt.step()
        quassi_loss = quassi_loss.get().data
        
    # send model to secure worker
    viper_model.move(secure_worker)
    quassi_model.move(secure_worker)
    
    #avg the model
    with th.no_grad():

        model.weight.set_(((viper_model.weight.data + quassi_model.weight.data) / 2).get())
        model.bias.set_(((viper_model.bias.data + quassi_model.bias.data) / 2).get())
    
    print(f"Viper: {viper_loss} \t Quassi: {quassi_loss}")

Viper: 4.452786015463062e-06 	 Quassi: 8.837663756366965e-08
Viper: 3.427057208682527e-06 	 Quassi: 6.792845397285419e-08
Viper: 2.6375596462457906e-06 	 Quassi: 5.2282292273275743e-08
Viper: 2.0299207790230867e-06 	 Quassi: 4.024515831702047e-08
Viper: 1.5622615592292277e-06 	 Quassi: 3.095210843184759e-08
Viper: 1.2023338058497757e-06 	 Quassi: 2.380204477958614e-08
Viper: 9.253211601389921e-07 	 Quassi: 1.833794271988154e-08
Viper: 7.121345788618783e-07 	 Quassi: 1.4095792977286692e-08
Viper: 5.480611093844345e-07 	 Quassi: 1.0848509646166349e-08
Viper: 4.217868081468623e-07 	 Quassi: 8.354078318006941e-09


In [21]:
preds = model(data)
loss = ((preds - target) ** 2).sum()

In [22]:
print(preds)
print(target)
print(loss.data)

tensor([[0.0018],
        [0.0015],
        [0.9981],
        [0.9978]], grad_fn=<AddmmBackward>)
tensor([[0.],
        [0.],
        [1.],
        [1.]], requires_grad=True)
tensor(1.3827e-05)
