# TensorFlow!

In [1]:
import tensorflow
import tensorflow.keras.models as models
import tensorflow.keras.layers as layers
import tensorflow.keras.utils as utils
import tensorflow.keras.optimizers as optimizers

client1_model = tensorflow.keras.models.load_model(
    'model_weight/global_model.h5')
client2_model = tensorflow.keras.models.load_model(
    'model_weight/global_model.h5')


# It's training time!

In [2]:
import numpy
import tensorflow.keras.callbacks as callbacks


def get_dataset():
	container = numpy.load('dataset.npz') # Download dataset.npz from the link in Readme.md
	b, v = container['b'], container['v']
	v = numpy.asarray(v / abs(v).max() / 2 + 0.5, dtype=numpy.float32) # normalization (0 - 1)
	return b, v

x_train, y_train = get_dataset()

In [3]:
# Split dataset for two clients
# pip install -U scikit-learn
from sklearn.model_selection import train_test_split

client1_x, client2_x, client1_y, client2y =train_test_split(x_train, y_train, test_size=0.5, random_state=100)

In [None]:
"""
Training client 1 model
"""
from tensorflow.keras.callbacks import ModelCheckpoint
checkpoint_filepath = '/checkpoint/f1-1/'
model_checkpointing_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
)
client1_model.fit(client1_x, client1_y,
          batch_size=2048,
          epochs=100,
          verbose=1,
          validation_split=0.1,
          callbacks=[callbacks.ReduceLROnPlateau(monitor='loss', patience=10),
                     callbacks.EarlyStopping(monitor='loss', patience=15, min_delta=1e-4), model_checkpointing_callback])


Epoch 1/100

In [None]:
"""
Training client 2 model
"""
checkpoint_filepath = '/tmp/checkpoint/f1-2/'
model_checkpointing_callback = ModelCheckpoint(
    filepath=checkpoint_filepath,
    save_best_only=True,
)
client2_model.fit(client2_x, client2y,
          batch_size=2048,
          epochs=100,
          verbose=1,
          validation_split=0.1,
          callbacks=[callbacks.ReduceLROnPlateau(monitor='loss', patience=10),
                     callbacks.EarlyStopping(monitor='loss', patience=15, min_delta=1e-4), model_checkpointing_callback])


In [None]:
"""
Aggregrate models
"""

client1_weights = client1_model.get_weights()
client2_weights = client2_model.get_weights()

client1_weights = numpy.array(client1_weights, dtype=object)
client2_weights = numpy.array(client2_weights, dtype=object)

client1_weights = 1/2 * client1_weights
client2_weights = 1/2 * client2_weights

aggregated_weights = numpy.add(client1_weights,client2_weights)

client1_model.set_weights(aggregated_weights)
client2_model.save_weights("model_weight/federated1_weights", save_format="h5")
