# Federated Next Word Prediction with Director example

In [None]:
# install requirements
!pip install -r requirements.txt

In [None]:
import numpy as np

import os
# disable GPUs due to Tensoflow not supporting CUDA 11
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

# Connect to the Federation

In [None]:
# Create a federation
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
cliend_id = 'frontend'

# 1) Run with API layer - Director mTLS
# If the user wants to enable mTLS their must provide CA root chain, and signed key pair to the federation interface
# cert_chain = 'cert/root_ca.crt'
# API_certificate = 'cert/frontend.crt'
# API_private_key = 'cert/frontend.key'

# federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051', disable_tls=False,
#                        cert_chain=cert_chain, api_cert=API_certificate, api_private_key=API_private_key)

# --------------------------------------------------------------------------------------------------------------------

# 2) Run with TLS disabled (trusted environment)
# Federation can also determine local fqdn automatically
federation = Federation(client_id='frontend', director_node_fqdn='localhost', director_port='50051', tls=False)

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
federation.target_shape

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
sample, target = dummy_shard_desc.get_dataset(dataset_type='')[0]

## Creating a FL experiment using Interactive API

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface, DataInterface, ModelInterface, FLExperiment

### Register dataset

In [None]:
from tensorflow.keras.utils import Sequence

class DataGenerator(Sequence):

    def __init__(self, dataset, batch_size):
        self.dataset = dataset
        self.batch_size = batch_size
        self.on_epoch_end()

    def __len__(self):
        return len(self.dataset) // self.batch_size

    def __getitem__(self, index):
        return self.dataset[index * self.batch_size:(index + 1) * self.batch_size]

# Now you can implement you data loaders using dummy_shard_desc
class NextWordSD(DataInterface):

    def __init__(self, train_val_split=0.8, **kwargs):
        super().__init__(**kwargs)
        self.train_val_split = train_val_split

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor

    def __getitem__(self, index):
        return self.shard_descriptor[index]

    def __len__(self):
        return len(self.shard_descriptor)

    def get_train_loader(self):
        """
        Output of this method will be provided to tasks with optimizer in contract
        """
        if self.kwargs['train_bs']:
            batch_size = self.kwargs['train_bs']
        else:
            batch_size = 64

        self.train_dataset = self.shard_descriptor.get_dataset('train', self.train_val_split)
        return DataGenerator(self.train_dataset, batch_size=batch_size)

    def get_valid_loader(self):
        """
        Output of this method will be provided to tasks without optimizer in contract
        """
        if self.kwargs['valid_bs']:
            batch_size = self.kwargs['valid_bs']
        else:
            batch_size = 512

        self.val_dataset = self.shard_descriptor.get_dataset('val', self.train_val_split)
        return DataGenerator(self.val_dataset, batch_size=batch_size)

    def get_train_data_size(self):
        """
        Information for aggregation
        """
        return len(self.train_dataset)

    def get_valid_data_size(self):
        """
        Information for aggregation
        """
        return len(self.val_dataset)


### Describe a model and optimizer
#### Sequential API

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import LSTM, Dense
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import TopKCategoricalAccuracy
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras.models import Sequential

model = Sequential()
model.add(LSTM(1000, return_sequences=True))
model.add(LSTM(1000))
model.add(Dense(1000, activation='tanh'))
model.add(Dense(10719, activation='softmax'))

optimizer = Adam(learning_rate=0.001)
loss_fn = CategoricalCrossentropy()
train_acc_metric = TopKCategoricalAccuracy(k=10)
val_acc_metric = TopKCategoricalAccuracy(k=10)

batch_size = 64
model.build(input_shape=[batch_size, 3, 96])

In [None]:
fed_dataset = NextWordSD(train_bs=64, valid_bs=512)

### Define and register FL tasks

In [None]:
TI = TaskInterface()

# https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit
@TI.register_fl_task(model='model', data_loader='train_loader', device='device', optimizer='optimizer')
def train(model, train_loader, device, optimizer):

    # Iterate over the batches of the dataset.
    for step, (x_batch_train, y_batch_train) in enumerate(train_loader):

        y = tf.convert_to_tensor(y_batch_train)
        with tf.GradientTape() as tape:
            y_pred = model(x_batch_train, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = loss_fn(y, y_pred)

        # Compute gradients
        trainable_vars = model.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        # Update weights
        optimizer.apply_gradients(zip(gradients, trainable_vars))

        # Update metrics
        train_acc_metric.update_state(y, y_pred)
    
    # Reset training metrics at the end of each epoch
    train_acc = train_acc_metric.result()
    train_acc_metric.reset_states()
    return {'train_acc': train_acc, 'loss': loss}


@TI.register_fl_task(model='model', data_loader='val_loader', device='device')
def validate(model, val_loader, device=''):
    for x_batch_val, y_batch_val in val_loader:
        y = tf.convert_to_tensor(y_batch_val)
        # Compute predictions
        y_pred = model(x_batch_val, training=False)
        # Update the metrics.
        val_acc_metric.update_state(y, y_pred)
    val_acc = val_acc_metric.result()
    val_acc_metric.reset_states()
    return {'validation_accuracy': val_acc}


#### Register model

In [None]:
from copy import deepcopy

framework_adapter = 'openfl.plugins.frameworks_adapters.keras_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=model, optimizer=optimizer, framework_plugin=framework_adapter)
# Save the initial model state
initial_model = deepcopy(model)

## Time to start a federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'word_prediction_test_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# If I use autoreload I got a pickling error

# The following command zips the workspace and python requirements to be transfered to collaborator nodes
fl_experiment.start(model_provider=MI, 
                    task_keeper=TI,
                    data_loader=fed_dataset,
                    rounds_to_train=20,
                    opt_treatment='RESET')

In [None]:
# If user want to stop IPython session, then reconnect and check how experiment is going 
# fl_experiment.restore_experiment_state(MI)

fl_experiment.stream_metrics()

## Testing the best model

In [None]:
!pip install -r ../envoy/sd_requirements.txt

In [None]:
import sys
sys.path.insert(1, '../envoy')

In [None]:
from shard_descriptor import NextWordShardDescriptor

# https://www.gutenberg.org/files/2892/2892-h/2892-h.htm
fed_dataset = NextWordSD(train_bs=64, valid_bs=512, train_val_split=0)
fed_dataset.shard_descriptor = NextWordShardDescriptor(title='Irish Fairy Tales', author='James Stephens')

In [None]:
best_model = fl_experiment.get_best_model()

# We remove data from director
fl_experiment.remove_experiment_data()

# Validating initial model
validate(initial_model, fed_dataset.get_valid_loader())

In [None]:
# Validating trained model
validate(best_model, fed_dataset.get_valid_loader())