# SplitFed Model Optimization using Game Theoretic Approaches

In this notebook we aim to optimize SplitFed ([arXiv:2004.12088](https://arxiv.org/abs/2004.12088)), a combination of Split Learning and Federated Learning ([arXiv:1810.06060](https://arxiv.org/abs/1810.06060), [arXiv:1812.00564](https://arxiv.org/abs/1812.00564)), using game theoretic approaches. Specifically, we look at balancing the number of model layers trained on each client device with computation overhead, communication overhead, and inference performance.

In [2]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # https://stackoverflow.com/a/64438413

In [3]:
from __future__ import annotations
import copy
import glob
import inspect
import json
import logging
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pathlib import Path
import seaborn as sns
import sys
import tensorflow as tf
import tensorflow.keras as keras
import tqdm
from typing import Any, Callable

In [4]:
sns.set() # Use seaborn themes.

## Environment Setup

This section contains code that is modifies output path locations, random seed, and logging.

In [5]:
# Set random seeds.
SEED = 0
tf.random.set_seed(SEED) # Only this works on ARC (since tensorflow==2.4).

In [6]:
# Setup logging (useful for ARC systems).
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG) # Must be lowest of all handlers listed below.
while logger.hasHandlers(): logger.removeHandler(logger.handlers[0]) # Clear all existing handlers.

# Custom log formatting.
formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s')

# Log to STDOUT (uses default formatting).
sh = logging.StreamHandler(stream=sys.stdout)
sh.setLevel(logging.INFO)
logger.addHandler(sh)

# Set Tensorflow logging level.
tf.get_logger().setLevel('ERROR') # 'INFO'

In [7]:
# List all GPUs visible to TensorFlow.
gpus = tf.config.list_physical_devices('GPU')
logger.info(f"Num GPUs Available: {len(gpus)}")
for gpu in gpus:
    logger.info(f"Name: {gpu.name}, Type: {gpu.device_type}")

Num GPUs Available: 0


## Split Model Architecture

To do Split Learning, a base model must be divided into client/server sub-models for training and evaluation. There are several configuration approaches to doing this as described in [arXiv:1812.00564](https://arxiv.org/abs/1812.00564). In this implementation, we focus on the simpler _vanilla_ configuration, which leverages a single forward/backward propagation pipeline. That is, the client model has a single input and the server holds the data labels. In the forward pass, data propagates through the client model, the outputs of which are then passed to the server where the loss is computed. In the backward pass, the gradients are computed at the server then backpropagated through its model, the final gradients are then sent to the client, where the backpropagation continues until the client input layer.

In [8]:
def split_model(
    base_model: keras.models.Model,
    cut_layer_key: int|str,
    ) -> tuple[keras.models.Model, keras.models.Model]:

    # Extract client-side input/output layers from base model.
    inp_client = base_model.input
    if isinstance(cut_layer_key, int):
        out_client = base_model.get_layer(index=cut_layer_key).output
    else:
        out_client = base_model.get_layer(name=cut_layer_key).output

    # Extract server-side output layer.
    out_server = base_model.output

    # Build client/server models.
    model_client = keras.models.Model(inputs=inp_client, outputs=out_client)
    model_server = keras.models.Model(inputs=out_client, outputs=out_server)
    return model_server, model_client



inp = keras.Input(shape=(10))
x = keras.layers.Dense(2, activation="relu", name="layer1")(inp)
x = keras.layers.Dense(3, activation="relu", name="layer2")(x)
x = keras.layers.Dense(4, name="layer3")(x)
model = keras.Model(inputs=inp, outputs=x)
s, c = split_model(model, 'layer2')
c.summary()
s.summary()

Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 10)]              0         
                                                                 
 layer1 (Dense)              (None, 2)                 22        
                                                                 
 layer2 (Dense)              (None, 3)                 9         
                                                                 
Total params: 31
Trainable params: 31
Non-trainable params: 0
_________________________________________________________________
Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 3)]               0         
                                                                 
 layer3 (Dense)              (None, 4)        

## Federated Training Using Split Model

In [24]:
def split_train_step(
    model_server: keras.models.Model,
    model_client: keras.models.Model,
    x: tf.Tensor,
    y: tf.Tensor,
    ) -> dict[str, tf.Tensor]:
    """Split learning training step.

    Runs a single training step for the given server and client models.

    Note that the current implementation uses a single `tf.GradientTape` instance to
    reduce code complexity. This means that the current implementation is for simulation
    purposes only. True distributed learning would require a separate `tf.GradientTape`
    instance for each model, where the backpropagation is done using a Jacobian matrix
    across the separate tape gradients.

    Args:
        model_server (keras.models.Model): Server model (compiled with optimizer and loss).
        model_client (keras.models.Model): Client model (compiled with optimizer and loss).
        x (tf.Tensor): Batched training input.
        y (tf.Tensor): Batched training targets.

    Returns:
        dict[str, tf.Tensor]: Dictionary of server model metrics after the current training step.
    """

    # For this simulation we use a single GradientTape instance to make
    # the codebase simpler. A true distributed environment would require
    # a separate GradientTape instance for the server/client.
    with tf.GradientTape(persistent=True) as tape:

        ###### Client forward pass ######
        out_client = model_client(x, training=True)

        ###### Server forward pass ######
        out_server = model_server(out_client, training=True)

        ###### Server backward pass ######
        loss = model_server.compiled_loss(
            y_true=y,
            y_pred=out_server,
            regularization_losses=model_server.losses,
        )
        # Compute server gradients.
        grad_server = tape.gradient(loss, model_server.trainable_variables)
        # Update server weights.
        model_server.optimizer.apply_gradients(zip(grad_server, model_server.trainable_variables))
        # Update server metrics.
        model_server.compiled_metrics.update_state(
            y_true=y,
            y_pred=out_server,
        )

        ###### Client backward pass ######
        grad_client = tape.gradient(loss, model_client.trainable_variables)
        # Update local client weights.
        model_client.optimizer.apply_gradients(zip(grad_client, model_client.trainable_variables))
        # No need to update client metrics since lables are on the server.

    # Return dictionary of servermetrics (including loss).
    return {m.name: m.result() for m in model_server.metrics}


def split_test_step(
    model_server: keras.models.Model,
    model_client: keras.models.Model,
    x: tf.Tensor,
    y: tf.Tensor,
    ) -> dict[str, tf.Tensor]:
    """Split learning validation/test step.

    Runs a single validation/test step for the given server and client models.

    Args:
        model_server (keras.models.Model): Server model (compiled with optimizer and loss).
        model_client (keras.models.Model): Client model (compiled with optimizer and loss).
        x (tf.Tensor): Batched validation/test input.
        y (tf.Tensor): Batched validation/test targets.

    Returns:
        dict[str, tf.Tensor]: Dictionary of server model metrics after the current validation/test step.
    """

    ###### Client forward pass ######
    out_client = model_client(x, training=False)

    ###### Server forward pass ######
    out_server = model_server(out_client, training=False)
    # Update server metrics.
    model_server.compiled_metrics.update_state(
        y_true=y,
        y_pred=out_server,
    )

    # Return dictionary of servermetrics (including loss).
    return {f"val_{m.name}": m.result() for m in model_server.metrics}


def fed_avg(
    model_weights: dict[str, list[tf.Tensor]],
    dist: dict[str, float],
    ) -> list[tf.Tensor]:
    """Weighted average of model layer parameters.

    Args:
        model_weights (dict[str, list[tf.Tensor]]): Dictionary of model weight lists.
        dist (dict[str, float]): Distribution for weighted averaging.

    Returns:
        list[tf.Tensor]: List of averaged weight tensors for each layer of the model.
    """

    # Scale the weights using the given distribution.
    model_weights_scaled = [
        [dist[key] * layer for layer in weights] 
        for key, weights in model_weights.items()
    ]

    # Average the weights.
    avg_weights = []
    for weight_tup in zip(*model_weights_scaled):
        avg_weights.append(
            tf.math.reduce_sum(weight_tup, axis=0)
        )
    return avg_weights


###
# Vanilla SplitLearning configuration only.
###
def train_splitfed(
    model_server: keras.models.Model,
    model_client: keras.models.Model,
    model_builder_server: Callable[[keras.models.Model], keras.models.Model],
    model_builder_client: Callable[[keras.models.Model], keras.models.Model],
    client_data: dict[int|str, tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]], # Dictionary of client data, where values are tuple(train, val, test) subsets (assumes already batched). The length of the dictionary determines the number of clients.
    n_rounds: int, # Number of global communication rounds.
    n_epochs: int, # Number of local client training epochs.
    ) -> tuple[keras.models.Model, keras.models.Model]:
    # Determine number of clients.
    n_clients: int = len(client_data)

    ########## Main Server ###############
    # Build initial server model.
    model_server = model_builder_server(model_server)

    # Copy of global server weight parameters.
    global_weights_server = copy.deepcopy(model_server.get_weights())
    ######################################

    ########## Federated Server ##########
    # Build initial client model.
    model_client = model_builder_client(model_client)

    # Copy of global client weight parameters.
    global_weights_client = copy.deepcopy(model_client.get_weights())
    #######################################

    # Global training loop.
    # Communication rounds between server <--> clients.
    for round in range(n_rounds):
        # Perserve server weights for each client update.
        all_server_weights: dict[str, tf.Tensor] = {}

        # Train each client model.
        # This could be done in parallel, but here we do it 
        # synchronously for ease of development.
        all_client_weights: dict[str, tf.Tensor] = {}
        all_client_data_records_train: dict[str, int] = {}
        for client, (train_dataset, val_dataset, test_dataset) in client_data.items():

            # Reset server model so that weights are fresh during synchronous updates.
            model_server.set_weights(global_weights_server)

            # Synchronize global client model to local client.
            model_client_local = model_builder_client(model_client)
            model_client_local.set_weights(global_weights_client)

            # Train the current model for the desired number of epochs.
            all_client_data_records_train[client] = 0 # Initialize record count.
            for epoch in range(n_epochs):

                # Training loop.
                with tqdm.tqdm(train_dataset, unit='batch') as pbar:
                    for step, (x_train_batch, y_train_batch) in enumerate(pbar):
                        pbar.set_description(f"[round {round+1}/{n_rounds}, client {client}, epoch {epoch+1}/{n_epochs}] train")

                        # Run a single training step.
                        metrics_train = split_train_step(
                            model_server=model_server,
                            model_client=model_client_local,
                            x=x_train_batch,
                            y=y_train_batch,
                        )

                        # Add current number of batches to total number of records for the current client.
                        all_client_data_records_train[client] += x_train_batch.shape[0]

                        # Update progress bar with metrics.
                        pbar.set_postfix({k:v.numpy() for k,v in metrics_train.items()})

                # Validation loop.
                with tqdm.tqdm(val_dataset, unit='batch') as pbar:
                    for x_val_batch, y_val_batch in pbar:
                        pbar.set_description(f"[round {round+1}/{n_rounds}, client {client}, epoch {epoch+1}/{n_epochs}] val")

                        # Run a single validation step.
                        metrics_val = split_test_step(
                            model_server=model_server,
                            model_client=model_client_local,
                            x=x_val_batch,
                            y=y_val_batch,
                        )

                        # Update progress bar with metrics.
                        pbar.set_postfix({k:v.numpy() for k,v in metrics_val.items()})

                # Reset train/val metrics.
                model_client.reset_metrics()
                model_server.reset_metrics()

            # Create a copy of this client's model weights and preserve for future aggregation.
            all_client_weights[client] = copy.deepcopy(model_client_local.get_weights())

            # Create a copy of the current server weights.
            all_server_weights[client] = copy.deepcopy(model_server.get_weights())

        # Count total number of data records across all clients.
        total_data_records = float(sum(v for _, v in all_client_data_records_train.items()))

        # Now perform federated averaging for weights of all clients.
        # To do this, first cCompute distribution for weighted-average.
        # Then perform federated averaging weight aggregation.
        dist = {
            client: float(count)/total_data_records
            for client, count in all_client_data_records_train.items()
        }
        global_weights_client = fed_avg(model_weights=all_client_weights, dist=dist)

        # Also average server weights for each client update.
        global_weights_server = fed_avg(model_weights=all_server_weights, dist=dist)

    # Load the final global weights for the server and client.
    model_server.set_weights(global_weights_server)
    model_client.set_weights(global_weights_client)

    # Return server and client models.
    return model_server, model_client

In [26]:
# Prepare the training dataset.
batch_size = 64
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
x_train = np.reshape(x_train, (-1, 784))
x_test = np.reshape(x_test, (-1, 784))

# Reserve 5,000 samples for validation, and 5,000 for testing.
x_test = x_train[-5000:]
y_test = y_train[-5000:]
x_val = x_train[-10000:-5000]
y_val = y_train[-10000:-5000]
x_train = x_train[:-10000]
y_train = y_train[:-10000]

# Prepare the training dataset.
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# Prepare the validation dataset.
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_dataset = val_dataset.batch(batch_size)

# Prepare the testing dataset.
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_dataset = test_dataset.batch(batch_size)

def compile_model(model: keras.models.Model):
    model.compile(
        optimizer='adam',
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['acc'],
        )
    return model

def build_model():
    inputs = keras.Input(shape=(784,), name="digits")
    x1 = keras.layers.Dense(64, activation="relu", name='dense0')(inputs)
    x2 = keras.layers.Dense(64, activation="relu", name='dense1')(x1)
    outputs = keras.layers.Dense(10, name="predictions")(x2)
    model = keras.Model(inputs=inputs, outputs=outputs)
    server, client = split_model(model, 'dense0')
    return server, client

print(f"{tf.data.experimental.cardinality(train_dataset).numpy()=}")
print(f"{tf.data.experimental.cardinality(val_dataset).numpy()=}")
print(f"{tf.data.experimental.cardinality(test_dataset).numpy()=}")


# THIS IS NAIVE, NEEDS TO BE CORRECTED!
# Build client datasets.
n_clients = 2
client_data = {
    c: (train_dataset, val_dataset, test_dataset)
    for c in range(n_clients)
}


server, client = build_model()

server_trained, client_trained = train_splitfed(
    model_server=server,
    model_client=client,
    model_builder_server=compile_model,
    model_builder_client=compile_model,
    client_data=client_data, # Dictionary of client data, where values are tuple(train, val, test) subsets (assumes already batched). The length of the dictionary determines the number of clients.
    n_rounds=1, # Number of global communication rounds.
    n_epochs=1, # Number of local client training epochs.
)

client.summary()
server.summary()

tf.data.experimental.cardinality(train_dataset).numpy()=782
tf.data.experimental.cardinality(val_dataset).numpy()=79
tf.data.experimental.cardinality(test_dataset).numpy()=79


[round 1/1, client 0, epoch 1/1] train: 100%|██████████| 782/782 [00:11<00:00, 67.03batch/s, loss=2.27, acc=0.823]
[round 1/1, client 0, epoch 1/1] val: 100%|██████████| 79/79 [00:00<00:00, 204.58batch/s, val_loss=2.27, val_acc=0.827]
[round 1/1, client 1, epoch 1/1] train: 100%|██████████| 782/782 [00:11<00:00, 67.87batch/s, loss=2, acc=0.822]   
[round 1/1, client 1, epoch 1/1] val: 100%|██████████| 79/79 [00:00<00:00, 196.56batch/s, val_loss=2, val_acc=0.827]


Model: "model_31"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 digits (InputLayer)         [(None, 784)]             0         
                                                                 
 dense0 (Dense)              (None, 64)                50240     
                                                                 
Total params: 50,240
Trainable params: 50,240
Non-trainable params: 0
_________________________________________________________________
Model: "model_32"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_12 (InputLayer)       [(None, 64)]              0         
                                                                 
 dense1 (Dense)              (None, 64)                4160      
                                                                 
 predictions (Dense)         (None, 

## Game Theory Definitions

In [1]:
import numpy as np

In [18]:
def compute_time_forward_prop(
    d: float, # Portion of network that client device `k` will train.
    W: float, # Total size of neural network to be trained.
    D: float, # Data size for client `k`.
    beta: float, # Amount of computational complexity of forward propagation.
    F_dk: float, # The computing resource of device `k`.
    ) -> float:
    """Computes `T_{d,k}^{F}`, the computing time for forward propagation of device `k`.

    Mathematically, this is:
        `T_{d,k}^{F} = (d * beta * W * D) / F_{d,k}`

    Args:
        d (float): Portion of network that client device `k` will train.
        W (float): Total size of neural network to be trained.
        D (float): Data size for client `k`.
        beta (float): Amount of computational complexity forward propagation.
        F_dk (float): The computing resource of device `k`, called `F_{d,k}`.

    Returns:
        float: `T_{d,k}^{F}`
    """
    return (d*beta*W*D)/F_dk

def compute_time_backward_prop(
    d: float, # Portion of network that client device `k` will train.
    W: float, # Total size of neural network to be trained.
    D: float, # Data size for client `k`.
    beta: float, # Amount of computational complexity of forward propagation.
    F_dk: float, # The computing resource of device `k`.
    ) -> float:
    """Computes `T_{d,k}^{B}`, the computing time for backward propagation of device `k`.

    Mathematically, this is:
        `T_{d,k}^{B} = (d * (1-beta) * W * D) / F_{d,k}`

    Args:
        d (float): Portion of network that client device `k` will train.
        W (float): Total size of neural network to be trained.
        D (float): Data size for client `k`.
        beta (float): Amount of computational complexity forward propagation.
        F_dk (float): The computing resource of device `k`, called `F_{d,k}`.

    Returns:
        float: `T_{d,k}^{B}`
    """
    return (d*(1-beta)*W*D)/F_dk

def compute_time_client(
    d: float, # Portion of network that client device `k` will train.
    W: float, # Total size of neural network to be trained.
    D: float, # Data size for client `k`.
    beta: float, # Amount of computational complexity of forward propagation.
    F_dk: float, # The computing resource of device `k`.
    ) -> float:
    """Computes `T_{d,k}`, the total computing time of device `k`.

    Mathematically, this is:
        `T_{d,k} = T_{d,k}^{F} + T_{d,k}^{B}`

    Args:
        d (float): Portion of network that client device `k` will train.
        W (float): Total size of neural network to be trained.
        D (float): Data size for client `k`.
        beta (float): Amount of computational complexity forward propagation.
        F_dk (float): The computing resource of device `k`, called `F_{d,k}`.

    Returns:
        float: `T_{d,k}`
    """
    T_F = compute_time_forward_prop(
        d=d,
        W=W,
        D=D,
        beta=beta,
        F_dk=F_dk,
    )
    T_B = compute_time_backward_prop(
        d=d,
        W=W,
        D=D,
        beta=beta,
        F_dk=F_dk,
    )
    return T_F + T_B

def compute_time_server(
    d: float, # Portion of network that client device `k` will train.
    W: float, # Total size of neural network to be trained.
    D: float, # Data size for client `k`.
    F_s: float, # The computing resource of the server.
    ) -> float:
    """Computes `T_{s}`, the computing time of the server `s` for a single client.

    Mathematically, this is:
        `T_{s} = ((1 - d) * W * D) / F_{s}`

    Args:
        d (float): Portion of network that the client device will train.
        W (float): Total size of neural network to be trained.
        D (float): Data size for client `k`.
        F_s (float): The computing resource of server `s`.

    Returns:
        float: `T_{s}`
    """
    return ((1-d)*W*D)/F_s

def compute_time_global_epoch(
    d: list[float], # Portion of network that client device `k` will train.
    W: float, # Total size of neural network to be trained.
    D: list[float], # Data size for each client `k`.
    beta: float, # Amount of computational complexity of forward propagation.
    F_dk: list[float], # The computing resource of each client `k`.
    F_s: float, # The computing resource of the server.
    ) -> float:
    """Computes `T_{g}`, the total computing time of one global epoch with 1 server and a subset of `k` clients.

    Mathematically, this is:
        `T_{g} = \max_{k}{T_{d,k}^{F}} + \max_{k}{T_{d,k}^{B}} + \sum_{i=1}^{k}{T_{s}}`

    Args:
        d (list[float]): Portions of network that each client device `k` will train.
        W (float): Total size of neural network to be trained.
        D (list[float]): Data size.
        beta (float): Amount of computational complexity forward propagation.
        F_dk (list[float]): The computing resource of device `k`, called `F_{d,k}`.
        F_s (float): The computing resource of server `s`.

    Returns:
        float: `T_{g}`
    """
    # Convert to numpy arrays for vectorization.
    d = np.array(d)
    D = np.array(D)
    F_dk = np.array(F_dk)

    # Ensure all are same length.
    assert len(d) == len(D) == len(F_dk)

    # Compute maximum client compute time.
    max_T_dk = np.max(compute_time_client(
        d=d,
        W=W,
        D=D,
        beta=beta,
        F_dk=F_dk,
    ))

    # Compute total server compute time.
    sum_T_s = np.sum(compute_time_server(
        d=d,
        W=W,
        D=D,
        F_s=F_s,
    ))

    # Compute total time for one global epoch.
    return max_T_dk + sum_T_s

def utility_client(
    d: float, # Portion of network that client device `k` will train.
    W: float, # Total size of neural network to be trained.
    D: float, # Data size for client `k`.
    F_dk: float, # The computing resource of device `k`.
    K: float, # ?
    C_k: float, # ?
    lam: float, # Discount factor.
    ) -> float:
    """Computes `U_{d,k}`, the utility of client `k`.

    Mathematically, this is:
        `U_{d,k} = C_{k} * F_{d,k} - d * W * D * k * (F_{d,k}^2) + lam * \log_{2}{1 + d}`

    Args:
        d (float): Portion of network that client device `k` will train.
        W (float): Total size of neural network to be trained.
        D (float): Data size for client `k`.
        beta (float): Amount of computational complexity forward propagation.
        F_dk (float): The computing resource of device `k`, called `F_{d,k}`.
        K (float): ?
        C_k (float): ?
        lam (float): Discount factor.

    Returns:
        float: `U_{d,k}`
    """
    reward_server = C_k * F_dk
    energy_consume_train = d * W * D * K * (F_dk**2.0)
    reward_privacy = lam * np.log2(1 + d)
    return reward_server - energy_consume_train + reward_privacy

In [19]:
x = compute_time_global_epoch(
    d=np.arange(1,100), # Portion of network that client device `k` will train.
    W=1.0, # Total size of neural network to be trained.
    D=np.arange(1,100), # Data size.
    beta=1.0, # Amount of computational complexity of forward propagation.
    F_dk=np.arange(1,100), # The computing resource of device `k`.
    F_s=1.0, # The computing resource of the server.
)
print(x)

-323301.0
