# Federated Keras MNIST Tutorial
## Using low-level Python API

In [1]:
# %load_ext autoreload
# %autoreload 2

In [2]:
# Install dependencies if not already installed
!pip install tensorflow==2.3.1
!pip install scikit-image
!pip install cloudpickle





### Describe the model and optimizer

In [3]:
import numpy as np
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras import backend as K
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, Flatten, Dense
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

In [4]:
"""
A keras model
"""
feature_shape = 784
classes = 10

model = Sequential()
model.add(Dense(64, input_shape=(784,), activation='relu'))
model.add(Dense(64, activation='relu'))
model.add(Dense(classes, activation='softmax'))
    
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'],)

### Prepare data

We ask user to keep all the test data in `data/` folder under the workspace as it will not be sent to collaborators

In [5]:
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

In [6]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()

VALID_PERCENT = 0.3

split_on = int((1 - VALID_PERCENT) * len(X_train))

train_images = X_train[0:split_on,:,:]
train_labels = to_categorical(y_train)[0:split_on,:]

valid_images = X_train[split_on:,:,:]
valid_labels = to_categorical(y_train)[split_on:,:]

test_images = X_test
test_labels = to_categorical(y_test)

def preprocess(images):
    #Normalize
    images = (images / 255) - 0.5
    #Flatten
    images = images.reshape((-1, 784))
    return images

# Preprocess the images.
train_images = preprocess(train_images)
valid_images = preprocess(valid_images)

## Describing FL experiment

In [7]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register model

In [8]:
framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, framework_plugin=framework_adapter)

### Register dataset

We extract User dataset class implementation.
Is it convinient?
What if the dataset is not a class?

In [9]:
class FedDataset(DataInterface):
    """
    The set of initialization parameters for the FedDataset can be customized.
    """
    def __init__(self, train_images, train_labels, valid_images, valid_labels, **kwargs):
        self.X_train = train_images
        self.y_train = train_labels
        self.X_valid = valid_images
        self.y_valid = valid_labels
        self.batch_size = 32
        self.kwargs = kwargs
    
    def _delayed_init(self, data_path='1,1'):
        # With the next command the local dataset will be loaded on the collaborator node
        # For this example we have the same dataset on the same path, and we will shard it
        # So we use `data_path` information for this purpose.
        self.rank, self.world_size = [int(part) for part in data_path.split(',')]
        
        # Do the actual sharding
        self._do_sharding( self.rank, self.world_size)
        
    def _do_sharding(self, rank, world_size):
        # This method relies on the dataset's implementation
        # i.e. coupled in a bad way
        self.X_train = self.X_train[ rank-1 :: world_size ]
        self.y_train = self.y_train[ rank-1 :: world_size ]
        self.X_valid = self.X_valid[ rank-1 :: world_size ]
        self.y_valid = self.y_valid[ rank-1 :: world_size ]
        

    def get_train_loader(self, batch_size=None):
        """
        Get training data loader.
        Returns
        -------
        loader object
        """
        return self.X_train

    def get_valid_loader(self, batch_size=None):
        """
        Get validation data loader.
        Returns:
            loader object
        """
        return self.X_valid

    def get_train_data_size(self):
        """
        Get total number of training samples.
        Returns:
            int: number of training samples
        """
        return self.X_train.shape[0]

    def get_valid_data_size(self):
        """
        Get total number of validation samples.
        Returns:
            int: number of validation samples
        """
        return self.X_valid.shape[0]

    @staticmethod
    def _batch_generator(X, y, idxs, batch_size, num_batches):
        """
        Generate batch of data.
        Args:
            X: input data
            y: label data
            idxs: The index of the dataset
            batch_size: The batch size for the data loader
            num_batches: The number of batches
        Yields:
            tuple: input data, label data
        """
        for i in range(num_batches):
            a = i * batch_size
            b = a + batch_size
            yield X[idxs[a:b]], y[idxs[a:b]]

    def _get_batch_generator(self, X, y, batch_size):
        """
        Return the dataset generator.
        Args:
            X: input data
            y: label data
            batch_size: The batch size for the data loader
        """
        if batch_size is None:
            batch_size = self.batch_size

        # shuffle data indices
        idxs = np.random.permutation(np.arange(X.shape[0]))

        # compute the number of batches
        num_batches = int(np.ceil(X.shape[0] / batch_size))

        # build the generator and return it
        return self._batch_generator(X, y, idxs, batch_size, num_batches)
    
fed_dataset = FedDataset(train_images=train_images,
                         train_labels=train_labels,
                         valid_images=valid_images,
                         valid_labels=valid_labels)

### Register tasks

In [10]:
TI = TaskInterface()

# We do not actually need to register additional kwargs, Just serialize them
@TI.add_kwargs(**{'batch_size': 32})
@TI.register_fl_task(model='model', data_loader='train_loader')
@TI.send_model()
def train(model, train_loader, batch_size=1):
    
    history = model.fit(train_loader.X_train,
                        train_loader.y_train,
                        batch_size=train_loader.batch_size,
                                 epochs=1,
                                 verbose=0, )
    
    return {str(metric_name): np.mean([history.history[metric_name]]) for metric_name in model.metrics_names}


@TI.register_fl_task(model='model', data_loader='val_loader')     
def validate(model, val_loader):
    
    vals = model.evaluate(
            val_loader.X_valid,
            val_loader.y_valid,
            batch_size=32,
            verbose=0
    )
    
    return {'accuracy': np.mean(vals[1]),}

    
    


## Time to start a federated learning experiment

In [11]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation
# will determine fqdn by itself
federation = Federation(central_node_fqdn='localhost', disable_tls=True)
col_data_paths = {'one': '1,2',
                'two': '2,2'}
federation.register_collaborators(col_data_paths=col_data_paths)

In [12]:
# create an experimnet in federation
fl_experiment = FLExperiment(federation=federation, serializer_plugin='openfl.plugins.interface_serializer.cloudpickle_serializer.Cloudpickle_Serializer')

In [13]:
# If I use autoreload I got a pickling error
fl_experiment.start_experiment(model_provider=MI, task_keeper=TI, data_loader=fed_dataset, rounds_to_train=5, \
                              opt_treatment='CONTINUE_GLOBAL')

tried to remove tensor: __opt_state_needed not present in the tensor dict
tried to remove tensor: __opt_state_needed not present in the tensor dict
gRPC is running on insecure channel with TLS disabled.


In [14]:
fl_experiment.plan.config['tasks']

{'defaults': None,
 'settings': {},
 'train': {'function': 'train', 'kwargs': {'batch_size': 32}},
 'validate': {'function': 'validate', 'kwargs': {}}}

In [15]:
fl_experiment.plan.config['assigner']['settings']['task_groups'][0]['tasks']

['train', 'validate']

In [16]:
fl_experiment.plan.Build(fl_experiment.plan.config['api_layer']['required_plugin_components']['serializer_plugin'], {})

<openfl.plugins.interface_serializer.cloudpickle_serializer.Cloudpickle_Serializer at 0x7f3b040769d0>

In [17]:
fl_experiment.plan.config

{'aggregator': {'template': 'openfl.component.Aggregator',
  'settings': {'db_store_rounds': 1,
   'init_state_path': 'save/init.pbuf',
   'best_state_path': 'save/best.pbuf',
   'last_state_path': 'save/last.pbuf',
   'rounds_to_train': 5,
   'aggregator_uuid': 'aggregator_plan.yaml_1538be85',
   'federation_uuid': 'plan.yaml_1538be85',
   'authorized_cols': ['one', 'two'],
   'assigner': <openfl.component.assigner.random_grouped_assigner.RandomGroupedAssigner at 0x7f3cec1685d0>,
   'defaults': 'plan/defaults/aggregator.yaml',
   'initial_tensor_dict': {'dense/kernel:0': array([[-0.06193923,  0.06393183, -0.01594066, ...,  0.00283346,
            -0.0739767 , -0.02184607],
           [-0.05921082,  0.03525515, -0.04930659, ...,  0.08030752,
             0.00076302,  0.04539744],
           [-0.03069666,  0.02656408, -0.0732119 , ...,  0.07358757,
             0.06032515,  0.05494713],
           ...,
           [ 0.03223971,  0.03537741, -0.05729705, ..., -0.06715562,
            -0.0

In [18]:
yaml.safe_dump()

NameError: name 'yaml' is not defined

In [None]:
import dill
with open('./tasks.pkl', 'wb') as f:
    dill.dump(TI, f,recurse=True)

In [None]:
import numpy as np
arr = np.arange(0,10)
test_task(arr)

In [None]:
import dill
with open('./model.pkl', 'wb') as f:
    dill.dump(MI, f, recurse=True)
# Pickling class    
# with open('./model_cls.pkl', 'wb') as f:
#     dill.dump(UNet, f, recurse=True)

In [None]:
UNet.__module__

In [None]:
import torch
a = torch.rand(1,3,128,128)
model_unet.forward(a)

In [None]:
import dill
with open('./dataloader.pkl', 'wb') as f:
    dill.dump(fed_dataset, f, recurse=True)

In [None]:
with open('./data.yaml', 'w') as f:
    for col_name, data_path in {'one': '1,2',
            'two': '2,2'}.items():
        f.write(f'{col_name},{data_path}\n')

In [None]:
task_contract = dict()
task_contract['optimizer'] = 1

In [None]:
validation = True if task_contract['optimizer'] is not None else False
validation