# Weight aggregation of Flower

In [34]:
from functools import reduce

import numpy as np

In [35]:
def aggregate(results):
    # Calculate the total number of examples used during training
    num_examples_total = sum([num_examples for _, num_examples in results])
    
    # Create a list of weights, each multiplied by the related number of example
    weighted_weights = [
        [layer * num_examples for layer in weights] for weights, num_examples in results
    ]
    
    # Compute average weights of each layer
    weights_prime = [
        reduce(np.add, layer_updates) / num_examples_total for layer_updates in zip(*weighted_weights)
    ]
    
    return weights_prime

## Example 1
### Both clients have the same number of examples

In [41]:
# define the weights and number of training instances of the first client
# all weights are set to 10
weights = np.ones(shape=(1,3,3)) * 10
num_examples = 5000
result_1 = (weights, num_examples)

# define the weights and number of training instances of the second client
# all weights are set to 5
weights = np.ones(shape=(1,3,3)) * 5
num_examples = 5000
result_2 = (weights, num_examples)

results = [result_1, result_2]
results

[(array([[[10., 10., 10.],
          [10., 10., 10.],
          [10., 10., 10.]]]),
  5000),
 (array([[[5., 5., 5.],
          [5., 5., 5.],
          [5., 5., 5.]]]),
  5000)]

In [42]:
print(aggregate(results)) # what we expected

[array([[7.5, 7.5, 7.5],
       [7.5, 7.5, 7.5],
       [7.5, 7.5, 7.5]])]


## Example 2
### Both clients have different number of examples

In [43]:
# define the weights and number of training instances of the first client
# all weights are set to 10
weights = np.ones(shape=(1,3,3)) * 10
num_examples = 5000
result_1 = (weights, num_examples)

# define the weights and number of training instances of the second client
# all weights are set to 5
weights = np.ones(shape=(1,3,3)) * 5
num_examples = 1000
result_2 = (weights, num_examples)

results = [result_1, result_2]
results

[(array([[[10., 10., 10.],
          [10., 10., 10.],
          [10., 10., 10.]]]),
  5000),
 (array([[[5., 5., 5.],
          [5., 5., 5.],
          [5., 5., 5.]]]),
  1000)]

In [44]:
print(aggregate(results)) # the result is shifted towards the result of the first client

[array([[9.16666667, 9.16666667, 9.16666667],
       [9.16666667, 9.16666667, 9.16666667],
       [9.16666667, 9.16666667, 9.16666667]])]
