In [22]:
#!pip install tensorflow==2.3.0 keras==2.3.0 tensorflow-federated==0.17.0
#!pip freeze

In [23]:
import nest_asyncio
nest_asyncio.apply()

In [24]:
import collections
import attr
import functools
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

np.random.seed(0)

In [25]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

mnist_train, mnist_test = tf.keras.datasets.mnist.load_data() # This dataset is not "E"mnist. Don't confuse!

In [26]:
[(x.dtype, x.shape) for x in mnist_train]

[(dtype('uint8'), (60000, 28, 28)), (dtype('uint8'), (60000,))]

In [27]:
raw_dataset_for_non_iid=list(zip(mnist_train[0], mnist_train[1]))
raw_dataset_for_non_iid.sort(key=lambda el:el[1])

In [28]:
el_size=300
shard_per_client=2
temp_list=[]
temp_federated_train_data_for_non_iid=[]
for idx, el in enumerate(raw_dataset_for_non_iid) :
    temp_list.append(el)
    if (idx+1)%(el_size*shard_per_client)==0 :
        temp_federated_train_data_for_non_iid.append(temp_list)
        temp_list=[]

In [29]:
temp_federated_train_data_tied_with_tf_dataset_for_non_iid=[
    tf.data.Dataset.from_tensor_slices({ 
        "pixels":np.array( tuple(client_dataset_tuple_for_pixels[0] for client_dataset_tuple_for_pixels in client_dataset), dtype=np.float32),
        "label":np.array( tuple(client_dataset_tuple_for_label[1] for client_dataset_tuple_for_label in client_dataset), dtype=np.float32)
    })for client_dataset in temp_federated_train_data_for_non_iid
]

In [30]:
print(type(temp_federated_train_data_tied_with_tf_dataset_for_non_iid[0]))
print(temp_federated_train_data_tied_with_tf_dataset_for_non_iid[0].element_spec)

<class 'tensorflow.python.data.ops.dataset_ops.TensorSliceDataset'>
{'pixels': TensorSpec(shape=(28, 28), dtype=tf.float32, name=None), 'label': TensorSpec(shape=(), dtype=tf.float32, name=None)}


In [31]:
FRACTION=0.2
TOTAL_CLIENTS = len(temp_federated_train_data_tied_with_tf_dataset_for_non_iid)
BATCH_SIZE = 10
SELECTED_CLIENTS = int(TOTAL_CLIENTS*FRACTION)
NUM_EPOCHS = 5 # fixed!

def preprocess(dataset):
  #print(len(dataset))  # you can see report_local_outputs second element(num_examples)

  def batch_format_fn(element):
    return (tf.reshape(element['pixels'], (-1, 28, 28, 1))/255.0, # Normalizes pixel values between 0 and 1.
             tf.reshape(element['label'], (-1, 1)))

  return dataset.repeat(NUM_EPOCHS).batch(BATCH_SIZE).map(batch_format_fn)


#NUM_CLIENTS = 10
#NUM_EPOCHS = 5
#BATCH_SIZE = 20
#SHUFFLE_BUFFER = 100
#PREFETCH_BUFFER = 10

#def preprocess(dataset):

#  def batch_format_fn(element):
#    """Flatten a batch `pixels` and return the features as an `OrderedDict`."""
#    return collections.OrderedDict(
#        x=tf.reshape(element['pixels'], [-1, 784]),
#        y=tf.reshape(element['label'], [-1, 1]))

#  return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER).batch(
#      BATCH_SIZE).map(batch_format_fn).prefetch(PREFETCH_BUFFER)

In [32]:
client_ids = np.random.choice(TOTAL_CLIENTS, size=SELECTED_CLIENTS, replace=False)

print("selected_client :", client_ids)

federated_train_data = [preprocess(temp_federated_train_data_tied_with_tf_dataset_for_non_iid[x])
  for x in client_ids
]

selected_client : [26 86  2 55 75 93 16 73 54 95 53 92 78 13  7 30 22 24 33  8]


In [33]:
def create_keras_model():
  return tf.keras.models.Sequential([
        tf.keras.Input(shape=(28, 28, 1)),
        tf.keras.layers.Conv2D(32, kernel_size=(5, 5), activation="relu", padding='same'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        
        tf.keras.layers.Conv2D(64, kernel_size=(5, 5), activation="relu", padding='same'),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2), padding='same'),
        
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(512, activation='relu'),
        tf.keras.layers.Dense(10, activation="softmax"),
  ])

In [34]:
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [35]:
@tff.tf_computation
def server_init():
  model = model_fn()
  return model.trainable_variables

@tff.federated_computation
def initialize_fn():
  return tff.federated_value(server_init(), tff.SERVER)

In [36]:
import random

dummy_model = model_fn()
tf_dataset_type = tff.SequenceType(dummy_model.input_spec)
model_weights_type = server_init.type_signature.result
federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)

@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
  """Performs training (using the server model weights) on the client's dataset."""
  # Initialize the client model with the current server weights.
  client_weights = model.trainable_variables
  # Assign the server weights to the client model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        client_weights, server_weights)

  # Use the client_optimizer to update the local model.
  for batch in dataset:
    with tf.GradientTape() as tape:
      # Compute a forward pass on the batch of data
      outputs = model.forward_pass(batch)

    # Compute the corresponding gradient
    grads = tape.gradient(outputs.loss, client_weights)
    grads_and_vars = zip(grads, client_weights)

    # Apply the gradient using a client optimizer.
    client_optimizer.apply_gradients(grads_and_vars)

  return client_weights

@tf.function
def server_update(model, mean_client_weights):
  """Updates the server model weights as the average of the client model weights."""
  model_weights = model.trainable_variables
  # Assign the mean client weights to the server model.
  tf.nest.map_structure(lambda x, y: x.assign(y),
                        model_weights, mean_client_weights)
  return model_weights

@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
  model = model_fn()
  client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
  return client_update(model, tf_dataset, server_weights, client_optimizer)

@tff.tf_computation(model_weights_type)
def server_update_fn(mean_client_weights):
  model = model_fn()
  return server_update(model, mean_client_weights)

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
  # Broadcast the server weights to the clients.
  server_weights_at_client = tff.federated_broadcast(server_weights)

  # Each client computes their updated weights.
  client_weights= tff.federated_map(
      client_update_fn, (federated_dataset, server_weights_at_client))

  # The server averages these updates.
  # When you test network traffic size, remove the everything about metrics
  mean_client_weights= tff.federated_mean(client_weights)

  # The server updates its model.
  server_weights = tff.federated_map(server_update_fn, mean_client_weights)
    
  return server_weights

In [37]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

In [38]:
def federated_evaluate(state, fderated_dataset):
  keras_model = create_keras_model()
  keras_model.compile(
      optimizer='adam',#optimizer='adam',
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]  
  )
  keras_model.set_weights(state)
  loss_acc_per_client_list = [keras_model.evaluate(client_dataset, verbose=0) for client_dataset in fderated_dataset]
  return np.mean(loss_acc_per_client_list, axis=0)

In [39]:
server_state = federated_algorithm.initialize()

NUM_ROUNDS = 200
for round_num in range(1, NUM_ROUNDS+1):
  server_state =federated_algorithm.next(server_state, federated_train_data)
  loss_acc_mean = federated_evaluate(server_state, federated_train_data)
  print(f"round {round_num:4d} - Loss : {loss_acc_mean[0]} , Accuracy : {loss_acc_mean[1]}")


round    1 - Loss : 2.0571262419223784 , Accuracy : 0.225666663213633
round    2 - Loss : 1.943870022892952 , Accuracy : 0.2
round    3 - Loss : 1.8525590240955352 , Accuracy : 0.35850000120699405
round    4 - Loss : 1.8091301381587983 , Accuracy : 0.38608333515003324
round    5 - Loss : 1.7679865777492523 , Accuracy : 0.47158333723200485
round    6 - Loss : 1.7289956033229827 , Accuracy : 0.45183333054883407
round    7 - Loss : 1.6959805727005004 , Accuracy : 0.539916668459773
round    8 - Loss : 1.66426522731781 , Accuracy : 0.44866666654124854
round    9 - Loss : 1.652553105354309 , Accuracy : 0.5219999991357327
round   10 - Loss : 1.6562440365552902 , Accuracy : 0.29941666413797063
round   11 - Loss : 1.8134323090314866 , Accuracy : 0.29800000390969217
round   12 - Loss : 1.825293430685997 , Accuracy : 0.29158333241939544
round   13 - Loss : 2.166898274421692 , Accuracy : 0.3135833352804184
round   14 - Loss : 1.751629814505577 , Accuracy : 0.37675000602612274
round   15 - Loss : 1

KeyboardInterrupt: ignored