In [None]:
$ pip install flwr

In [None]:
$ pip install mxnet

**Flower Client**

In [None]:
import flwr as fl

import numpy as np

import mxnet as mx
from mxnet import nd
from mxnet import gluon
from mxnet.gluon import nn
from mxnet import autograd as ag
import mxnet.ndarray as F

In [None]:
DEVICE = [mx.gpu() if mx.test_utils.list_gpus() else mx.cpu()]

In [None]:
def train(net, train_data, epoch):
    trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.03})
    trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.01})
    accuracy_metric = mx.metric.Accuracy()
    loss_metric = mx.metric.CrossEntropy()
    metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [accuracy_metric, loss_metric]:
        metrics.add(child_metric)
    softmax_cross_entropy_loss = gluon.loss.SoftmaxCrossEntropyLoss()
    for i in range(epoch):
        train_data.reset()
        num_examples = 0
        for batch in train_data:
            data = gluon.utils.split_and_load(
                batch.data[0], ctx_list=DEVICE, batch_axis=0
            )
            label = gluon.utils.split_and_load(
                batch.label[0], ctx_list=DEVICE, batch_axis=0
            )
            outputs = []
            with ag.record():
                for x, y in zip(data, label):
                    z = net(x)
                    loss = softmax_cross_entropy_loss(z, y)
                    loss.backward()
                    outputs.append(z.softmax())
                    num_examples += len(x)
            metrics.update(label, outputs)
            trainer.step(batch.data[0].shape[0])
        trainings_metric = metrics.get_name_value()
        print("Accuracy & loss at epoch %d: %s" % (i, trainings_metric))
    return trainings_metric, num_examples

In [None]:
def test(net, val_data):
    accuracy_metric = mx.metric.Accuracy()
    loss_metric = mx.metric.CrossEntropy()
    metrics = mx.metric.CompositeEvalMetric()
    for child_metric in [accuracy_metric, loss_metric]:
        metrics.add(child_metric)
    val_data.reset()
    num_examples = 0
    for batch in val_data:
        data = gluon.utils.split_and_load(batch.data[0], ctx_list=DEVICE, batch_axis=0)
        label = gluon.utils.split_and_load(
            batch.label[0], ctx_list=DEVICE, batch_axis=0
        )
        outputs = []
        for x in data:
            outputs.append(net(x).softmax())
            num_examples += len(x)
        metrics.update(label, outputs)
    return metrics.get_name_value(), num_examples

In [None]:
def main():
    def model():
        net = nn.Sequential()
        net.add(nn.Dense(256, activation="relu"))
        net.add(nn.Dense(64, activation="relu"))
        net.add(nn.Dense(10))
        net.collect_params().initialize()
        return net

    train_data, val_data = load_data()

    model = model()
    init = nd.random.uniform(shape=(2, 784))
    model(init)

In [None]:
class MNISTClient(fl.client.NumPyClient):
    def get_parameters(self, config):
        param = []
        for val in model.collect_params(".*weight").values():
            p = val.data()
            param.append(p.asnumpy())
        return param

    def set_parameters(self, parameters):
        params = zip(model.collect_params(".*weight").keys(), parameters)
        for key, value in params:
            model.collect_params().setattr(key, value)

    def fit(self, parameters, config):
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = train(model, train_data, epoch=2)
        results = {"accuracy": float(accuracy[1]), "loss": float(loss[1])}
        return self.get_parameters(config={}), num_examples, results

    def evaluate(self, parameters, config):
        self.set_parameters(parameters)
        [accuracy, loss], num_examples = test(model, val_data)
        print("Evaluation accuracy & loss", accuracy, loss)
        return float(loss[1]), val_data.batch_size, {"accuracy": float(accuracy[1])}

In [None]:
fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MNISTClient())

**Flower Server**

In [None]:
import flwr as fl

fl.server.start_server(config=fl.server.ServerConfig(num_rounds=3))