In [5]:
import tensorflow as tf
import tensorflow_federated as tff
import numpy as np
import matplotlib.pyplot as plt
import wandb
from wandb.integration.keras import WandbCallback
import os

In [6]:
USE_WANDB = True
NUM_CLIENTS = 10
NUM_ROUNDS = 10
BATCH_SIZE = 20

In [7]:
# Load CIFAR10 dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

# Normalize pixel values
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train, x_test = x_train.astype(np.float32), x_test.astype(np.float32)

# One-hot encode the labels
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)

In [8]:
if USE_WANDB:
    wandb.init(project="federated_learning", group="group_1", name=f"tf_federated_differential_privacy_{NUM_CLIENTS}clients_{NUM_ROUNDS}rounds_{BATCH_SIZE}batchsize")
    
# Create a function that returns a compiled Keras model
def create_keras_model():
    model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(20, (5, 5), activation='relu', input_shape=(32, 32, 3)),
        tf.keras.layers.MaxPooling2D((2, 2)),

        tf.keras.layers.Conv2D(50, (5, 5), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),

        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(500, activation='relu'),
        tf.keras.layers.Dense(10, activation='softmax')
    ])
    return model

# Wrap the Keras model for use with TFF
def model_fn():
    keras_model = create_keras_model()
    return tff.learning.models.from_keras_model(
        keras_model,
        input_spec=(tf.TensorSpec(shape=[None, 32, 32, 3], dtype=tf.float32),
                    tf.TensorSpec(shape=[None, 10], dtype=tf.float32)),
        loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=[tf.keras.metrics.CategoricalAccuracy()])

### Step 3: Create Federated Data
# Convert the dataset to a federated dataset
def preprocess(dataset):
    def batch_format_fn(element):
        return (tf.reshape(element['x'], [-1, 32, 32, 3]), tf.reshape(element['y'], [-1, 10]))
    return dataset.batch(BATCH_SIZE).map(batch_format_fn)

# Create a federated dataset
client_data = np.array_split(x_train, NUM_CLIENTS)
client_labels = np.array_split(y_train, NUM_CLIENTS)

federated_train_data = [
    preprocess(tf.data.Dataset.from_tensor_slices({'x': client_data[i], 'y': client_labels[i]}))
    for i in range(NUM_CLIENTS)
]

### Step 4: Define the Federated Learning Process
# Create a federated averaging process
iterative_process = tff.learning.algorithms.build_unweighted_fed_avg(
    model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.Adam(),
    # client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.1),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    # model_aggregator=tff.learning.model_update_aggregator.dp_aggregator(noise_multiplier=0.1, clients_per_round=NUM_CLIENTS),
    # model_aggregator=tff.learning.compression_aggregator()
)

state = iterative_process.initialize()

eval_model = create_keras_model()
eval_model.compile(
loss=tf.keras.losses.CategoricalCrossentropy(),
        metrics=[tf.keras.metrics.CategoricalAccuracy()])

full_path = os.path.join("./", "model.keras")
eval_model.save(filepath=full_path)
if USE_WANDB:
    wandb.log_model(path=full_path, name="CIFAR10_CNN")

# Train the model for a few rounds
for round_num in range(1, NUM_ROUNDS + 1):
    state, metrics = iterative_process.next(state, federated_train_data)
    
    model_weights = iterative_process.get_model_weights(state)
    model_weights.assign_weights_to(eval_model)
    eval_metrics = eval_model.evaluate(x_test, y_test, verbose=2)
    
    print(f'Round {round_num}, Metrics={metrics}, Val_Metrics={eval_metrics}')
    
    if USE_WANDB:
        wandb.log({
            'round': round_num,
            'loss': metrics['client_work']['train']['loss'],
            'categorical_accuracy': metrics['client_work']['train']['categorical_accuracy'],
            'val_loss': eval_metrics[0],
            'val_categorical_accuracy': eval_metrics[1]
        })
    
if USE_WANDB:
    wandb.finish()

2024-07-28 22:26:26.645638: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-07-28 22:26:26.645758: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-07-28 22:26:26.658411: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-07-28 22:26:26.658482: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-07-28 22:26:26.662640: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-07-28 22:26:26.662733: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session
2024-07-28 22:26:26.667654: I tensorflow/core/grappler/devices.cc:66] Number of eligible GPUs (core count >= 8, compute capability >= 0.0): 0
2024-07-28 22:26:26.667731: I tensorflow/core/grappler/clusters/single_machine.cc:361] Starting new session


313/313 - 1s - loss: 2.2999 - categorical_accuracy: 0.1019 - 661ms/epoch - 2ms/step
Round 1, Metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('categorical_accuracy', 0.2874), ('loss', 1.9209325), ('num_examples', 50000), ('num_batches', 2500)]))])), ('aggregator', OrderedDict([('zeroing', OrderedDict([('dp_query_metrics', OrderedDict([('clip', 0.1105171)])), ('dp', ())])), ('zeroing_norm', 21.0), ('zeroed_count', 0)])), ('finalizer', OrderedDict([('update_non_finite', 0)]))]), Val_Metrics=[2.299900770187378, 0.10189999639987946]
313/313 - 1s - loss: 2.2987 - categorical_accuracy: 0.1037 - 597ms/epoch - 2ms/step
Round 2, Metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('categorical_accuracy', 0.29096), ('loss', 1.9127438), ('num_examples', 50000), ('num_batches', 2500)]))])), ('aggregator', OrderedDict([('zeroing', OrderedDict([('dp_query_metrics', OrderedDict([('clip', 0.12214029)])), ('dp', ())

0,1
categorical_accuracy,▃▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████▇▇▅▄▄▁
loss,▆▆▅▅▅▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁▁▁▁▁▂▂▅▆▆█
round,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
val_categorical_accuracy,▁▁▁▁▁▂▂▂▄▄▄▅▆▆▆▇▇▇▆▆▆▅▆▆▆▆▆▇█▆▂▅▆▅▅▆▃▃▃▂
val_loss,▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▄▂▂▂▂▂▅▄▃█

0,1
categorical_accuracy,0.22964
loss,2.17778
round,50.0
val_categorical_accuracy,0.1326
val_loss,7.17983


In [9]:
if USE_WANDB:
    wandb.finish()

In [10]:
federated_train_data

[<_MapDataset element_spec=(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>,
 <_MapDataset element_spec=(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>,
 <_MapDataset element_spec=(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>,
 <_MapDataset element_spec=(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>,
 <_MapDataset element_spec=(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>,
 <_MapDataset element_spec=(TensorSpec(shape=(None, 32, 32, 3), dtype=tf.float32, name=None), TensorSpec(shape=(None, 10), dtype=tf.float32, name=None))>,
 <_MapDataset element_spec=(TensorSpec(shape=(None, 32, 32, 3), dtype=