In [24]:
%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


In [25]:
# import necessary modules

import collections

import numpy as np
import tensorflow as tf
import keras
from keras import layers
import json
import datetime
import pickle
from matplotlib import pyplot as plt

In [4]:
def import_data():
    train_path = "../../data/FMNIST/FCIFAR_alpha10_train.json"
    test_path = "../../data/FMNIST/FCIFAR_alpha10_test.json"
    with open(train_path, "rb") as f:
        train_data = pickle.load(f)
    with open(test_path, "rb") as f:
        test_data = pickle.load(f)
    return train_data, test_data

In [5]:

np.random.seed(12345)

train_data, test_data = import_data()

In [6]:
N = 10
m = 3
K = 10
d = 28 * 28
hidden_unit = 64
user_name = ["f_{0:05d}".format(n) for n in range(N)]

x_test = tf.convert_to_tensor(test_data["user_data"]["test"]['x'])
y_test = tf.convert_to_tensor(test_data["user_data"]["test"]['y'])

In [7]:
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

In [8]:
def client_train(model, x_train, y_train):
    model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.2, 
              callbacks=[tensorboard_callback], steps_per_epoch=1, verbose=0)
    return model

In [73]:
def get_MLPmodel():
    model = keras.Sequential([
        keras.Input(shape=(d, )),
        layers.Dense(64, activation = "relu", use_bias=False),
        layers.Dense(K, use_bias=False),
    ])
    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=keras.optimizers.SGD(),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )   
    return model
    

In [9]:
def get_CNNmodel():
    model = keras.Sequential()
    model.add(layers.Conv2D(6, (5, 5), activation='relu', input_shape=(32, 32, 3)))
    model.add(layers.Conv2D(16, (5, 5), activation='relu'))
    model.summary()

    model.add(layers.Flatten())
    model.add(layers.Dense(120))
    model.add(layers.Dense(84, activation='relu'))
    model.add(layers.Dense(10))
    model.summary()

    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        optimizer=keras.optimizers.SGD(),
        metrics=[keras.metrics.SparseCategoricalAccuracy()],
    )   

    return model

In [19]:
# model = get_CNNmodel()
# y = model(x_test)

# dlist = [w.shape for w in model.get_weights()]
# print(dlist)
# w_num = len(dlist)
# print(len(model.layers))

Model: "sequential_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_18 (Conv2D)          (None, 28, 28, 6)         456       
                                                                 
 conv2d_19 (Conv2D)          (None, 24, 24, 16)        2416      
                                                                 
Total params: 2872 (11.22 KB)
Trainable params: 2872 (11.22 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Model: "sequential_9"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_18 (Conv2D)          (None, 28, 28, 6)         456       
                                                                 
 conv2d_19 (Conv2D)          (None, 24, 24, 16)        2416      
                                                                

In [26]:

global_model = get_CNNmodel()

w_num = 10
T = 5
for t in range(T):
    w_agg = dict(zip(list(range(w_num)), [
        np.zeros((5, 5, 3, 6)), 
        np.zeros((6, )), 
        np.zeros((5, 5, 6, 16)),
        np.zeros((16, )),
        np.zeros((9216, 120)),
        np.zeros((120, )),
        np.zeros((120, 84)),
        np.zeros((84, )),
        np.zeros((84, 10)),
        np.zeros((10, )),
    ]))
    participants_set = np.random.choice(N, m, replace=False)

    for n in participants_set:
        x_train = tf.convert_to_tensor(train_data["user_data"][user_name[n]]['x'])
        y_train = tf.convert_to_tensor(train_data["user_data"][user_name[n]]['y'])

        local_model = client_train(global_model, x_train, y_train)
        for idx in range(w_num):
            w_agg[idx] += local_model.get_weights()[idx]


    for idx in range(w_num):     
        w_agg[idx] = w_agg[idx]/m   

    idx = 0
    for layer in global_model.layers:
        if layer.name.startswith("flatten"):
            continue
        global_model.get_layer(layer.name).set_weights([w_agg[idx], w_agg[idx + 1]])
        idx += 2
        
    test_scores = global_model.evaluate(x_test, y_test, verbose=0)
    print("Test loss:", test_scores[0])
    print("Test accuracy:", test_scores[1])



Model: "sequential_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_28 (Conv2D)          (None, 28, 28, 6)         456       
                                                                 
 conv2d_29 (Conv2D)          (None, 24, 24, 16)        2416      
                                                                 
Total params: 2872 (11.22 KB)
Trainable params: 2872 (11.22 KB)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________
Model: "sequential_14"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d_28 (Conv2D)          (None, 28, 28, 6)         456       
                                                                 
 conv2d_29 (Conv2D)          (None, 24, 24, 16)        2416      
                                                              

In [55]:
N = 10
m = 3
K = 10
d = 28 * 28
hidden_unit = 64
user_name = ["f_{0:05d}".format(n) for n in range(N)]

x_test = tf.convert_to_tensor(test_data["user_data"]["test"]['x'])
y_test = tf.convert_to_tensor(test_data["user_data"]["test"]['y'])

(784, 64) (64, 10)
