# SplitFed Model Optimization using Game Theoretic Approaches

In this notebook we aim to optimize SplitFed ([arXiv:2004.12088](https://arxiv.org/abs/2004.12088)), a combination of Split Learning and Federated Learning ([arXiv:1810.06060](https://arxiv.org/abs/1810.06060), [arXiv:1812.00564](https://arxiv.org/abs/1812.00564)), using game theoretic approaches. Specifically, we look at balancing the number of model layers trained on each client device with computation overhead, communication overhead, and inference performance.

In [3]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # https://stackoverflow.com/a/64438413

In [1]:
from __future__ import annotations
import copy
import glob
import inspect
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import seaborn as sns
import sys
import tensorflow as tf
import tensorflow.keras as keras
import tqdm
from typing import Any, Callable

In [5]:
sns.set() # Use seaborn themes.

## Environment Setup

This section contains code that is modifies output path locations, random seed, and logging.

In [6]:
# Set random seeds.
SEED = 0
tf.random.set_seed(SEED) # Only this works on ARC (since tensorflow==2.4).

In [7]:
# Setup logging (useful for ARC systems).
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Must be lowest of all handlers listed below.
while logger.hasHandlers(): logger.removeHandler(logger.handlers[0]) # Clear all existing handlers.

# Custom log formatting.
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')

# Log to STDOUT (uses default formatting).
sh = logging.StreamHandler(stream=sys.stdout)
sh.setLevel(logging.INFO)
logger.addHandler(sh)

# Set Tensorflow logging level.
tf.get_logger().setLevel('ERROR') # 'INFO'

In [8]:
# List all GPUs visible to TensorFlow.
gpus = tf.config.list_physical_devices('GPU')
logger.info(f"Num GPUs Available: {len(gpus)}")
for gpu in gpus:
    logger.info(f"Name: {gpu.name}, Type: {gpu.device_type}")

Num GPUs Available: 0


## Split Model Architecture

To do Split Learning, a base model must be divided into client/server sub-models for training and evaluation. There are several configuration approaches to doing this as described in [arXiv:1812.00564](https://arxiv.org/abs/1812.00564). In this implementation, we focus on the simpler _vanilla_ configuration, which leverages a single forward/backward propagation pipeline. That is, the client model has a single input and the server holds the data labels. In the forward pass, data propagates through the client model, the outputs of which are then passed to the server where the loss is computed. In the backward pass, the gradients are computed at the server then backpropagated through its model, the final gradients are then sent to the client, where the backpropagation continues until the client input layer.

In [22]:
def split_model(
    base_model: keras.models.Model,
    cut_layer_key: int|str,
    ) -> tuple[keras.models.Model, keras.models.Model]:

    # Extract client-side input/output layers from base model.
    inp_client = base_model.input
    if isinstance(cut_layer_key, int):
        out_client = base_model.get_layer(index=cut_layer_key).output
    else:
        out_client = base_model.get_layer(name=cut_layer_key).output

    # Extract server-side output layer.
    out_server = base_model.output

    # Build client/server models.
    model_client = keras.models.Model(inputs=inp_client, outputs=out_client)
    model_server = keras.models.Model(inputs=out_client, outputs=out_server)
    return model_server, model_client



inp = keras.Input(shape=(10))
x = keras.layers.Dense(2, activation="relu", name="layer1")(inp)
x = keras.layers.Dense(3, activation="relu", name="layer2")(x)
x = keras.layers.Dense(4, name="layer3")(x)
model = keras.Model(inputs=inp, outputs=x)
s, c = split_model(model, 'layer2')
c.summary()
s.summary()

Model: "model_15"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_11 (InputLayer)       [(None, 10)]              0         
                                                                 
 layer1 (Dense)              (None, 2)                 22        
                                                                 
 layer2 (Dense)              (None, 3)                 9         
                                                                 
Total params: 31
Trainable params: 31
Non-trainable params: 0
_________________________________________________________________
Model: "model_16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_12 (InputLayer)       [(None, 3)]               0         
                                                                 
 layer3 (Dense)              (None, 4)      

## Federated Training Using Split Model

In [3]:
@tf.function
def split_train_step(
    model_server: keras.models.Model,
    model_client: keras.models.Model,
    x: tf.Tensor,
    y: tf.Tensor,
    ) -> dict[str, tf.Tensor]:
    """Compiled Split Learning training step.

    Args:
        model_server (keras.models.Model): Server model (compiled with optimizer and loss).
        model_client (keras.models.Model): Client model (compiled with optimizer and loss).
        x (tf.Tensor): Batched training input.
        y (tf.Tensor): Batched training targets.

    Returns:
        dict[str, tf.Tensor]: Dictionary of server model metrics after the current training step.
    """

    # For this simulation we use a single GradientTape instance to make
    # the codebase simpler. A true distributed environment would require
    # a separate GradientTape instance for the server/client.
    with tf.GradientTape(persistent=True) as tape:

        ###### Client forward pass ######
        out_client = model_client(x, training=True)

        ###### Server forward pass ######
        out_server = model_server(out_client, training=True)

        ###### Server backward pass ######
        loss = model_server.compiled_loss(
            y_true=y,
            y_pred=out_server,
            regularization_losses=model_server.losses,
        )
        # Compute server gradients.
        grad_server = tape.gradient(loss, model_server.trainable_variables)
        # Update server weights.
        model_server.optimizer.apply_gradients(zip(grad_server, model_server.trainable_variables))
        # Update server metrics.
        model_server.compiled_metrics.update_state(
            y_true=y,
            y_pred=out_server,
        )

        ###### Client backward pass ######
        grad_client = tape.gradient(loss, model_client.trainable_variables)
        # Update local client weights.
        model_client.optimizer.apply_gradients(zip(grad_client, model_client.trainable_variables))
        # No need to update client metrics since lables are on the server.

    # Return dictionary of servermetrics (including loss).
    return {m.name: m.result() for m in model_server.metrics}


@tf.function
def split_test_step(
    model_server: keras.models.Model,
    model_client: keras.models.Model,
    x: tf.Tensor,
    y: tf.Tensor,
    ) -> dict[str, tf.Tensor]:

    ###### Client forward pass ######
    out_client = model_client(x, training=False)

    ###### Server forward pass ######
    out_server = model_server(out_client, training=False)
    # Update server metrics.
    model_server.compiled_metrics.update_state(
        y_true=y,
        y_pred=out_server,
    )

    # Return dictionary of servermetrics (including loss).
    return {m.name: m.result() for m in model_server.metrics}


def fed_avg(
    model_weights: dict[str, list[tf.Tensor]],
    dist: dict[str, float],
    ) -> list[tf.Tensor]:
    """Weighted average of model layer parameters.

    Args:
        model_weights (dict[str, list[tf.Tensor]]): Dictionary of model weight lists.
        dist (dict[str, float]): Distribution for weighted averaging.

    Returns:
        list[tf.Tensor]: List of averaged weight tensors for each layer of the model.
    """

    # Scale the weights using the given distribution.
    model_weights_scaled = [
        [dist[key] * layer for layer in weights] 
        for key, weights in model_weights.items()
    ]

    # Average the weights.
    avg_weights = []
    for weight_tup in zip(*model_weights_scaled):
        avg_weights.append(
            tf.math.reduce_sum(weight_tup, axis=0)
        )
    return avg_weights


###
# Vanilla SplitLearning configuration only.
###
def train_splitfed(
    model_server: keras.models.Model,
    model_client: keras.models.Model,
    model_builder_server: Callable[[keras.models.Model], keras.models.Model],
    model_builder_client: Callable[[keras.models.Model], keras.models.Model],
    client_data: dict[int|str, tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]], # Dictionary of client data, where values are tuple(train, val, test) subsets (assumes already batched). The length of the dictionary determines the number of clients.
    n_rounds: int, # Number of global communication rounds.
    n_epochs: int, # Number of local client training epochs.
    ):
    # Determine number of clients.
    n_clients: int = len(client_data)

    ########## Main Server ###############
    # Build initial server model.
    model_server = model_builder_server(model_server)

    # Copy of global server weight parameters.
    global_weights_server = copy.deepcopy(model_server.get_weights())
    ######################################

    ########## Federated Server ##########
    # Build initial client model.
    model_client = model_builder_client(model_client)

    # Copy of global client weight parameters.
    global_weights_client = copy.deepcopy(model_client.get_weights())
    #######################################

    # Global training loop.
    # Communication rounds between server <--> clients.
    for round in range(n_rounds):
        # Perserve server weights for each client update.
        all_server_weights: dict[str, tf.Tensor] = {}

        # Train each client model.
        # This could be done in parallel, but here we do it 
        # synchronously for ease of development.
        all_client_weights: dict[str, tf.Tensor] = {}
        all_client_data_records: dict[str, int] = {}
        for client, (train_dataset, val_dataset, test_dataset) in client_data.items():

            # Reset server model so that weights are fresh during synchronous updates.
            model_server.set_weights(global_weights_server)

            # Synchronize global client model to local client.
            model_client_local = model_builder_client(model_client)
            model_client_local.set_weights(global_weights_client)

            # Train the current model for the desired number of epochs.
            all_client_data_records[client] = 0 # Initialize record count.
            for epoch in range(n_epochs):
                for step, (x_train_batch, y_train_batch) in enumerate(train_dataset):

                    # Run a single training step.
                    metrics_train = split_train_step(
                        model_server=model_server,
                        model_client=model_client_local,
                        x=x_train_batch,
                        y=y_train_batch,
                    )

                    # Add current number of batches to total number of records for the current client.
                    all_client_data_records[client] += x_train_batch.shape[0]

                # Validation loop.
                for x_val_batch, y_val_batch in val_dataset:
                    metrics_val = split_test_step(
                        model_server=model_server,
                        model_client=model_client_local,
                        x=x_val_batch,
                        y=y_val_batch,
                    )

                # Reset train/val metrics.
                model_client.reset_metrics()
                model_server.reset_metrics()

            # Create a copy of this client's model weights and preserve for future aggregation.
            all_client_weights[client] = copy.deepcopy(model_client_local.get_weights())

            # Create a copy of the current server weights.
            all_server_weights[client] = copy.deepcopy(model_server.get_weights())

        # Count total number of data records across all clients.
        total_data_records = float(sum(v for _, v in all_client_data_records.items()))

        # Now perform federated averaging for weights of all clients.
        # To do this, first cCompute distribution for weighted-average.
        # Then perform federated averaging weight aggregation.
        dist = {
            client: float(count)/total_data_records
            for client, count in all_client_data_records.items()
        }
        global_weights_client = fed_avg(model_weights=all_client_weights, dist=dist)

        # Also average server weights for each client update.
        global_weights_server = fed_avg(model_weights=all_server_weights, dist=dist)