## Implementation of IncFL Algorithm 

### Load  and perpare the database for experiment

In [63]:
import nest_asyncio
nest_asyncio.apply()

In [64]:
import collections

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

#np.random.seed(0)

#tff.federated_computation(lambda: 'Hello, World!')()

In [65]:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()

In [66]:
NUM_CLIENTS = 10
BATCH_SIZE = 32
numTAU = 40

def preprocess(dataset):

  def batch_format_fn(element):
    """Flatten a batch of EMNIST data and return a (features, label) tuple."""
    return (tf.reshape(element['pixels'], [-1, 784]), 
            tf.reshape(element['label'], [-1, 1]))

  return dataset.batch(BATCH_SIZE).map(batch_format_fn)

In [67]:
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [preprocess(emnist_train.create_tf_dataset_for_client(x))
  for x in client_ids
]

### Prepare the model

In [68]:
def create_keras_model():
  initializer = tf.keras.initializers.GlorotNormal(seed=0)
  return tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(10, kernel_initializer=initializer),
      tf.keras.layers.Softmax(),
  ])

In [69]:
def model_fn():
  keras_model = create_keras_model()
  return tff.learning.from_keras_model(
      keras_model,
      input_spec=federated_train_data[0].element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

### Tensorflow Blocks

In [70]:
@tf.function
def client_update(model, dataset, server_weights, client_optimizer):
    ''' Performs training (using the server model weights) on the clients dataset.'''
    # Initilize the client model with the current server weights
    client_weights = model.trainable_variables
    
    # Initilize the aggregation weights for the incfl objective calculations
    agg_weights = model.trainable_variables
    
    # Assign the server weights to the client model.
    tf.nest.map_structure(lambda x,y: x.assign(y),
                         client_weights, server_weights)
    
    # Assign the server weights to the agg weights model.
    tf.nest.map_structure(lambda x,y: x.assign(y),
                         agg_weights, server_weights)
    
    L1 = 0.0
    ## L1
    ## Calculate the loss with server_weights
    ## Compute the model.forward_pass(server_weights) 
        ##  with the client dataset
    count = 1.0
    for batch in dataset:
        with tf.GradientTape() as tape:
            # Compute a forward pass on the batch of data
            server_outputs = model.forward_pass(batch, False)
        # Store the loss on running the data through client weights
            L1 = L1 + server_outputs.loss
        count += 1
        
    ## Average loss accross all the batches on the client data
    L1 = L1/count 

    ## L2 over 100 interations         
    # Get the agg weights by running the mini-batch SGD for 100 interations (As mentioned in the paper)
    L2 = 0.0
    for i in range(100):
        for batch in dataset:
            with tf.GradientTape() as tape:
                # Compute a forward pass on the batch of data
                agg_outputs = model.forward_pass(batch)
        
            # Compute the corresponding gradient
            agg_grads = tape.gradient(agg_outputs.loss, agg_weights)
            agg_grads_and_vars = zip(agg_grads, agg_weights)
        
        
        
            # Apply the gradient using a client optimizer
            client_optimizer.apply_gradients(agg_grads_and_vars)
            L2 = agg_outputs.loss

    ## Compute the sigmoid of the difference of L1 - L2
    ## These are the aggregated weights
    aggregation = tf.keras.activations.sigmoid(L1-L2)
    
    # Rounds tau added for each client to optimize
    # Use the client_optimizer to update the local model
    for tau in range(numTAU):
        for batch in dataset:
            with tf.GradientTape() as tape:
                # Compute a forward pass on the batch of data
                outputs = model.forward_pass(batch)
        
            # Compute the corresponding gradient
            grads = tape.gradient(outputs.loss, client_weights)
            grads_and_vars = zip(grads, client_weights)

            # Apply the gradient using a client optimizer
            client_optimizer.apply_gradients(grads_and_vars)
        
    return client_weights, aggregation

In [71]:
@tf.function
def server_update(
    model,
    client_weights,
    server_optimizer
):

    model_weights = model.trainable_variables
    
    # Confirm if this approach is correct to implement the client optimizer
    
    server_optimizer.apply_gradients(
    list(zip(tf.nest.flatten(client_weights), tf.nest.flatten(model_weights))))
    
    ''' Updates the server model weights as the average of the clients model weights'''
    ## Trying assign_add() instead of just using the normal assign
    
    ###
    # The next thing that I want to try is to use: 
    # multiply the clients weight by 0.001
    # and then add the client weights to the original model weights
    ## Something like client_weights.__mul__(0.01)
    ## then observe the results
    ## And finally try this client_weights.__mul__(-0.01)
    ###
    
    ## Attempt 2!
    #client_weights = Lambda(lambda x: 0.01 * client_weights)
    #tf.nest.map_structure(lambda x: x.assign(x * -0.0001),
    #                      client_weights)
    ## Attempt 1!
    #tf.nest.map_structure(lambda x,y: x.assign_add(y),
    #                     model_weights, client_weights)
    
    return model_weights

### TF Core blocks

In [72]:
@tff.tf_computation
def server_init():
    model = model_fn()
    return model.trainable_variables

In [73]:
@tff.federated_computation
def initialize_fn():
    return tff.federated_value(server_init(), tff.SERVER)

In [74]:
## Turn the client_update into tf_computation that accepts 
# a client's dataset and server weights
# And outputs an updated clent weights tensor 

whimsy_model = model_fn()
tf_dataset_type = tff.SequenceType(whimsy_model.input_spec)

In [75]:
str(tf_dataset_type)

'<float32[?,784],int32[?,1]>*'

In [76]:
model_weights_type = server_init.type_signature.result 

In [77]:
str(model_weights_type)

'<float32[784,10],float32[10]>'

In [78]:
@tff.tf_computation
def agg_weights_init():
    agg_weights = 0.0
    return agg_weights

In [79]:
@tff.federated_computation
def agg_weights_fn():
    return tff.federated_value(agg_weights_init(), tff.CLIENTS) 


In [80]:
agg_weights_type = agg_weights_init.type_signature.result

In [81]:
str(agg_weights_type)

'float32'

In [82]:
## Now create a tf_computation for the client update
@tff.tf_computation(tf_dataset_type, model_weights_type)
def client_update_fn(tf_dataset, server_weights):
    model = model_fn()
    client_optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)
    weights, aggweights = client_update(model, tf_dataset, server_weights,client_optimizer)
    return weights, aggweights

In [83]:
## the tff.tf_computation version of the server update 
@tff.tf_computation(model_weights_type,agg_weights_type)
def server_update_fn(client_weights,agg_weights):
    model = model_fn()
    #_learning_rate =  (1/(agg_weights+2.5))
    ## Adding this as from somewhere in the backend a negative sign is propgating 
    _learning_rate = -0.5
    server_optimizer = tf.keras.optimizers.SGD(learning_rate=_learning_rate)
    return server_update(model, client_weights,server_optimizer)

In [84]:
## Now we just need to call a tf.federated_computation function
## That wraps the two tf_computation functions that we created above
## The function will accept teo federated values:
## 1. server weights with placement server
## 2. client datasets with placemnt clients

federated_server_type = tff.FederatedType(model_weights_type, tff.SERVER)
federated_dataset_type = tff.FederatedType(tf_dataset_type, tff.CLIENTS)
federated_agg_weights_type =  tff.FederatedType(agg_weights_type, tff.CLIENTS)

In [85]:
'''
1. Refer to the build federated average code
2. Get the weighted server code working - this means you would have implemented the weighted fedavg algorithm
3. The next thing you need to edit is the client update function
    a. In this fucntion you need to calculate the 2 loses
    b. one of the client weights and the other on the server model
    c. calculate the differnce of these weights (sigmoid(F(w) - F(w^)))
    d. Send the new updated weights along with these aggregated w eights to the server
    {Try your own variation where you save the calculation of the loss * updated weight and then send it to the server}
4. Update the server model: Calculate the new learning rate and the coefficient of the wights
'''

@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_weights, federated_dataset):
        #[No Change] Broadcast the global model weights to the clients
        server_weights_at_clients = tff.federated_broadcast(server_weights)
    
        ########### [This needs to change] Each client computes their updated weights ####################
        ## This function will return two things
        ## updated weights and aggregated weights
        client_weights,agg_weights = tff.federated_map(
                        client_update_fn, (federated_dataset, server_weights_at_clients))

        ########## [This needs to change] The server averages these updates ###################
        ## this function will use the aggregation weights as the multiplier of the updated weights 
        ## you need to adjust the learning rate of the model
        mean_weighted_client_weights = tff.federated_mean(client_weights,agg_weights)
        mean_client_agg_weights = tff.federated_mean(agg_weights)
        #[No Change] The server updates its model 
        # server update also need to have an learner? maybe?
        server_weights = tff.federated_map(server_update_fn, (mean_weighted_client_weights,mean_client_agg_weights))
        ## If I pass the mean_client_weights then the algorithm works fine but when 
        ## i pass the client_weights directly to gives an error
        ## that means the the function federated_mean is doing some kind of transformation that I didn't notice
        
        return server_weights 

In [86]:
federated_algorithm = tff.templates.IterativeProcess(
    initialize_fn=initialize_fn,
    next_fn=next_fn
)

In [87]:
str(federated_algorithm.initialize.type_signature)

'( -> <float32[784,10],float32[10]>@SERVER)'

In [88]:
str(federated_algorithm.next.type_signature)

'(<server_weights=<float32[784,10],float32[10]>@SERVER,federated_dataset={<float32[?,784],int32[?,1]>*}@CLIENTS> -> <float32[784,10],float32[10]>@SERVER)'

In [89]:
central_emnist_test = emnist_test.create_tf_dataset_from_all_clients()
central_emnist_test = preprocess(central_emnist_test)

In [90]:
def evaluate(server_state):
    keras_model = create_keras_model()
    keras_model.compile(
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
    )
    keras_model.set_weights(server_state)
    keras_model.evaluate(central_emnist_test)

In [91]:
server_state = federated_algorithm.initialize()
evaluate(server_state)



In [None]:
for round in range(10):
    server_state = federated_algorithm.next(server_state, federated_train_data)

In [None]:
evaluate(server_state)