In [None]:
# version 13.02 (Perfect Knowledge)
# Single run

In [None]:
import gc
import os
import numpy as np
from math import floor, ceil, pi
from time import time
from tqdm import tqdm
from pickle import dump, load
from copy import deepcopy
from scipy.optimize import minimize
from scipy.signal import butter,filtfilt
rng = np.random.RandomState(2345466)
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten
from tensorflow.keras.models import Sequential, load_model, save_model
from tensorflow.keras.callbacks import ModelCheckpoint
from matplotlib import rcParams,pyplot as plt
%matplotlib inline

In [None]:
params = {
   'axes.labelsize': 12,
    'axes.labelweight': "bold",
   'font.size': 12,
   'font.weight': "bold",
   'legend.fontsize': 12,
   'xtick.labelsize': 12,
   'ytick.labelsize': 12,
   'text.usetex': False,
   'figure.figsize': [6, 3.7082]
   }
rcParams.update(params)

In [None]:
IMAGE_W = 256
IMAGE_H = 256
n_clients = 5
RIS = 3
B = [96, 203, 309, 444, 642] # data available at each node

In [None]:
def butter_lowpass_filter(data, cutoff, fs, order):
    normal_cutoff = cutoff / (0.5*fs)
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y

In [None]:
class GCB:

    RA = 0.0
    RB = 0.0
    D = 0.0
    N = 0
    k = 0
    v = 0.0
    B = 0.0
    q = 0.0
    rB = 0.0
    zn_star = 0.0
    zi_star = 0.0
    rA = 0.0
    uA = 0.0
    uB = 0.0
    du = 0.0

    def __init__(self, RA, RB, N, k, n_epoch):
        self.RA = RA
        self.RB = RB
        self.D = self.RA - self.RB
        self.N = N
        self.k = k
#         print('Iternation:', n_epoch+1, '\n')
#         print('GCB Parameters:\nRA:', self.RA, '\nRB:', self.RB, '\nD:', self.D)
#         print('GCB Parameters:\nRA:', self.RA, '\nRB:', self.RB, '\nD:', self.D, '\nN:', self.N, '\nk:', self.k)

    def get_game_values(self, v, B):
        self.v = v
#         self.A = np.sort(A)[::-1]
        self.B = B
        self.q = self.B * self.v
        self.q = self.q/np.sum(self.q)
#         print('Agent Criticality, q:', self.q)
        
        return self.q

    def get_reward_allocation(self, allocation_probability):
        self.rB = allocation_probability * self.RB
#         self.rB = np.sort(self.rB)[::-1]

        if (np.sum(self.rB) - self.RB) != 0:
#             print('rB without difference adjustment:', self.rB)
            rB_diff = np.sum(self.rB) - self.RB
#             print('rB difference:', rB_diff)
            self.rB = self.rB - rB_diff/self.N
#             print('rB with difference adjustment:', self.rB)

#         else:
#             print('rB:', self.rB)

#         print('Sum of rB:', np.sum(self.rB))
        
        return self.rB

    def play_gcb(self, zn, show_graph):
        fzn = np.zeros(len(zn))

        for zn_i in zn:
            fzn_sum = 0
            for i in range(self.N - 1):
                fzn_sum = fzn_sum + np.sqrt(((zn_i*zn_i)*(self.q[i]/self.q[self.N-1])) + (np.absolute(self.q[i] - self.q[self.N-1])/((self.k*self.k)*self.q[self.N-1])))
            fzn[np.where(zn == zn_i)[0][0]] = zn_i + fzn_sum 
        idx = np.argwhere(np.diff(np.sign(fzn - [self.D]*len(zn)))).flatten()
#         plt.figure()
#         plt.plot(zn, fzn, color="midnightblue")
#         plt.plot(zn, [self.D]*len(zn), color="red")
        zn_min = idx[0]
#         plt.plot(zn, fzn, color="midnightblue")
#         plt.plot(zn, [self.D]*len(zn), color="red")
#         print(fzn)
#         print(self.D)
        zn_max = np.where(np.ceil(fzn) == self.D)[0][-1] + 1
        zn_idx = np.array([zn_min, zn_max])

        if show_graph:
            plt.figure()
            #plt.ylim(-0.0005, 0.0025)
            plt.plot(zn, fzn, color="midnightblue")
            plt.plot(zn, [self.D]*len(zn), color="red")
            plt.title('Determining $z_n^*$')
            # idx = np.argwhere(np.diff(np.sign(fzn - [D]*len(zn)))).flatten()
            # plt.plot(zn[idx], fzn[idx], 'ro')
            plt.plot(zn[zn_idx], fzn[zn_idx], 'ro')
            plt.xlabel('zn')
            plt.ylabel('f(zn)')
            # plt.legend(['f(zn)', 'D'], loc=2)
            plt.xticks(np.linspace(min(zn), max(zn), self.N-1));
            # # plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
            # # plt.savefig("images/1e.eps", format='eps', dpi=1200)

        self.zn_star = zn[zn_max]

        self.zi_star = np.zeros(self.N)

        for i in range(self.N - 1):
            self.zi_star[i] = np.sqrt(((self.zn_star*self.zn_star)*(self.q[i]/self.q[self.N-1])) + (np.absolute(self.q[i] - self.q[self.N-1])/((self.k*self.k)*self.q[self.N-1])))
        self.zi_star[self.N-1] = self.zn_star

        if show_graph:
            plt.figure()
            #plt.ylim(-0.0005, 0.0025)
            plt.plot(range(self.N), self.zi_star, color="limegreen")
            plt.scatter(range(self.N), self.zi_star, color="limegreen")
            # plt.plot(zn, [D]*len(zn), color="red")
            plt.title('Determining $z_i^*$')
            plt.xlabel('Agent ID')
            plt.ylabel('Rewards Allocation Difference, $z_i^*$')
            # plt.legend(['f(zn)', 'D'], loc=2)
            plt.xticks(range(self.N), range(1, self.N + 1));
            # # plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
            # # plt.savefig("images/1e.eps", format='eps', dpi=1200)

        self.rA = self.rB + self.zi_star

        if (np.sum(self.rA) - self.RA) != 0:
#             print('rA without difference adjustment:', self.rA)
            rA_diff = np.sum(self.rA) - self.RA
#             print('rA difference:', rA_diff)
            self.rA = self.rA - rA_diff/self.N
#             print('rA with difference adjustment:', self.rA)

#         else:
#             print('rA:', self.rA)
            
#         print('Sum of rA:', np.sum(self.rA))
#         print('\n')
        return self.rA, self.zi_star

    def get_utility(self):
        u = (self.q/np.pi)*np.arctan(self.k*self.zi_star) + 0.5
        self.uA = u
        self.uB = 1 - self.uA
        self.du = self.uA - self.uB
    
#         print('uA:', self.uA, '\nuB:', self.uB, '\ndu:', self.du)
        
        return self.uA, self.uB, self.du

    def plot_results(self):
        path = os.getcwd() + '\graphs'
        os.makedirs(path, exist_ok=True)
            
        # plot 1
        cmap = plt.get_cmap("Paired")
        colors = cmap(range(self.N))
        plt.figure()
        plt.bar(range(self.N), self.q, color=colors)
        plt.xticks(range(self.N), range(1, self.N + 1))
        # plt.ticklabel_format(axis="y", style="sci", scilimits=(0, 0))
        # plt.title('Type for each Agent')
        plt.xlabel('Agent ID')
        plt.ylabel('Agent Criticality')
        # # plt.legend(['Complete Information', 'Incomplete Information'], loc=2)
        plt.savefig("graphs/1.eps", format='eps', bbox_inches='tight', dpi=1200)
        plt.savefig("graphs/1.png", format='png', bbox_inches='tight', dpi=1200)

        # plot 2
        plt.figure()
        plt.plot(range(self.N), self.zi_star, color="limegreen", marker = 'o')
        # plt.title('Quality and Quantity of Information from each user')
        plt.xlabel('Agent ID')
        plt.ylabel('Rewards Allocation Difference')
        # plt.text(0, 1.44, '(b)', fontsize = 12)
        # plt.legend(['CCT+GCB', 'ICT+GCB', 'High $q_n$ GCB', 'Medium $q_n$ GCB', 'Low $q_n$ GCB'], loc=4, prop={'size': 11})
        plt.xticks(range(self.N), range(1, self.N + 1));
        # plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
        plt.savefig("graphs/2.eps", format='eps', bbox_inches='tight', dpi=1200)
        plt.savefig("graphs/2.png", format='png', bbox_inches='tight', dpi=1200)

        # plot 3
        plt.figure()
        plt.plot(range(self.N), self.rA, color="blue", marker = 'o')
        plt.plot(range(self.N), self.rB, color="red", marker = 'o')
        # plt.title('Quality and Quantity of Information from each user')
        plt.xlabel('Agent ID')
        plt.ylabel('Rewards Allocation')
        # plt.text(0, 1.44, '(b)', fontsize = 12)
        plt.legend(['Server A', 'Server B'], loc=0, prop={'size': 11})
        plt.xticks(range(self.N), range(1, self.N + 1));
        # plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
        plt.savefig("graphs/3.eps", format='eps', bbox_inches='tight', dpi=1200)
        plt.savefig("graphs/3.png", format='png', bbox_inches='tight', dpi=1200)

        # plot 4
        plt.figure()
        plt.plot(range(self.N), self.uA, color="blue", marker = 'o')
        plt.plot(range(self.N), self.uB, color="red", marker = 'o')
        # plt.title('Quality and Quantity of Information from each user')
        plt.xlabel('Agent ID')
        plt.ylabel('Utility of Servers')
        # plt.text(0, 0.573, '(c)', fontsize = 12)
        plt.legend(['Server A', 'Server B'], loc=0, prop={'size': 11})
        plt.xticks(range(self.N), range(1, self.N + 1));
        # plt.ticklabel_format(axis="y", style="sci", scilimits=(0,0))
        plt.savefig("graphs/4.eps", format='eps', bbox_inches='tight', dpi=1200)
        plt.savefig("graphs/4.png", format='png', bbox_inches='tight', dpi=1200)

In [None]:
def get_location():
    # location = list of agent location coordinates (x, y)
    # centres = list of centres of each region
    # v = priority of regions
    # agent_v = list of priorities of each agent

    # Agent 1 location
    x1 = -0.63
    y1 = 0.0

#     # Agent 2 location
#     x2 = 0.20
#     y2 = 0.0

    # Agent 3 location
    x3 = -0.42
    y3 = 0.49

#     # Agent 4 location
#     x4 = -1
#     y4 = 0.37

    # Agent 5 location
    x5 = -1.0
    y5 = -0.04

#     # Agent 6 location
#     x6 = -0.8
#     y6 = -0.6

    # Agent 7 location
    x7 = 0.42
    y7 = -0.49

#     # Agent 8 location
#     x8 = 1
#     y8 = -0.4

    # Agent 9 location
    x9 = 1.0
    y9 = 0.04

#     # Agent 10 location
#     x10 = 0.8
#     y10 = 0.59


#     location = [(x1, y1, 0.5), (x2, y2, 0.5), (x3, y3, 0.5), (x4, y4, 0.5), (x5, y5, 0.5), (x6, y6, 0.5), (x7, y7, 0.5), (x8, y8, 0.5), (x9, y9, 0.5), (x10, y10, 0.5)]
    
    location = [(x1, y1, 0.5), (x3, y3, 0.5), (x5, y5, 0.5), (x7, y7, 0.5), (x9, y9, 0.5)]
    
    # Region 1 Centre
    cx1 = 0.0 
    cy1 = 0.0

    # Region 2 Centre
    cx2 = -1 
    cy2 = 0.5773502691896257

    # Region 3 Centre
    cx3 = -1 
    cy3 = -0.5773502691896257

    # Region 4 Centre
    cx4 = 1 
    cy4 = -0.5773502691896257

    # Region 5 Centre
    cx5 = 1 
    cy5 = 0.5773502691896257

    centres = [(cx1, cy1, 0), (cx2, cy2, 0), (cx3, cy3, 0), (cx4, cy4, 0), (cx5, cy5, 0)]
    return 100*np.array(location), 100*np.array(centres)

In [None]:
location, centres = get_location()

In [None]:
def gen_poly(r,n):
    a = 360/n
    theta = pi*np.linspace(0, 360, num = n, endpoint = False)/180
    x = r*np.cos(theta)
    y = r*np.sin(theta)
    return np.array([[x[i],y[i],0] for i in range(n)])

In [None]:
pseudo_pos = gen_poly(100,10)
pseudo_pos = np.concatenate((np.array([[0, 0, 150]]),pseudo_pos))

In [None]:
plt.scatter(location[:,0],location[:,1], color='blue')
plt.scatter(pseudo_pos[:,0],pseudo_pos[:,1], color='red')

In [None]:
def get_region():
#     agent_idx = list()
#     agent_region_idx = list()

#     l = 'result_'+str(RIS)
#     with open(l,'rb') as fp:
#         data = load(fp)
#         stats = data['stats']
#     mu = stats[2]
#     sigma = stats[3]
#     local_loc = deepcopy(location)
    
#     for i in range(n_clients):
#         local_loc[i] = np.random.choice([-1,1], 3)*np.random.normal(mu[2*i], sigma[2*i], 3) + location[i]
        
#     #print(location)
#     #print(local_loc)
    
#     for idx, coords in enumerate(local_loc):
#         agent_idx.append(idx)
#         agent_x = coords[0]
#         agent_y = coords[1]
#         agent_z = coords[2]

#         distance = list()

#         for idx, coords in enumerate(centres):
#             region_idx = idx
#             region_cx = coords[0]
#             region_cy = coords[1]
#             region_cz = coords[2]
#             distance.append(np.sqrt((region_cx - agent_x)**2 + (region_cy - agent_y)**2 + (region_cz - agent_z)**2))
#         #print(distance)

#         d_min = min(distance)
#         agent_region_idx.append(distance.index(d_min))

# #     for i in range(len(location)):
# #         print('Agent', agent_idx[i], '-> Region:', agent_region_idx[i])


# #     v = [0.6, 0.7, 0.8, 0.9, 1.]
#     v = [0.0625, 0.125, 0.25, 0.5, 1.0]

#     agent_v = np.zeros(n_clients)
#     for i in range(n_clients):
#         agent_v[i] = v[agent_region_idx[i]]

    v = [0.0625, 0.125, 0.25, 0.5, 1.0]
    agent_v = np.array(v)
        
    return agent_v

In [None]:
def load_data(IMAGE_W = 256, IMAGE_H = 256):
    dataGen = ImageDataGenerator(rescale=1.0/255.0)
    train_generators = list()
    test_generators = list()
    for i in range(n_clients):
        train_generators.append(dataGen.flow_from_directory('data/Train/agent_'+str(i+1)+'/train/', target_size=(IMAGE_W, IMAGE_H), class_mode='binary'))
        test_generators.append(dataGen.flow_from_directory('data/Train/agent_'+str(i+1)+'/test/', target_size=(IMAGE_W, IMAGE_H), class_mode='binary'))
    server = dataGen.flow_from_directory('data/Test/server/', target_size=(IMAGE_W, IMAGE_H), class_mode='binary')
    trainX = list()
    trainY = list()
    testX = list()
    testY = list()

    for n in range(n_clients):
        trX = np.empty((0, IMAGE_W, IMAGE_H, 3))
        trY = np.empty((0,))
        for i in range(len(train_generators[n])):
            X, y = train_generators[n].next()
            trX = np.concatenate((trX, X))
            trY = np.concatenate((trY, y))
        trainX.append(trX)
        trainY.append(trY)
        teX = np.empty((0, IMAGE_W, IMAGE_H, 3))
        teY = np.empty((0,))
        for i in range(len(test_generators[n])):
            X, y = test_generators[n].next()
            teX = np.concatenate((teX, X))
            teY = np.concatenate((teY, y))
        testX.append(teX)
        testY.append(teY)
    return trainX, trainY, testX, testY, server

In [None]:
def proc_data(dataX, dataY, split = 1):
    F = np.where(dataY == 0)[0]
    N = np.where(dataY == 1)[0]
    idx = np.random.permutation(len(F))
    l = floor(split*len(F))
    if l == 0:
        l = 1
    idxC = idx[:l]
    idxP = idx[l:]
    idxFC = F[idx[idxC]]
    idxFP = F[idx[idxP]]
    
    idx = np.random.permutation(len(N))
    l = floor(split*len(N))
    if l == 0:
        l = 1
    idxC = idx[:l]
    idxP = idx[l:]
    idxNC = N[idx[idxC]]
    idxNP = N[idx[idxP]]
    
    idxC = np.concatenate((idxNC,idxFC))
    idx = np.random.permutation(len(idxC))
    idxC = idxC[idx]
    
    idxP = np.concatenate((idxNP,idxFP))
    idx = np.random.permutation(len(idxP))
    idxP = idxP[idx]
    return dataX[idxP], dataY[idxP], dataX[idxC], dataY[idxC]

In [None]:
def plot_history(l):
    n = n_clients
    with open('ver_13_02/' + l,'rb') as fp:
        data = load(fp)
        histories = np.array(data['histories'])
        shistory = np.array(data['shistory'])
        
        plt.figure(figsize = (16, 8))
        cmap = plt.get_cmap("Set1")
        colors = cmap(range(n_clients))
        for i in range(n):
#             plt.plot(histories[i]['serA_acc'], color = colors[i], label='Agent ' + str(i+1))
            plt.plot(butter_lowpass_filter(histories[i]['serP_acc'], 0.5, 2, 2), color = colors[i], label='Agent ' + str(i+1))
        plt.title("Individual accuracy at the server P", fontsize=20)
        plt.xlabel("Communication round", fontsize=20)
        plt.xticks(fontsize=14)
        plt.ylabel("Accuracy (%)", fontsize=20)
        plt.yticks(fontsize=14)
        plt.legend(fontsize=14)
        
        plt.figure(figsize = (16, 8))
        cmap = plt.get_cmap("Set1")
        colors = cmap(range(n_clients))
        for i in range(n):
#             plt.plot(histories[i]['serB_acc'], color = colors[i], label='Agent ' + str(i+1))
            plt.plot(butter_lowpass_filter(histories[i]['serC_acc'], 0.5, 2, 2), color = colors[i], label='Agent ' + str(i+1))
        plt.title("Individual accuracy at the server C", fontsize=20)
        plt.xlabel("Communication round", fontsize=20)
        plt.xticks(fontsize=14)
        plt.ylabel("Accuracy (%)", fontsize=20)
        plt.yticks(fontsize=14)
        plt.legend(fontsize=14)
        
        plt.figure(figsize = (16, 8))
        plt.plot(np.array(shistory)[:,0], color = 'blue', label="Server P")
        plt.plot(np.array(shistory)[:,1], color = 'red', label="Server C")
        plt.title("Aggregated accuracy at the servers", fontsize=20)
        plt.xlabel("Communication round", fontsize=20)
        plt.xticks(fontsize=14)
        plt.ylabel("Accuracy (%)", fontsize=20)
        plt.yticks(fontsize=14)
        plt.legend(fontsize=14)
        
        plt.figure(figsize = (16, 8))
        plt.plot(butter_lowpass_filter(np.array(shistory)[:,0], 0.5, 2, 2), color = 'blue', label="Server P")
        plt.plot(butter_lowpass_filter(np.array(shistory)[:,1], 0.5, 2, 2), color = 'red', label="Server C")
        plt.title("Aggregated accuracy at the servers", fontsize=20)
        plt.xlabel("Communication round", fontsize=20)
        plt.xticks(fontsize=14)
        plt.ylabel("Accuracy (%)", fontsize=20)
        plt.yticks(fontsize=14)
        plt.legend(fontsize=14)

In [None]:
def create_model():
    model = Sequential()
    model.add(Conv2D(8, (3, 3), activation='tanh', input_shape=(256, 256, 3)))
    model.add(MaxPooling2D(2,2))
    model.add(Conv2D(16, (3, 3), activation='tanh'))
    model.add(MaxPooling2D(2, 2))
    model.add(Flatten())
    model.add(Dense(10, activation='tanh'))
    model.add(Dense(1, activation='sigmoid'))

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    return model

In [None]:
def train_model(trainX, trainY, testX, testY, server_gen):
    n = n_clients
    RPT = 1000
    RCT = 800
    n_epoch = 20
    k = 10
    q = np.ndarray(shape=(n_epoch, n_clients))
    rC = np.ndarray(shape=(n_epoch, n_clients))
    zi_star = np.ndarray(shape=(n_epoch, n_clients))
    rP = np.ndarray(shape=(n_epoch, n_clients))
    uP = np.ndarray(shape=(n_epoch, n_clients))
    uC = np.ndarray(shape=(n_epoch, n_clients))
    du = np.ndarray(shape=(n_epoch, n_clients))
    f = np.ndarray(shape=(n_epoch, n_clients))
    t = np.ndarray(shape=(n_epoch, n_clients))
    saved_models_P = [None] * n_clients
    saved_models_C = [None] * n_clients
    # Create history storing dictionary
    histories = list()
    # Create aggregated model history
    shistory = list()
    for i in range(n):
        histories.append({"loss":list(), "acc":list(), "val_loss":list(), "val_acc":list(), "serP_acc":list(), "serC_acc":list()})
    # Run clients
    for j in range(n_epoch):
        ## GCB
        RP = round(RPT/(n_epoch - j))
        RC = round(RCT/(n_epoch - j))
        gcb = GCB(RP, RC, n_clients, k, j)
        v = get_region()
#         print(v)
        q[j, :] = gcb.get_game_values(v, B)
#         allocation_probability = np.sort(rng.uniform(0.0, 1.0, n_clients))[::-1]
#         allocation_probability = allocation_probability/np.sum(allocation_probability)
        allocation_probability = q[j, :]/np.sum(q[j, :])
        rC[j, :] = gcb.get_reward_allocation(allocation_probability)
        zn = np.linspace(-10, 10, 500)
        rP[j, :], zi_star[j, :] = gcb.play_gcb(zn, show_graph=False)
        uP[j, :], uC[j, :], du[j, :] = gcb.get_utility()
        # gcb.plot_results()
        RPT = RPT - RP
        RCT = RCT - RC
#         print('v: ', v)
#         print('rA: ', rA[j,:])
#         print('rB: ', rB[j,:])
        ## Agent
        wP = list()
        wC = list()
        AP = np.zeros(n_clients)
        AC = np.zeros(n_clients)
#         Calculation of f
        f[j, :] = uC[j, :]/(uP[j, :] + uC[j, :])
#         f[j, :] = (1.5 - ((rA[j, :] - rB[j, :])/(rA[j, :] + rB[j, :])))/2
#         print('f: ', f[j,:])
        for i in range(n):
            trainXP, trainYP, trainXC, trainYC = proc_data(trainX[i], trainY[i], split = f[j, i])
            
#             filepath = 'P_models/model_'+str(i+1)+'_P.h5'
            if j != 0:
#                 modelP = load_model(filepath)
                modelP = saved_models_P[i]
                modelP.set_weights(WP)
            else:
                modelP = create_model()
                print("Sample size (Server P): ", len(trainYP)," Fire: ", np.sum(trainYP == 0)," Neutral: ", np.sum(trainYP == 1))
            # fit model for the server B
            history = modelP.fit(trainXP, trainYP, validation_data=(testX[i], testY[i]), batch_size = 32, epochs = 1, verbose = 0)
#             save_model(modelP,filepath)
            saved_models_P[i] = modelP
            wP.append(modelP.get_weights())
            # Test using central server test set
            acc = 0
            _, acc = modelP.evaluate(server_gen, verbose=0)
            AP[i] = acc
            
#             filepath = 'C_models/model_'+str(i+1)+'_C.h5'
            if j != 0:
#                 modelC = load_model(filepath)
                modelC = saved_models_C[i]
                modelC.set_weights(WC)
            else:
                modelC = create_model()
                print("Sample size (Server C): ", len(trainYC)," Fire: ", np.sum(trainYC == 0)," Neutral: ", np.sum(trainYC == 1))
            # fit model for the server B
            modelC.fit(trainXC, trainYC, validation_data=(testX[i], testY[i]), batch_size = 32, epochs = 1, verbose = 0)
#             save_model(modelC,filepath)
            saved_models_C[i] = modelC
            wC.append(modelC.get_weights())
            # Test using central server test set
            acc = 0
            _, acc = modelC.evaluate(server_gen, verbose=0)
            AC[i] = acc
            
            # stores history
            histories[i]['loss'].append(history.history['loss'][0])
            histories[i]['acc'].append(history.history['accuracy'][0])
            histories[i]['val_loss'].append(history.history['val_loss'][0])
            histories[i]['val_acc'].append(history.history['val_accuracy'][0])
            histories[i]['serP_acc'].append(AP[i])
            histories[i]['serC_acc'].append(AC[i])
        
        ## Server
        uv = v/np.sum(v)
        # Aggregate weights
        WP = list()
        for weights_list_tuple in zip(*wP):
            temp = np.array(weights_list_tuple)
            array = np.zeros_like(temp[0])
            for i in range(n):
                array = array + uv[i]*temp[i]
            WP.append(array)
        WC = list()
        for weights_list_tuple in zip(*wC):
            temp = np.array(weights_list_tuple)
            array = np.zeros_like(temp[0])
            for i in range(n):
                array = array + uv[i]*temp[i]
            WC.append(array)

        # Test the aggregated model using central server test set
        modelP.set_weights(WP)
        accP = 0
        _, accP = modelP.evaluate(server_gen, verbose=0)
        modelC.set_weights(WC)
        accC = 0
        _, accC = modelC.evaluate(server_gen, verbose=0)
        print("Comm Round %d: (Server P) %.3f (Server C) %.3f" % (j, 100*accP, 100*accC))
        shistory.append([accP,accC])
        
#         print('\n')
    # Save the results to a file
    data = {'histories':histories,'shistory':shistory, 'B':B, 'f':f, 'q':q, 'rP': rP, 'rC': rC, 'uP': uP, 'uC': uC, 'du': du, 'zi_star': zi_star}
    with open('ver_13_02/' + file_name,'wb') as fp:
        dump(data, fp)

In [None]:
trainX, trainY, testX, testY, server_gen = load_data()

In [None]:
for i in tqdm(range(4)):
    file_name = 'ver_13_02_'+str(round(time()))
    train_model(trainX, trainY, testX, testY, server_gen)

In [None]:
plot_history(file_name)