In [4]:
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import time
import numpy as np

class Parser:
    """Parameters for training"""
    def __init__(self):
        self.epochs = 10
        self.lr = 0.001
        self.test_batch_size = 8
        self.batch_size = 8
        self.log_interval = 10
        self.seed = 1
    
args = Parser()

torch.manual_seed(args.seed)
kwargs = {}

In [5]:
with open('./data/boston_housing.pickle','rb') as f:
    ((X, y), (X_test, y_test)) = pickle.load(f)

X = torch.from_numpy(X).float()
y = torch.from_numpy(y).float()
X_test = torch.from_numpy(X_test).float()
y_test = torch.from_numpy(y_test).float()
# preprocessing
mean = X.mean(0, keepdim=True)
dev = X.std(0, keepdim=True)
mean[:, 3] = 0.
dev[:, 3] = 1.
X = (X - mean) / dev
X_test = (X_test - mean) / dev
train = TensorDataset(X, y)
test = TensorDataset(X_test, y_test)
train_loader = DataLoader(train, batch_size=args.batch_size, shuffle=True, **kwargs)
test_loader = DataLoader(test, batch_size=args.test_batch_size, shuffle=True, **kwargs)

In [6]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(13, 32)
        self.fc2 = nn.Linear(32, 24)
        self.fc4 = nn.Linear(24, 16)
        self.fc3 = nn.Linear(16, 1)

    def forward(self, x):
        x = x.view(-1, 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc4(x))
        x = self.fc3(x)
        return x

In [7]:
import syft as sy
from syft.frameworks.torch.federated import utils
from syft.workers.websocket_client import WebsocketClientWorker

hook = sy.TorchHook(torch)
# bob = sy.VirtualWorker(hook, id="bob")
# alice = sy.VirtualWorker(hook, id="alice")
james = sy.VirtualWorker(hook, id="james")

kwargs_websocket = {"host": "localhost", "hook": hook}
alice = WebsocketClientWorker(id='alice', port=8779, **kwargs_websocket)
bob = WebsocketClientWorker(id='bob', port=8778, **kwargs_websocket)

compute_nodes = [bob, alice]

In [8]:
remote_dataset = (list(), list())

train_distributed_dataset = []

for batch_idx, (data,target) in enumerate(train_loader):
    data = data.send(compute_nodes[batch_idx % len(compute_nodes)])
    target = target.send(compute_nodes[batch_idx % len(compute_nodes)])
    remote_dataset[batch_idx % len(compute_nodes)].append((data, target))

def update(data, target, model, optimizer):
    model.send(data.location)
    optimizer.zero_grad()
    pred = model(data)
    loss = F.mse_loss(pred.view(-1), target)
    loss.backward()
    optimizer.step()
    return model

bobs_model = Net()
alices_model = Net()
model = Net()

bobs_optimizer = optim.SGD(bobs_model.parameters(), lr=args.lr)
alices_optimizer = optim.SGD(alices_model.parameters(), lr=args.lr)

models = [bobs_model, alices_model]
optimizers = [bobs_optimizer, alices_optimizer]

In [9]:
import copy

def train():
    for data_index in range(len(remote_dataset[0])-1):
        
        # update remote models
        for remote_index in range(len(compute_nodes)):
            data, target = remote_dataset[remote_index][data_index]
            models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])
            
        for model in models:
            model.get()
            
        model = utils.federated_avg({
            "bob": models[0],
            "alice": models[1]
        })

In [10]:
def test():
    models[1].eval()
    test_loss = 0
    for data, target in test_loader:
        output = models[0](data)
        test_loss += F.mse_loss(output.view(-1), target, reduction='sum').item() # sum up batch loss
        pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
        
    test_loss /= len(test_loader.dataset)
    print('Test set: Average loss: {:.4f}\n'.format(test_loss))

In [11]:
t = time.time()

for epoch in range(args.epochs):
    print(f"Epoch {epoch + 1}")
    train()
    test()

    
total_time = time.time() - t
print('Total', round(total_time, 2), 's')

Epoch 1
Test set: Average loss: 491.0737

Epoch 2
Test set: Average loss: 23.7595

Epoch 3
Test set: Average loss: 20.5366

Epoch 4
Test set: Average loss: 20.3751

Epoch 5
Test set: Average loss: 20.6716

Epoch 6
Test set: Average loss: 21.0246

Epoch 7
Test set: Average loss: 21.5361

Epoch 8
Test set: Average loss: 21.8080

Epoch 9
Test set: Average loss: 22.3076

Epoch 10
Test set: Average loss: 23.0067

Total 62.1 s
