# Federated Tensorflow CIFAR10 Tutorial
Using `tf.data` API

In [None]:
# Install TF if not already. We recommend TF2.7 or greater.
# !pip install tensorflow==2.8

## Imports

In [None]:
import tensorflow as tf
print('TensorFlow', tf.__version__)

## Connect to the Federation

Start `Director` and `Envoy` before proceeding with this cell. 

This cell connects this notebook to the Federation.

In [None]:
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
director_port = 50051

# Create a Federation
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)

## Query Datasets from Shard Registry

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

## Describing FL experiment

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface
from openfl.interface.interactive_api.experiment import ModelInterface
from openfl.interface.interactive_api.experiment import FLExperiment

### Register model

In [None]:
# Define model
model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.BatchNormalization(),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10, activation=None),
], name='simplecnn')
model.summary()

# Define optimizer
optimizer = tf.optimizers.Adam(learning_rate=1e-4)

# Loss and metrics. These will be used later.
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()

# Create ModelInterface
framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

### Register dataset

In [None]:
from openfl.interface.interactive_api.experiment import DataInterface

class CIFAR10FedDataset(DataInterface):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        # shard_descriptor.get_split(...) returns a tf.data.Dataset
        # Check cifar10_shard_descriptor.py for details
        self.train_set = shard_descriptor.get_split('train')
        self.valid_set = shard_descriptor.get_split('valid')

    def get_train_loader(self):
        """Output of this method will be provided to tasks with optimizer in contract"""
        bs = self.kwargs.get('train_bs', 32)
        return self.train_set.batch(bs)

    def get_valid_loader(self):
        """Output of this method will be provided to tasks without optimizer in contract"""
        bs = self.kwargs.get('valid_bs', 32)
        return self.valid_set.batch(bs)
    
    def get_train_data_size(self) -> int:
        """Information for aggregation"""
        return len(self.train_set)

    def get_valid_data_size(self) -> int:
        """Information for aggregation"""
        return len(self.valid_set)

### Create CIFAR10 federated dataset

In [None]:
fed_dataset = CIFAR10FedDataset(train_bs=64, valid_bs=512)

## Define and register FL tasks

In [None]:
from tensorflow.keras.utils import Progbar

TI = TaskInterface()

@TI.register_fl_task(model='model', data_loader='dataset', optimizer='optimizer', device='device')     
def train(model, dataset, optimizer, device, loss_fn=loss_fn, warmup=False):

    # Iterate over the batches of the dataset.
    pbar = Progbar(len(dataset))
    
    for step, (x, y) in enumerate(dataset):
        
        # Gradient
        with tf.GradientTape() as tape:
            logits = model(x, training=True)
            loss_value = loss_fn(y, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y, logits)
        pbar.update(step+1, 
                    values={'loss': loss_value, 'acc': train_acc_metric.result()}.items())
        if warmup: break
    
    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()
    return {'train_acc': train_acc,}


@TI.register_fl_task(model='model', data_loader='dataset', device='device')     
def validate(model, dataset, device):
    # Run a validation loop at the end of each epoch.
    for x, y in dataset:
        logits = model(x, training=False)
        # Update val metrics
        val_acc_metric.update_state(y, logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
            
    return {'validation_accuracy': val_acc,}

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'cifar10_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
ROUNDS_TO_TRAIN = 10
fl_experiment.start(model_provider=MI, 
                   task_keeper=TI,
                   data_loader=fed_dataset,
                   rounds_to_train=ROUNDS_TO_TRAIN,
                   opt_treatment='CONTINUE_GLOBAL')
fl_experiment.stream_metrics()