In [1]:
import os
from pathlib import Path
from multiprocessing import Process
import syft as sy
from syft.workers.websocket_client import WebsocketClientWorker
import torch
import argparse
import logging

In [3]:
path, filename = os.path.split(os.getcwd())
log_file = os.path.join('.', 'log_files', '{}.log'.format(filename))
logging.basicConfig(filename=log_file,  format="%(asctime)s: %(message)s", datefmt='%Y-%m-%d %H:%M:%S', level=logging.INFO)

In [4]:
hook = sy.TorchHook(torch)

In [5]:
def exec_pipeline(socket_pipe):
    a = torch.tensor([1, 2, 3]).tag("a")
    b = torch.tensor([3, 2, 1]).tag("b")
    logging.info("CLIENT - Created local tensors:")
    logging.info("CLIENT - A: [{}]".format(' '.join(map(str, a.numpy()))))
    logging.info("CLIENT - B: [{}]".format(' '.join(map(str, b.numpy()))))

    a_at_server = a.send(socket_pipe)
    b_at_server = b.send(socket_pipe)
    logging.info("CLIENT - Sent tensors to remote node")

    c_at_server = a_at_server + b_at_server
    logging.info("CLIENT - Computed sum on remote node")

    c_at_local = c_at_server.get()
    logging.info("CLIENT - Get the result to local, removing from remote")
    logging.info("CLIENT - Final result: [{}]".format(' '.join(map(str, c_at_local.numpy()))))

In [6]:
def start_proc(participant, kwargs):
    """ helper function for spinning up a websocket participant """

    def target():
        socket_pipe = participant(**kwargs)
        logging.info("CLIENT - Connected")
        exec_pipeline(socket_pipe)

    p = Process(target=target)
    p.start()

    return p


In [7]:
kwargs = {
        "id": "fed",
        "host": "0.0.0.0",
        "port": "8769",
        "hook": hook,
        "verbose": False,
         }

In [8]:
logging.info("CLIENT - Started in host:{}, port:{}".format(kwargs['host'], kwargs['port']))
logging.info("CLIENT - Local mode")
start_proc(WebsocketClientWorker, kwargs)

<Process(Process-1, started)>