# Federated Keras MNIST Tutorial

In [None]:
#Install Tensorflow and MNIST dataset if not installed
!pip install tensorflow==2.3.1 ~/repos/openfl-fork

#Alternatively you could use the intel-tensorflow build
# !pip install intel-tensorflow==2.3.0

In [None]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

import openfl.native as fx
from openfl.federated import FederatedModel,FederatedDataSet
tf.config.run_functions_eagerly(True)
tf.random.set_seed(0)
np.random.seed(0)

In [None]:
def test_intel_tensorflow():
    """
    Check if Intel version of TensorFlow is installed
    """
    import tensorflow as tf

    print("We are using Tensorflow version {}".format(tf.__version__))

    major_version = int(tf.__version__.split(".")[0])
    if major_version >= 2:
        from tensorflow.python import _pywrap_util_port
        print("Intel-optimizations (DNNL) enabled:",
              _pywrap_util_port.IsMklEnabled())
    else:
        print("Intel-optimizations (DNNL) enabled:")

test_intel_tensorflow()

After importing the required packages, the next step is setting up our openfl workspace. To do this, simply run the `fx.init()` command as follows:

In [None]:
#Setup default workspace, logging, etc.
fx.init('keras_cnn_mnist')

Now we are ready to define our dataset and model to perform federated learning on. The dataset should be composed of a numpy arrayWe start with a simple fully connected model that is trained on the MNIST dataset. 

In [None]:
#Import and process training, validation, and test images/labels

# Set the ratio of validation imgs, can't be 0.0
VALID_PERCENT = 0.3

(X_train, y_train), (X_test, y_test) = mnist.load_data()
split_on = int((1 - VALID_PERCENT) * len(X_train))

train_images = X_train[0:split_on,:,:]
train_labels = to_categorical(y_train)[0:split_on,:]

valid_images = X_train[split_on:,:,:]
valid_labels = to_categorical(y_train)[split_on:,:]

test_images = X_test
test_labels = to_categorical(y_test)

def preprocess(images):
    #Normalize
    images = (images / 255) - 0.5
    #Flatten
    images = images.reshape((-1, 784))
    return images

# Preprocess the images.
train_images = preprocess(train_images)
valid_images = preprocess(valid_images)
test_images = preprocess(test_images)

feature_shape = train_images.shape[1]
classes = 10

class UnbalancedFederatedDataset(FederatedDataSet):
    def split(self, num_collaborators, shuffle=True, equally=False):
        train_idx = self.split_dirichlet(self.y_train, num_collaborators)
        X_train = np.array([self.X_train[idx] for idx in train_idx])
        y_train = np.array([self.y_train[idx] for idx in train_idx])
        
        valid_idx = self.split_dirichlet(self.y_valid, num_collaborators)
        X_valid = np.array([self.X_valid[idx] for idx in valid_idx])
        y_valid = np.array([self.y_valid[idx] for idx in valid_idx])
        
        return [
            FederatedDataSet(
                X_train[i],
                y_train[i],
                X_valid[i],
                y_valid[i],
                batch_size=self.batch_size,
                num_classes=self.num_classes
            ) for i in range(num_collaborators)
        ]
    
    def split_dirichlet(self, labels, num_collaborators):
        min_size = 0
        alpha = 0.5
        n = len(labels)
        while min_size < 10:
            idx_batch = [[] for _ in range(num_collaborators)]
            # for each class in the dataset
            for k in range(self.num_classes):
                idx_k = np.where(np.argmax(labels, axis=1) == k)[0]
                np.random.shuffle(idx_k)
                proportions = np.random.dirichlet(np.repeat(alpha, num_collaborators))
                ## Balance
                proportions = np.array([p * (len(idx_j) < n / num_collaborators) for p, idx_j in zip(proportions, idx_batch)])
                proportions = proportions / proportions.sum()
                proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1]
                idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k,proportions))]
                min_size = min([len(idx_j) for idx_j in idx_batch])
        return idx_batch

fl_data = UnbalancedFederatedDataset(train_images,train_labels,valid_images,valid_labels,batch_size=32,num_classes=classes)

In [None]:
from tensorflow.python.ops import standard_ops


@keras.utils.register_keras_serializable()
class FedProxOptimizer(keras.optimizers.Optimizer):
    def __init__(self, learning_rate=0.01, mu=0.01, name='FedProxOptimizer', **kwargs):
        super().__init__(name=name, **kwargs)

        self._set_hyper("learning_rate", learning_rate)
        self._set_hyper("mu", mu)

        self._lr_t = None
        self._mu_t = None

    def _prepare(self, var_list):
        self._lr_t = tf.convert_to_tensor(self._get_hyper('learning_rate'), name="lr")
        self._mu_t = tf.convert_to_tensor(self._get_hyper('mu'), name="mu")

    def _create_slots(self, var_list):
        for v in var_list:
            self.add_slot(v, "vstar")

    def _resource_apply_dense(self, grad, var):
        lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
        mu_t = tf.cast(self._mu_t, var.dtype.base_dtype)
        vstar = self.get_slot(var, "vstar")

        var_update = var.assign_sub(lr_t * (grad + mu_t * (var - vstar)))

        return tf.group(*[var_update, ])

    def _apply_sparse_shared(self, grad, var, indices, scatter_add):
        lr_t = tf.cast(self._lr_t, var.dtype.base_dtype)
        mu_t = tf.cast(self._mu_t, var.dtype.base_dtype)
        vstar = self.get_slot(var, "vstar")
        v_diff = vstar.assign(mu_t * (var - vstar), use_locking=self._use_locking)

        with tf.control_dependencies([v_diff]):
            scaled_grad = scatter_add(vstar, indices, grad)
        var_update = var.assign_sub(lr_t * scaled_grad)

        return tf.group(*[var_update, ])

    def _resource_apply_sparse(self, grad, var):
        return self._apply_sparse_shared(
            grad.values, var, grad.indices,
            lambda x, i, v: standard_ops.scatter_add(x, i, v))

    def get_config(self):
        base_config = super(FedProxOptimizer, self).get_config()
        return {
            **base_config,
            "lr": self._serialize_hyperparameter("learning_rate"),
            "mu": self._serialize_hyperparameter("mu")
        }

In [None]:
def build_model(feature_shape, classes):
    model = Sequential([
        Dense(64, input_shape=feature_shape, activation='relu'),
        Dense(64, activation='relu'),
        Dense(classes, activation='softmax')
    ])

    model.compile(optimizer=FedProxOptimizer(mu=1e-1),loss='categorical_crossentropy',metrics=['accuracy'])
    return model    

In [None]:
#Create a federated model using the build model function and dataset
fl_model = FederatedModel(build_model, data_loader=fl_data)

The `FederatedModel` object is a wrapper around your Keras, Tensorflow or PyTorch model that makes it compatible with openfl. It provides built in federated training and validation functions that we will see used below. Using it's `setup` function, collaborator models and datasets can be automatically defined for the experiment. 

In [None]:
collaborator_models = fl_model.setup(num_collaborators=10)
collaborators = {f'col{col}':collaborator_models[col] for col in range(len(collaborator_models))}#, 'three':collaborator_models[2]}

In [None]:
#Original MNIST dataset
print(f'Original training data size: {len(train_images)}')
print(f'Original validation data size: {len(valid_images)}\n')

#Collaborator one's data
print(f'Collaborator one\'s training data size: {len(collaborator_models[0].data_loader.X_train)}')
print(f'Collaborator one\'s validation data size: {len(collaborator_models[0].data_loader.X_valid)}\n')

#Collaborator two's data
print(f'Collaborator two\'s training data size: {len(collaborator_models[1].data_loader.X_train)}')
print(f'Collaborator two\'s validation data size: {len(collaborator_models[1].data_loader.X_valid)}\n')

#Collaborator three's data
#print(f'Collaborator three\'s training data size: {len(collaborator_models[2].data_loader.X_train)}')
#print(f'Collaborator three\'s validation data size: {len(collaborator_models[2].data_loader.X_valid)}')

We can see the current plan values by running the `fx.get_plan()` function

In [None]:
#Get the current values of the plan. Each of these can be overridden
import json
print(json.dumps(fx.get_plan(), indent=4, sort_keys=True))

Now we are ready to run our experiment. If we want to pass in custom plan settings, we can easily do that with the `override_config` parameter

In [None]:
#Run experiment, return trained FederatedModel
final_fl_model = fx.run_experiment(collaborators,override_config={'aggregator.settings.rounds_to_train':5})

In [None]:
#Save final model and load into keras
final_fl_model.save_native('final_model')
model = tf.keras.models.load_model('./final_model')

In [None]:
#Test the final model on our test set
model.evaluate(test_images,test_labels)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

fedprox_keras_mnist_9=[
    0.07246376574039459,
0.02415458858013153,
0.08060453087091446,
0.07828282564878464,
0.08010470867156982,
0.005873139947652817,
0.09040793776512146,
0.015303430147469044,
0.07728459686040878,
0.059869591146707535,
0.0805152952671051,
0.25764894485473633,
0.21284635365009308,
0.4530302882194519,
0.6570680737495422,
0.222004696726799,
0.26405733823776245,
0.05540896952152252,
0.44490861892700195,
0.18020154535770416,
0.15780998766422272,
0.4444444477558136,
0.509655773639679,
0.6025252342224121,
0.7282722592353821,
0.26037588715553284,
0.4994487464427948,
0.22744064033031464,
0.6214098930358887,
0.29934796690940857,
0.2415459007024765,
0.5813204646110535,
0.7065491080284119,
0.6823232173919678,
0.7837696075439453,
0.3132341504096985,
0.6730981469154358,
0.4849604368209839,
0.7503916621208191,
0.3698873817920685,
0.3035426735877991,
0.6103059649467468,
0.772879958152771,
0.7161616086959839,
0.8544502854347229,
0.3535630404949188,
0.766262412071228,
0.5783641338348389,
0.8161879777908325,
0.3977474868297577,
]
fedprox_keras_mnist_5 = [0.07246376574039459,
0.02415458858013153,
0.08060453087091446,
0.07828282564878464,
0.08010470867156982,
0.005873139947652817,
0.09040793776512146,
0.015303430147469044,
0.07728459686040878,
0.059869591146707535,
0.0805152952671051,
0.25764894485473633,
0.21200671792030334,
0.4530302882194519,
0.6560209393501282,
0.222004696726799,
0.26405733823776245,
0.05540896952152252,
0.44438642263412476,
0.18020154535770416,
0.15780998766422272,
0.4444444477558136,
0.5092359185218811,
0.6015151739120483,
0.7282722592353821,
0.2607674300670624,
0.4972436726093292,
0.22849604487419128,
0.6214098930358887,
0.2987551987171173,
0.24315619468688965,
0.5813204646110535,
0.7078085541725159,
0.6828283071517944,
0.7842931747436523,
0.3132341504096985,
0.6736493706703186,
0.4875989556312561,
0.7519582509994507,
0.3698873817920685,
0.3027375340461731,
0.6103059649467468,
0.7737195491790771,
0.715656578540802,
0.8554973602294922,
0.35277995467185974,
0.7668136954307556,
0.5778363943099976,
0.8177545666694641,
0.39715471863746643,
]
fedprox_keras_mnist_0 = [0.07246376574039459,
0.02415458858013153,
0.08060453087091446,
0.07828282564878464,
0.08010470867156982,
0.005873139947652817,
0.09040793776512146,
0.015303430147469044,
0.07728459686040878,
0.059869591146707535,
0.0805152952671051,
0.25764894485473633,
0.21200671792030334,
0.4530302882194519,
0.6565445065498352,
0.222004696726799,
0.26405733823776245,
0.05540896952152252,
0.44438642263412476,
0.18020154535770416,
0.15780998766422272,
0.44605475664138794,
0.510075569152832,
0.6020202040672302,
0.727748692035675,
0.2607674300670624,
0.498346209526062,
0.23007915914058685,
0.6214098930358887,
0.2987551987171173,
0.24235104024410248,
0.5813204646110535,
0.7057095170021057,
0.6823232173919678,
0.7842931747436523,
0.31284260749816895,
0.6719955801963806,
0.48601582646369934,
0.7503916621208191,
0.3692946135997772,
0.3027375340461731,
0.6119162440299988,
0.7749789953231812,
0.7161616086959839,
0.8554973602294922,
0.3531714975833893,
0.766262412071228,
0.577308714389801,
0.8177545666694641,
0.3977474868297577,]
fedprox_keras_mnist_2 = [
    0.07246376574039459,
0.02415458858013153,
0.08060453087091446,
0.07828282564878464,
0.08010470867156982,
0.005873139947652817,
0.09040793776512146,
0.015303430147469044,
0.07728459686040878,
0.059869591146707535,
0.07729468494653702,
0.25764894485473633,
0.2086481899023056,
0.4515151381492615,
0.6575916409492493,
0.222004696726799,
0.2596471905708313,
0.0548812672495842,
0.44281983375549316,
0.18079431354999542,
0.15458936989307404,
0.4412238299846649,
0.49958017468452454,
0.5944444537162781,
0.7230366468429565,
0.2599843442440033,
0.4807056188583374,
0.22005276381969452,
0.6146214008331299,
0.29401305317878723,
0.23268921673297882,
0.5748792290687561,
0.6952140927314758,
0.6752524971961975,
0.7790575623512268,
0.31205952167510986,
0.652701199054718,
0.4765171408653259,
0.7342036366462708,
0.3580320179462433,
0.2946859896183014,
0.6054750680923462,
0.7611250877380371,
0.7085858583450317,
0.8476439714431763,
0.3484729826450348,
0.7519294619560242,
0.5604221820831299,
0.8052219152450562,
0.3900415003299713,
]
fedprox_keras_mnist_1 = [
    0.07246376574039459,
0.02415458858013153,
0.08060453087091446,
0.07828282564878464,
0.08010470867156982,
0.005873139947652817,
0.09040793776512146,
0.015303430147469044,
0.07728459686040878,
0.059869591146707535,
0.06763284653425217,
0.249597430229187,
0.17800167202949524,
0.43282827734947205,
0.6612565517425537,
0.222004696726799,
0.21885336935520172,
0.0480211079120636,
0.4281984269618988,
0.17427386343479156,
0.10225442796945572,
0.34621578454971313,
0.3467674255371094,
0.5272727012634277,
0.6958115100860596,
0.23688332736492157,
0.3450937271118164,
0.10395778715610504,
0.5039164423942566,
0.22762300074100494,
0.14895330369472504,
0.483091801404953,
0.507556676864624,
0.5747475028038025,
0.7036648988723755,
0.2760375738143921,
0.42833516001701355,
0.29445910453796387,
0.550391674041748,
0.2708950936794281,
0.15056361258029938,
0.4911433160305023,
0.5457598567008972,
0.5974747538566589,
0.7125654220581055,
0.27760374546051025,
0.47629547119140625,
0.26596304774284363,
0.580156683921814,
0.2898636758327484,
]

plt.figure(figsize=(9,6), dpi=150)
plt.title('Keras MNIST unbalanced split')
plt.plot([np.mean(round_acc) for round_acc in np.array_split(fedprox_keras_mnist_0, 5)], linestyle='--', label='FedAvg')
plt.plot([np.mean(round_acc) for round_acc in np.array_split(fedprox_keras_mnist_9, 5)], label='FedProx (mu=1e-9)')
plt.plot([np.mean(round_acc) for round_acc in np.array_split(fedprox_keras_mnist_5, 5)], label='FedProx (mu=1e-5)')
plt.plot([np.mean(round_acc) for round_acc in np.array_split(fedprox_keras_mnist_2, 5)], label='FedProx (mu=1e-2)')
plt.plot([np.mean(round_acc) for round_acc in np.array_split(fedprox_keras_mnist_1, 5)], label='FedProx (mu=1e-1)')

plt.legend()
plt.xticks(range(5))
plt.show()