# Federated FLAX/JAX CIFAR10 Tutorial
Using `TFDS` API as a data loader

In [None]:
!pip install ml_collections flax -q

## Imports

`TF_FORCE_GPU_ALLOW_GROWTH=true` - Starts out allocating very little memory, and as the program gets run and more GPU memory is needed, the GPU memory region is extended for the TensorFlow process.

`XLA_PYTHON_CLIENT_PREALLOCATE=false` - This disables the preallocation behavior. JAX will instead allocate GPU memory as needed, potentially decreasing the overall memory usage. 

In [None]:
DEFAULT_DEVICE='cpu'

In [None]:
import os

if DEFAULT_DEVICE == 'cpu':
    os.environ['JAX_PLATFORMS']='cpu' # Force XLA to use CPU
    os.environ['CUDA_VISIBLE_DEVICES']='-1' # Force TF to use CPU
elif DEFAULT_DEVICE == 'GPU':
    os.environ['XLA_PYTHON_CLIENT_PREALLOCATE']='false'
    os.environ['TF_FORCE_GPU_ALLOW_GROWTH']='true'
    os.environ['TF_ENABLE_ONEDNN_OPTS']='0' # Disable oneDNN custom operations

In [None]:
import tensorflow as tf
print('TensorFlow', tf.__version__)
import warnings
warnings.filterwarnings("ignore")

In [None]:
from flax import linen as nn
from flax.metrics import tensorboard
from flax.training import train_state
import jax
import jax.numpy as jnp
import logging
import ml_collections
import optax
import tensorflow_datasets as tfds
from tensorflow.keras.utils import Progbar
from dataclasses import field


## Connect to the Federation

Start `Director` and `Envoy` before proceeding with this cell. 

This cell connects this notebook to the Federation.

In [None]:
from openfl.interface.interactive_api.federation import Federation

# please use the same identificator that was used in signed certificate
client_id = 'api'
cert_dir = 'cert'
director_node_fqdn = 'localhost'
director_port = 50055

# Create a Federation
federation = Federation(
    client_id=client_id,
    director_node_fqdn=director_node_fqdn,
    director_port=director_port, 
    tls=False
)

## Query Datasets from Shard Registry

In [None]:
shard_registry = federation.get_shard_registry()
shard_registry

In [None]:
# First, request a dummy_shard_desc that holds information about the federated dataset 
dummy_shard_desc = federation.get_dummy_shard_descriptor(size=10)
dummy_shard_dataset = dummy_shard_desc.get_dataset('train')
sample, target = dummy_shard_dataset[0]
f"Sample shape: {sample.shape}, target shape: {target.shape}"

In [None]:
def get_config():
  """Get the default hyperparameter configuration."""
  config = ml_collections.ConfigDict()
  config.learning_rate = 0.01
  config.momentum = 0.9
  config.batch_size = 128
  config.num_epochs = 10
  config.rounds_to_train = 3
  return config

In [None]:
config = get_config()

## Describing FL experiment

In [None]:
from openfl.interface.interactive_api.experiment import TaskInterface
from openfl.interface.interactive_api.experiment import ModelInterface
from openfl.interface.interactive_api.experiment import FLExperiment

### Register model

In [None]:
# Define model
class CNN(nn.Module):
    """A simple CNN model."""
    
    @nn.compact
    def __call__(self, x):
        x = nn.Conv(features=32, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=64, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=128, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)
        x = nn.relu(x)
        x = nn.Dense(features=10)(x)
        return x


In [None]:
class CustomTrainState(train_state.TrainState):
    """ Subclass `train_state.Trainstate` and `update_state` method
        to allow update of `model parameters` and `optimizer state` 
        during `training` loop execution
    """
    opt_vars : list = field(default_factory=list)
    
    def update_state(self, new_state: train_state.TrainState) -> None:
        ''' 
        Update the model states, used during evaluation/inference.
        
        Parameters
        ----------
        new_state : train_state.TrainState
            Updated state with applied gradients.
            update the `state` variable used to initialize ModelInterface
            with the `new_state` parameters

        Returns
        -------
        None
        '''
        # Update Params
        self.params.update(new_state.params)
        
        # Update Optimizer States
        for var in self.opt_vars:
            opt_var_dict = getattr(self.opt_state[0], var)
            new_opt_var_dict = getattr(new_state.opt_state[0], var)
            opt_var_dict.update(new_opt_var_dict)
                


In [None]:
def _get_opt_vars(x):
    return False if x.startswith('_') or x in ['index', 'count'] else True

def create_train_state(rng, config):
    """Creates initial `TrainState`."""
    cnn = CNN()
    params = cnn.init(rng, jnp.ones([1, 32, 32, 3]))['params'].unfreeze() # Random Parameters
    tx = optax.sgd(config.learning_rate, config.momentum) # Optimizer
    optvars = list(filter(_get_opt_vars, dir(tx.init(params)[0])))
    initial_model_state = CustomTrainState.create(apply_fn=cnn.apply, params=params, tx=tx, opt_vars=optvars)
    return initial_model_state

In [None]:
# PRNG - Pseudo Random Number Generator  Seed
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)

# Initialize parameters and optimizers 
# Encapsulate within TrainState class and apply gradients in an easy way
state = create_train_state(init_rng, config)

# Create ModelInterface - Register the state
framework_adapter = 'openfl.plugins.frameworks_adapters.flax_adapter.FrameworkAdapterPlugin'
MI = ModelInterface(model=state, optimizer=None, framework_plugin=framework_adapter)

### Register dataset

In [None]:
from openfl.interface.interactive_api.experiment import DataInterface

class CIFAR10FedDataset(DataInterface):

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    @property
    def shard_descriptor(self):
        return self._shard_descriptor

    @shard_descriptor.setter
    def shard_descriptor(self, shard_descriptor):
        """
        Describe per-collaborator procedures or sharding.

        This method will be called during a collaborator initialization.
        Local shard_descriptor will be set by Envoy.
        """
        self._shard_descriptor = shard_descriptor
        
        # shard_descriptor.get_split(...) returns a tf.data.Dataset
        # Check cifar10_shard_descriptor.py for details
        self.train_set = shard_descriptor.get_split('train')
        self.valid_set = shard_descriptor.get_split('valid')

    def get_train_loader(self):
        """Output of this method will be provided to tasks with optimizer in contract"""
        return self.train_set
        # bs = self.kwargs.get('train_bs', 32)
        # return self.train_set.batch(bs)

    def get_valid_loader(self):
        """Output of this method will be provided to tasks without optimizer in contract"""
        return self.valid_set
        # bs = self.kwargs.get('valid_bs', 32)
        # return self.valid_set.batch(bs)
    
    def get_train_data_size(self) -> int:
        """Information for aggregation"""
        return len(self.train_set)

    def get_valid_data_size(self) -> int:
        """Information for aggregation"""
        return len(self.valid_set)

### Create CIFAR10 federated dataset

In [None]:
fed_dataset = CIFAR10FedDataset()

## Define and register FL tasks

In [None]:
@jax.jit
def apply_model(state, images, labels):
    """Computes gradients, loss and accuracy for a single batch."""

    def loss_fn(params):
        logits = state.apply_fn({'params': params}, images)
        one_hot = jax.nn.one_hot(labels, 10) # 10 - Total number of classes for a given dataset
        loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=one_hot))
        return loss, logits

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(state.params)
    accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
    return grads, loss, accuracy


@jax.jit
def update_model(state, grads):
    """Return an immutable and updated state with applied gradients"""
    return state.apply_gradients(grads=grads)

In [None]:
def train_epoch(state, train_ds, batch_size, rng):
    """Train for a single epoch."""
    train_ds_size = len(train_ds['image'])
    steps_per_epoch = train_ds_size // batch_size
    pbar = Progbar(steps_per_epoch)
    
    # Randomize the batch selection.
    # Permute the dataset index selection
    perms = jax.random.permutation(rng, train_ds_size)
    perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
    perms = perms.reshape((steps_per_epoch, batch_size))

    epoch_loss = []
    epoch_accuracy = []
    step = 1
    for perm in perms:
        batch_images = train_ds['image'][perm, ...] # Same as [perm, :, :, :]
        batch_labels = train_ds['label'][perm, ...]
        # apply_model -> Forward pass through the layers with the given model `state` as a parameter
        grads, loss, accuracy = apply_model(state, batch_images, batch_labels)
        # Apply gradients and get the updated `state`
        # jitted methods are statelessssssss!
        state = update_model(state, grads)
        epoch_loss.append(loss)
        epoch_accuracy.append(accuracy)
        pbar.update(step, values={'epoch loss': loss, 'epoch accuracy': accuracy}.items())
        step = step + 1
        
    train_loss = jnp.array(epoch_loss).mean().item()
    train_accuracy = jnp.array(epoch_accuracy).mean().item()
    return state, train_loss, train_accuracy

In [None]:
TI = TaskInterface()
@TI.register_fl_task(model='state', data_loader='dataset', optimizer='optimizer', device='device')  
def train(state, dataset, optimizer, device, loss_fn=None, warmup=False):
    new_state, train_loss, train_accuracy = train_epoch(state, dataset, config.batch_size, init_rng)
    state.update_state(new_state) # Update `model` parameters registered in ModelInterface with the `new_state` parameters.
    return {'train_acc': train_accuracy,}

@TI.register_fl_task(model='state', data_loader='dataset', device='device')
def validate(state, dataset, device):
    _, val_loss, val_accuracy = apply_model(state, dataset['image'], dataset['label'])
    # print("Validation accuracy: %.4f" % (float(val_accuracy),))
    return {'validation_accuracy': val_accuracy,}

## Start federated learning experiment

In [None]:
# create an experimnet in federation
experiment_name = 'cifar10_experiment'
fl_experiment = FLExperiment(federation=federation, experiment_name=experiment_name)

In [None]:
# The following command zips the workspace and python requirements to be transfered to collaborator nodes

fl_experiment.start(model_provider=MI,
                   task_keeper=TI,
                   data_loader=fed_dataset,
                   rounds_to_train=config.rounds_to_train,
                   opt_treatment='CONTINUE_GLOBAL',
                   device_assignment_policy='CUDA_PREFERRED')


In [None]:
fl_experiment.stream_metrics()