# Federated Coordination Set Method: Non-Linear Regression Toy Example

In [None]:
import ipyparallel as ipp
n = 4
rc = ipp.Cluster(engines="mpi", n=n).start_and_connect_sync()
view = rc[:]
rc.ids

In [None]:
%%px --block
# MPI initialization, library imports and sanity checks on all engines
from mpi4py import MPI
import numpy as np
import time
import tensorflow as tf
from comm_weights import unflatten_weights, flatten_weights
import os
import matplotlib.pyplot as plt
import copy
%matplotlib qt
np.random.seed(132)

mpi = MPI.COMM_WORLD
bcast = mpi.bcast
barrier = mpi.barrier
rank = mpi.rank
size = mpi.size
print("MPI rank: %i/%i" % (mpi.rank,mpi.size))

## Helper Functions

In [None]:
 %%px --block
def synthetic_data2d(n, alpha):
    noise_x = alpha*np.random.normal(size=n)
    noise_y = alpha*np.random.normal(size=n)
    x = np.random.uniform(-2*np.pi, 2*np.pi, size=(n, 2))
    y = np.sin(np.cos(x[:, 1]) + noise_x) + np.exp(np.cos(x[:, 0]) + noise_y)
    return x.astype(np.float32), y.astype(np.float32)

In [None]:
 %%px --block 
# Implement Custom Loss Function
@tf.function
def consensus_loss(y_true, y_pred, z, l2):

    # local error
    local_error = y_true - y_pred
    local_square_error = tf.square(local_error)
    local_mse = tf.reduce_mean(local_square_error)

    # consensus loss error
    consensus_error = z - y_pred
    consensus_square_error = tf.square(consensus_error)
    consensus_mse = l2*tf.reduce_sum(consensus_square_error)

    return local_mse + consensus_mse

In [None]:
%%px --block
def get_model_architecture(model):
    # find shape and total elements for each layer of the resnet model
    model_weights = model.get_weights()
    layer_shapes = []
    layer_sizes = []
    for i in range(len(model_weights)):
        layer_shapes.append(model_weights[i].shape)
        layer_sizes.append(model_weights[i].size)
    return layer_shapes, layer_sizes

In [None]:
 %%px --block
def model_sync(model, layer_shapes, layer_sizes, size):
    # necessary preprocess
    model_weights = model.get_weights()
    # flatten tensor weights
    send_buffer = flatten_weights(model_weights)
    recv_buffer = np.zeros_like(send_buffer)
    # perform all-reduce to synchronize initial models across all clients
    MPI.COMM_WORLD.Allreduce(send_buffer, recv_buffer, op=MPI.SUM)
    # divide by total workers to get average model
    recv_buffer = recv_buffer / size
    # update local models
    new_weights = unflatten_weights(recv_buffer, layer_shapes, layer_sizes)
    model.set_weights(new_weights)

In [None]:
%%px --block
def average_models(model, local_update, layer_shapes, layer_sizes):
    model_weights = model.get_weights()
    # flatten tensor weights
    coordinate_weights = flatten_weights(model_weights)
    local_weights = flatten_weights(local_update)
    next_weights = unflatten_weights(np.average([coordinate_weights, local_weights], axis=0),
                                           layer_shapes, layer_sizes)
    # update model weights to average
    model.set_weights(next_weights)

In [None]:
 %%px --block
# Play around with this more
def set_learning_rate(optimizer, epoch):
    optimizer.lr = optimizer.lr
    #if epoch >= 175:
    #    optimizer.lr = 0.005
    #elif epoch >= 350:
    #    optimizer.lr = 0.001
    '''
    if epoch <= 30:
        optimizer.lr = 0.0025
    if 30 < epoch <= 100:
        optimizer.lr = 0.0015
    elif 100 < epoch <= 200:
        optimizer.lr = 0.001
    elif 200 < epoch <= 300:
        optimizer.lr = 0.0005
    elif 300 < epoch <= 400:
        optimizer.lr = 0.00045
    elif 400 < epoch <= 450:
        optimizer.lr = 0.00005
    else:
        optimizer.lr = 0.00001
    '''

In [None]:
%%px --block
# 2d mesh example
def mesh_grid(N):
    xx = np.linspace(-2 * np.pi, 2 * np.pi, N)
    xv, yv = np.meshgrid(xx, xx)
    z = np.empty(N * N)
    c = 0
    for i in range(N):
        for j in range(N):
            z[c] = np.sin(np.cos(yv[i, j])) + np.exp(np.cos(xv[i, j]))
            c += 1
    X = np.vstack((xv.flatten(), yv.flatten())).transpose()
    X = (X - np.min(X)) / (np.max(X) - np.min(X))
    return z, X

In [None]:
%%px --block
# Model evaluation and plots
def model_results(model, test_dataset):

    # Compile and test model
    model.compile(loss=lossF, optimizer=optimizer)
    model.evaluate(test_dataset)

    # generate test mesh
    N = 50
    z, X = mesh_grid(N)
    
    # get predictions from current model
    test_predictions = model.predict(X).flatten()
    
    # plot results
    plt.figure(1)
    plt.scatter(z, test_predictions, color='blue')
    a = min(np.min(z), np.min(test_predictions)) - 0.5
    b = max(np.max(z), np.max(test_predictions)) + 0.5
    e = np.linspace(a, b, 100)
    plt.xlim(a, b)
    plt.ylim(a, b)
    plt.plot(e, e, color='black')
    plt.title('Actual vs. Predicted Values')
    plt.show()

    plt.figure(2)
    ax = plt.axes(projection='3d')
    ax.scatter(X[:, 0], X[:, 1], z, color='black', alpha=0.1, label='Actual Fit')
    ax.scatter(X[:, 0], X[:, 1], test_predictions, color='red', label='Predicted Fit')
    ax.legend(loc='upper right')
    plt.title('Predicted vs. Actual Fit')
    plt.show()

## Training Loop

In [None]:
%%px --block
def train(model, lossF, optimizer, train_dataset, coordination_dataset, epochs, coord_batch_size, batches,
          layer_shapes, layer_sizes, l2):
    loss_metric = tf.keras.metrics.MeanSquaredError()
    for epoch in range(epochs):

        # Adjust learning rate
        set_learning_rate(optimizer, epoch)

        # Forward Pass of Coordination Set (get z)
        send_predicted = np.zeros((coord_batch_size, batches), dtype=np.float32)
        recv_avg_pred = np.zeros((coord_batch_size, batches), dtype=np.float32)
        # loss = np.zeros(batches, dtype=np.float64)
        for c_batch_idx, (c_data, c_target) in enumerate(coordination_dataset):
            pred = model(c_data, training=True)
            send_predicted[:, c_batch_idx] = pred.numpy().flatten()

        # Communication Process Here
        MPI.COMM_WORLD.Allreduce(send_predicted, recv_avg_pred, op=MPI.SUM)
        recv_avg_pred = recv_avg_pred / size

        # save initial model
        start_model = copy.deepcopy(model.get_weights())

        # Local Training
        for batch_idx, (data, target) in enumerate(train_dataset):
            with tf.GradientTape() as tape:
                y_p = model(data, training=True)
                loss_val = lossF(y_true=target, y_pred=y_p)
            grads = tape.gradient(loss_val, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))
            loss_metric.update_state(target, y_p)

        # save model after local update
        local_model = copy.deepcopy(model.get_weights())

        # reset model weights
        model.set_weights(start_model)
        # Consensus Training
        for c_batch_idx, (c_data, c_target) in enumerate(coordination_dataset):
            with tf.GradientTape() as tape:
                c_yp = model(c_data, training=True)
                loss_val = consensus_loss(y_true=c_target, y_pred=c_yp,
                                           z=recv_avg_pred[:, c_batch_idx].reshape(coord_batch_size, 1),
                                           l2=l2)

            grads = tape.gradient(loss_val, model.trainable_weights)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # update model weights
        average_models(model, local_model, layer_shapes, layer_sizes)

        if rank == 0 and epoch % 10 == 0:
            print('(Rank %d) Training Loss for Epoch %d: %0.4f' % (rank, epoch, loss_metric.result()))
        loss_metric.reset_states()

## Data Initialization

In [None]:
 %%px --block
# Hyper-parameters
n = 1000
alpha = 0.05
epochs = 500
learning_rate = 0.01
l2 = 0.1

# 2d example
# X, Y = synthetic_data2d(int(n/size), alpha)
X, Y = synthetic_data2d(n, alpha)

# Rescale data between 0 and 1
data_max = np.max(X)
data_min = np.min(X)
X = (X - data_min) / (data_max - data_min)
# Split up data
train_split = 0.8
batch_size = 64
num_data = len(Y)
train_x = X[0:int(num_data * train_split), :]
train_y = Y[0:int(num_data * train_split)]
test_x = X[int(num_data * train_split):, :]
test_y = Y[int(num_data * train_split):]
# convert to tensors
train_dataset = tf.data.Dataset.from_tensor_slices((train_x, train_y))
test_dataset = tf.data.Dataset.from_tensor_slices((test_x, test_y))
# shuffle and batch
train_dataset = train_dataset.shuffle(int(num_data * train_split)).batch(batch_size)
test_dataset = test_dataset.batch(batch_size)

# Coordination set construction
coord_size = 160
# c_batch_size = 16
c_batch_size = 160
c_num_batches = int(coord_size/c_batch_size)

true_x = np.tile(np.linspace(-2*np.pi, 2*np.pi, coord_size), (2, 1)).transpose()
true_y = np.sin(np.cos(true_x[:, 1])) + np.exp(np.cos(true_x[:, 0]))
true_x = true_x.astype(np.float32)
true_y = true_y.astype(np.float32)
coord_max = np.max(true_x)
coord_min = np.min(true_x)
true_x = (true_x - coord_min) / (coord_max - coord_min)
coordination_dataset = tf.data.Dataset.from_tensor_slices((true_x, true_y))
coordination_dataset = coordination_dataset.batch(c_batch_size)

## Model Initialization

In [None]:
%%px --block
# Initialize Model
model = tf.keras.Sequential()
model.add(tf.keras.layers.Dense(10, activation='relu', input_shape=(2,)))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(64, activation='relu'))
model.add(tf.keras.layers.Dense(10, activation='relu'))
model.add(tf.keras.layers.Dense(1))

# get model architecture
layer_shapes, layer_sizes = get_model_architecture(model)

# Initialize Local Loss Function
lossF = tf.keras.losses.MeanSquaredError()

# Initialize Optimizer
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
# optimizer = tf.keras.optimizers.SGD(learning_rate=learning_rate)


In [None]:
%%px --block
# Sync model weights
model_sync(model, layer_shapes, layer_sizes, size)
MPI.COMM_WORLD.Barrier()

In [None]:
%%px --block
train(model, lossF, optimizer, train_dataset, coordination_dataset, epochs, c_batch_size, c_num_batches, 
     layer_shapes, layer_sizes, l2)

In [16]:
%%px --block
if mpi.rank == 0:
    model_results(model, test_dataset)

%px:   0%|          | 0/4 [00:00<?, ?tasks/s]