# Federated PyTorch Mnist Tutorial
## Using low-level Python API

# Long-Living entities update

* We now may have director running on another machine.
* We use Federation API to communicate with Director.
* Federation object should hold a Director's client (for user service)
* Keeping in mind that several API instances may be connacted to one Director.


* We do not think for now how we start a Director.
* But it knows the data shape and target shape for the DataScience problem in the Federation.
* Director holds the list of connected envoys, we do not need to specify it anymore.
* Director and Envoys are responsible for encrypting connections, we do not need to worry about certs.


* Yet we MUST have a cert to communicate to the Director.
* We MUST know the FQDN of a Director.
* Director communicates data and target shape to the Federation interface object.


* Experiment API may use this info to construct a dummy dataset and a `shard descriptor` stub.

## Connect to the Federation

In [1]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
# 1) Run with API layer - Director mTLS 
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = f'{cert_dir}/root_ca.crt'
# api_certificate = f'{cert_dir}/{client_id}.crt'
# api_private_key = f'{cert_dir}/{client_id}.key'

# federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051',
#                        cert_chain=cert_chain, api_cert=api_certificate, api_private_key=api_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id=client_id, director_node_fqdn=director_node_fqdn, director_port='50051', tls=False)


In [2]:
federation.target_shape

['1']

In [3]:
shard_registry = federation.get_shard_registry()
shard_registry

{'env_one': {'shard_info': node_info {
    name: "env_one"
  }
  shard_description: "Mnist dataset, shard number 1 out of 2"
  sample_shape: "784"
  target_shape: "1",
  'is_online': True,
  'is_experiment_running': False,
  'last_updated': '2021-09-22 17:57:13',
  'current_time': '2021-09-22 17:57:23',
  'valid_duration': seconds: 120}}

In [4]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
sample, target = dummy_shard_desc[0]

## Describing FL experimen

In [6]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

## Register model

In [8]:
from layers import create_model, optimizer
framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
model = create_model()
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)

### Register dataset

In [9]:
class FedDataset(DataInterface):
    def __init__(self, x_train, y_train, x_valid, y_valid, **kwargs):
        self.X_train = X_train
        self.y_train = y_train
        self.X_valid = X_valid
        self.y_valid = y_valid
        self.batch_size = kwargs['batch_size']
        self.kwargs = kwargs
        self._setup_datasets()
        
    def _setup_datasets(self):
        self.train_dataset = tf.data.Dataset.from_tensor_slices((self.X_train, self.y_train))
        self.train_dataset = self.train_dataset.shuffle(buffer_size=1024).batch(self.batch_size)
        self.valid_dataset = tf.data.Dataset.from_tensor_slices((self.X_valid, self.y_valid))
        self.valid_dataset = self.valid_dataset.shuffle(buffer_size=1024).batch(self.batch_size)
    
    def _delayed_init(self, data_path='1,1'):
        # With the next command the local dataset will be loaded on the collaborator node
        # For this example we have the same dataset on the same path, and we will shard it
        # So we use `data_path` information for this purpose.
        self.rank, self.world_size = [int(part) for part in data_path.split(',')]
        
        # Do the actual sharding
        self._do_sharding(self.rank , self.world_size)
        
    def _do_sharding(self, rank, world_size):
        self.X_train = self.X_train[rank-1 :: world_size ]
        self.y_train = self.y_train[rank-1 :: world_size ]
        self.X_valid = self.X_valid[rank-1 :: world_size ]
        self.y_valid = self.y_valid[rank-1 :: world_size ]
        self._setup_datasets()

    def get_train_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        return self.train_dataset

    def get_valid_loader(self, **kwargs):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        return self.valid_dataset

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.X_train)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.X_valid)


In [12]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

X_valid = x_train[-10000:]
y_valid = y_train[-10000:]
X_train = x_train[:-10000]
y_train = y_train[:-10000]

fed_dataset = FedDataset(X_train, y_train, X_valid, y_valid, batch_size=64)

## Define and register FL tasks

In [13]:
TI = TaskInterface()

import time
from layers import train_acc_metric, val_acc_metric, loss_fn

@TI.register_fl_task(model='model', data_loader='train_dataset', \
                     device='device', optimizer='optimizer')     
def train(model, train_dataset, optimizer, device, loss_fn=loss_fn, warmup=False):
    start_time = time.time()

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_dataset):
        with tf.GradientTape() as tape:
            logits = model(x_batch_train, training=True)
            loss_value = loss_fn(y_batch_train, logits)
        grads = tape.gradient(loss_value, model.trainable_weights)
        optimizer.apply_gradients(zip(grads, model.trainable_weights))

        # Update training metric.
        train_acc_metric.update_state(y_batch_train, logits)

        # Log every 200 batches.
        if step % 200 == 0:
            print(
                "Training loss (for one batch) at step %d: %.4f"
                % (step, float(loss_value))
            )
            print("Seen so far: %d samples" % ((step + 1) * 64))
        if warmup:
            break

    # Display metrics at the end of each epoch.
    train_acc = train_acc_metric.result()
    print("Training acc over epoch: %.4f" % (float(train_acc),))

    # Reset training metrics at the end of each epoch
    train_acc_metric.reset_states()

        
    return {'train_acc': train_acc,}


@TI.register_fl_task(model='model', data_loader='val_dataset', device='device')     
def validate(model, val_dataset, device):
    # Run a validation loop at the end of each epoch.
    for x_batch_val, y_batch_val in val_dataset:
        val_logits = model(x_batch_val, training=False)
        # Update val metrics
        val_acc_metric.update_state(y_batch_val, val_logits)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
            
    return {'validation_accuracy': val_acc,}

### Perform model warm up

The model warmup is necessary to initialize weights when using Tensorflow Gradient Tape

In [14]:
train(model, fed_dataset.get_train_loader(), optimizer, 'cpu', warmup=True)

#Make a copy of the model for later comparison
initial_model = tf.keras.models.clone_model(model)

Training loss (for one batch) at step 0: 118.8006
Seen so far: 64 samples
Training acc over epoch: 0.0781


## Time to start a federated learning experiment

In [15]:
# create an experimnet in federation
experiment_name = 'mnist_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [17]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                   task_keeper=TI,
                   data_loader=fed_dataset,
                   rounds_to_train=5,
                   opt_treatment='CONTINUE_GLOBAL')

InternalError: Tensorflow type 21 not convertible to numpy dtype.