# SplitFed Model Optimization using Game Theoretic Approaches

In this notebook we aim to optimize SplitFed ([arXiv:2004.12088](https://arxiv.org/abs/2004.12088)), a combination of Split Learning and Federated Learning ([arXiv:1810.06060](https://arxiv.org/abs/1810.06060), [arXiv:1812.00564](https://arxiv.org/abs/1812.00564)), using game theoretic approaches. Specifically, we look at balancing the number of model layers trained on each client device with computation overhead, communication overhead, and inference performance.

In [3]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # https://stackoverflow.com/a/64438413

In [4]:
from __future__ import annotations
import glob
import inspect
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import seaborn as sns
import sys
import tensorflow as tf
import tensorflow.keras as keras

In [5]:
sns.set() # Use seaborn themes.

## Environment Setup

This section contains code that is modifies output path locations, random seed, and logging.

In [6]:
# Set random seeds.
SEED = 0
tf.random.set_seed(SEED) # Only this works on ARC (since tensorflow==2.4).

In [7]:
# Setup logging (useful for ARC systems).
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Must be lowest of all handlers listed below.
while logger.hasHandlers(): logger.removeHandler(logger.handlers[0]) # Clear all existing handlers.

# Custom log formatting.
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')

# Log to STDOUT (uses default formatting).
sh = logging.StreamHandler(stream=sys.stdout)
sh.setLevel(logging.INFO)
logger.addHandler(sh)

# Set Tensorflow logging level.
tf.get_logger().setLevel('ERROR') # 'INFO'

In [8]:
# List all GPUs visible to TensorFlow.
gpus = tf.config.list_physical_devices('GPU')
logger.info(f"Num GPUs Available: {len(gpus)}")
for gpu in gpus:
    logger.info(f"Name: {gpu.name}, Type: {gpu.device_type}")

Num GPUs Available: 0


## Split Model Architecture

To do Split Learning, a base model must be divided into client/server sub-models for training and evaluation. There are several configuration approaches to doing this as described in [arXiv:1812.00564](https://arxiv.org/abs/1812.00564). In this implementation, we focus on the simpler _vanilla_ configuration, which leverages a single forward/backward propagation pipeline. That is, the client model has a single input and the server holds the data labels. In the forward pass, data propagates through the client model, the outputs of which are then passed to the server where the loss is computed. In the backward pass, the gradients are computed at the server then backpropagated through its model, the final gradients are then sent to the client, where the backpropagation continues until the client input layer.

In [22]:
def split_model(
    base_model: keras.models.Model,
    cut_layer_key: int|str,
    ) -> tuple[keras.models.Model, keras.models.Model]:

    # Extract client-side input/output layers from base model.
    inp_client = base_model.input
    if isinstance(cut_layer_key, int):
        out_client = base_model.get_layer(index=cut_layer_key).output
    else:
        out_client = base_model.get_layer(name=cut_layer_key).output

    # Extract server-side output layer.
    out_server = base_model.output

    # Build client/server models.
    model_client = keras.models.Model(inputs=inp_client, outputs=out_client)
    model_server = keras.models.Model(inputs=out_client, outputs=out_server)
    return model_server, model_client



inp = keras.Input(shape=(10))
x = keras.layers.Dense(2, activation="relu", name="layer1")(inp)
x = keras.layers.Dense(3, activation="relu", name="layer2")(x)
x = keras.layers.Dense(4, name="layer3")(x)
model = keras.Model(inputs=inp, outputs=x)
s, c = split_model(model, 'layer2')
c.summary()
s.summary()

Model: "model_15"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_11 (InputLayer)       [(None, 10)]              0         
                                                                 
 layer1 (Dense)              (None, 2)                 22        
                                                                 
 layer2 (Dense)              (None, 3)                 9         
                                                                 
Total params: 31
Trainable params: 31
Non-trainable params: 0
_________________________________________________________________
Model: "model_16"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_12 (InputLayer)       [(None, 3)]               0         
                                                                 
 layer3 (Dense)              (None, 4)      