## Preparation: start the websocket server workers

Each worker is represented by two parts, a local handle (websocket client worker) and the remote instance that holds the data and performs the computations. The remote part is called a websocket server worker.

So first, we need to create the remote workers. For this, you need to run in a terminal (not possible from the notebook):

```bash
python start_websocket_servers.py
```

## Setting up the websocket client workers

We first need to perform the imports and setup some arguments and variables.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
import syft as sy
from syft.workers import WebsocketClientWorker
import torch
from torchvision import datasets, transforms

from syft.frameworks.torch.federated import utils

In [3]:
import run_websocket_client as rwc

In [4]:
args = rwc.define_and_get_arguments(args=[])
use_cuda = args.cuda and torch.cuda.is_available()
torch.manual_seed(args.seed)
device = torch.device("cuda" if use_cuda else "cpu")
print(args)

Namespace(batch_size=64, cuda=False, epochs=2, federate_after_n_batches=50, lr=0.01, save_model=False, seed=1, test_batch_size=1000, use_virtual=False, verbose=False)


Now let's instantiate the websocket client workers, our local access point to the remote workers.
Note that **this step will fail, if the websocket server workers are not running**.

In [5]:
hook = sy.TorchHook(torch)

kwargs_websocket = {"host": "http://ec2-13-233-99-209.ap-south-1.compute.amazonaws.com", "hook": hook, "verbose": args.verbose}
#kwargs_websocket = {"host":"172.20.10.2", "hook": hook, "verbose": args.verbose}
alice = WebsocketClientWorker(id="alice", port=8777, **kwargs_websocket)
bob = WebsocketClientWorker(id="bob", port=8778, **kwargs_websocket)
charlie = WebsocketClientWorker(id="charlie", port=8779, **kwargs_websocket)

workers = [alice, bob, charlie]
print(workers)


WebSocketAddressException: [Errno -2] Name or service not known

In [None]:
federated_train_loader = sy.FederatedDataLoader(
    datasets.MNIST(
        "../data",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ).federate(tuple(workers)),
    batch_size=args.batch_size,
    shuffle=True,
    iter_per_worker=True
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../data",
        train=False,
        transform=transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        ),
    ),
    batch_size=args.test_batch_size,
    shuffle=True
)


In [None]:
model = rwc.Net().to(device)
print(model)

In [None]:
def train(model, device, federated_train_loader, lr, federate_after_n_batches):
    model.train()

    nr_batches = federate_after_n_batches

    models = {}
    loss_values = {}

    iter(federated_train_loader)  # initialize iterators
    batches = rwc.get_next_batches(federated_train_loader, nr_batches)
    counter = 0

    while True:
        print("Starting training round, batches [{}, {}]".format(counter, counter + nr_batches))
        data_for_all_workers = True
        for worker in batches:
            curr_batches = batches[worker]
            if curr_batches:
                models[worker], loss_values[worker] = rwc.train_on_batches(
                    worker, curr_batches, model, device, lr
                )
            else:
                data_for_all_workers = False
        counter += nr_batches
        if not data_for_all_workers:
            logger.debug("At least one worker ran out of data, stopping.")
            break

        model = utils.federated_avg(models)
        batches = rwc.get_next_batches(federated_train_loader, nr_batches)
    return model

In [None]:
import logging
FORMAT = "%(asctime)s %(levelname)s %(filename)s(l:%(lineno)d) - %(message)s"
LOG_LEVEL = logging.DEBUG
logging.basicConfig(format=FORMAT, level=LOG_LEVEL)
logger = logging.getLogger("main")

In [None]:
for epoch in range(1, args.epochs + 1):
    print("Starting epoch {}/{}".format(epoch, args.epochs))
    model = train(model, device, federated_train_loader, args.lr, args.federate_after_n_batches)
    rwc.test(model, device, test_loader)