In [1]:
import sail
import numpy as np
import torch
import torch.nn as nn
import time

In [2]:
vm1 = sail.connect("127.0.0.1", 7001, "marine@terran.com", "sailpassword")
vm2 = sail.connect("127.0.0.1", 7002, "marine@terran.com", "sailpassword")
vmagg = sail.connect("127.0.0.1", 7003, "marine@terran.com", "sailpassword")
sublist = [vm1, vm2]

In [3]:
fn_getPartyA = sail.registerfn("fn_getPartyA.py", 1, 0, 0, 1)[0]
fn_getPartyB = sail.registerfn("fn_getPartyB.py", 1, 0, 0, 1)[0]
fn_getGrad = sail.registerfn("fn_getGrad.py", 0, 1, 1, 0)[0]
fn_updateGrad = sail.registerfn("fn_updateGradients.py", 1, 1, 0, 1)[0]
fn_agg = sail.registerfn("fn_agg.py", 1, 0, 1, 0)[0]
fn_test = sail.registerfn("fn_test.py", 1, 1, 1, 0)[0]
fn_mae = sail.resgisterfn("fn_mae.py", 2, 0, 1, 0)[0]
subfn = [fn_getGrad, fn_updateGrad, fn_test, fn_mae]

In [4]:
subfn

['2B26A10FEA954ACA800851910D07993F',
 'FD923229E05E473B9D9FB28EB49504B3',
 '2FBB442164AA456183E543433D134C2F']

In [5]:
for vm in sublist:
    for fn in subfn:
        sail.pushfn(vm, fn)
sail.pushfn(vm1, fn_getPartyA)
sail.pushfn(vm2, fn_getPartyB)
sail.pushfn(vmagg, fn_agg)

In [6]:
model = nn.Linear(3, 2)
parties = []

In [7]:
jobid1 = sail.newguid()
sail.pushdata(vm1, jobid1, fn_getPartyA, [model], [], "/home/jjj/playground")
sail.execjob(vm1, fn_getPartyA, jobid1)
result = sail.pulldata(vm1, jobid1, fn_getPartyA, "/home/jjj/playground")

In [8]:
result

[[], ['CD97826EE1C844B78C62476715A6565F61E669A4347E42C8A4E87823F2CC73F8']]

In [9]:
parties.append(result[1][0])

In [10]:
jobid2 = sail.newguid()
sail.pushdata(vm2, jobid2, fn_getPartyB, [model], [], "/home/jjj/playground")
sail.execjob(vm2, fn_getPartyB, jobid2)
result = sail.pulldata(vm2, jobid2, fn_getPartyB, "/home/jjj/playground")
parties.append(result[1][0])

In [11]:
num_epochs = 100

In [12]:
for epoch in range(num_epochs):
    if(epoch%9==0):
          print("processing round: "+str(epoch+1))
    gradlist = []
    time.sleep(0.1)
    for i in range(len(sublist)):
        jobid3 = sail.newguid()
        sail.pushdata(sublist[i], jobid3, fn_getGrad, [], [parties[i]], "/home/jjj/playground")
        sail.execjob(sublist[i], fn_getGrad, jobid3)
        result = sail.pulldata(sublist[i], jobid3, fn_getGrad, "/home/jjj/playground")
        gradlist.append(result[0][0])
    
    jobid4 = sail.newguid()
    sail.pushdata(vmagg, jobid4, fn_agg, [gradlist], [], "/home/jjj/playground")
    sail.execjob(vmagg, fn_agg, jobid4)
    result = sail.pulldata(vmagg, jobid4, fn_agg, "/home/jjj/playground")
    newgrad = result[0][0]
    
    for i in range(len(sublist)):
        jobid5 = sail.newguid()
        sail.pushdata(sublist[i], jobid5, fn_updateGrad, [newgrad], [parties[i]], "/home/jjj/playground")
        sail.execjob(sublist[i], fn_updateGrad, jobid5)
        result = sail.pulldata(sublist[i], jobid5, fn_updateGrad, "/home/jjj/playground")
        parties[i] = result[1][0]

processing round: 1
processing round: 10
processing round: 19
processing round: 28
processing round: 37
processing round: 46
processing round: 55
processing round: 64
processing round: 73
processing round: 82
processing round: 91
processing round: 100


In [13]:
all_inputs = np.array([[73, 67, 43], 
                   [91, 88, 64], 
                   [87, 134, 58], 
                   [102, 43, 37], 
                   [69, 96, 70]], dtype='float32')
all_targets = np.array([[56, 70], 
                    [81, 101], 
                    [119, 133], 
                    [22, 37], 
                    [103, 119]], dtype='float32')
all_inputs = torch.from_numpy(all_inputs)
all_targets = torch.from_numpy(all_targets)

In [14]:
preds = []
for i in range(len(sublist)):
    jobid6 = sail.newguid()
    sail.pushdata(sublist[i], jobid6, fn_test, [all_inputs], [parties[i]], "/home/jjj/playground")
    sail.execjob(sublist[i], fn_test, jobid6)
    result = sail.pulldata(sublist[i], jobid6, fn_test, "/home/jjj/playground")
    preds.append(result[0][0])

In [15]:
preds

[tensor([[ 58.4522,  73.4913],
         [ 82.6104,  99.8480],
         [115.3775, 128.1072],
         [ 29.1866,  56.8975],
         [ 97.8328, 106.1374]], requires_grad=True),
 tensor([[ 58.4522,  73.4913],
         [ 82.6104,  99.8480],
         [115.3775, 128.1072],
         [ 29.1866,  56.8975],
         [ 97.8328, 106.1374]], requires_grad=True)]

In [16]:
all_targets

tensor([[ 56.,  70.],
        [ 81., 101.],
        [119., 133.],
        [ 22.,  37.],
        [103., 119.]])

In [None]:
mae_errs = []
for i in range(len(sublist)):
    jobid7 = sail.newguid()
    sail.pushdata(sublist[i], jobid7, fn_mae, [preds[i], all_targets], [], "/home/jjj/playground")
    sail.execjob(sublist[i], fn_mae, jobid7)
    result = sail.pulldata(sublist[i], jobid7, fn_mae, "/home/jjj/playground")
    mae_errs.append(result[0][0])

In [None]:
mae_errs