### Imports

In [18]:
import os
import h5py

import socket
import struct
import pickle
import sys

import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense, LeakyReLU, BatchNormalization, Reshape, Flatten, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import initializers

from threading import Thread
from threading import Lock


import time

from tqdm import tqdm

import copy

### Required socket functions

In [19]:
def send_msg(sock, msg):
    # prefix each message with a 4-byte length in network byte order
    msg = pickle.dumps(msg)
    l_send = len(msg)
    msg = struct.pack('>I', l_send) + msg
    sock.sendall(msg)
    return l_send

def recv_msg(sock):
    # read message length and unpack it into an integer
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    # read the message data
    msg =  recvall(sock, msglen)
    msg = pickle.loads(msg)
    return msg, msglen

def recvall(sock, n):
    # helper function to receive n bytes or return None if EOF is hit
    data = b''
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data += packet
    return data

### FedAvg

In [20]:
def average_weights(w, datasize):
    """
    Returns the average of the weights.
    """
    
    # Scale each weight by its corresponding data size
    for i, data in enumerate(datasize):
        for j in range(len(w[i])):
            w[i][j] = tf.multiply(w[i][j], float(data))
    
    # Create a deep copy of the first set of weights
    w_avg = copy.deepcopy(w[0])

    # Sum the scaled weights for each layer
    for j in range(len(w_avg)):
        for i in range(1, len(w)):
            w_avg[j] = tf.add(w_avg[j], w[i][j])
        
        # Calculate the average by dividing the sum by the total data size
        w_avg[j] = tf.divide(w_avg[j], float(sum(datasize)))

    return w_avg


### Receive users before training

In [21]:
def run_thread(func, num_user):
    global clientsoclist
    global start_time
    
    thrs = []
    for i in range(num_user):
        conn, addr = s.accept()
        print('Conntected with', addr)
        # append client socket on list
        clientsoclist[i] = conn
        args = (i, num_user, conn)
        thread = Thread(target=func, args=args)
        thrs.append(thread)
        thread.start()
    print("timmer start!")
    start_time = time.time()    # store start time
    for thread in thrs:
        thread.join()
    end_time = time.time()  # store end time
    print("TrainingTime: {} sec".format(end_time - start_time))

In [22]:
def train(userid, train_dataset_size, num_users, client_conn):
    global weights_list
    global global_weights
    global weight_count
    global val_acc
    
    for r in range(rounds):
        with lock:
            if weight_count == num_users:
                for i, conn in enumerate(clientsoclist):
                    datasize = send_msg(conn, global_weights)
                    total_sendsize_list.append(datasize)
                    client_sendsize_list[i].append(datasize)
                    train_sendsize_list.append(datasize)
                    weight_count = 0

        client_weights, datasize = recv_msg(client_conn)
        total_receivesize_list.append(datasize)
        client_receivesize_list[userid].append(datasize)
        train_receivesize_list.append(datasize)

        weights_list[userid] = client_weights
        print("User" + str(userid) + "'s Round " + str(r + 1) +  " is done")
        with lock:
            weight_count += 1
            if weight_count == num_users:
                #average
                global_weights = average_weights(weights_list, datasetsize)
                #train_discriminator(epochs, batch_size)
                
        
    

In [23]:
def receive(userid, num_users, conn): #thread for receive clients
    global weight_count
    
    global datasetsize


    msg = {
        'rounds': rounds,
        'client_id': userid,
        'local_epoch': local_epoch
    }

    datasize = send_msg(conn, msg)    #send epoch
    total_sendsize_list.append(datasize)
    client_sendsize_list[userid].append(datasize)

    train_dataset_size, datasize = recv_msg(conn)    # get total_batch of train dataset
    total_receivesize_list.append(datasize)
    client_receivesize_list[userid].append(datasize)
    
    
    with lock:
        datasetsize[userid] = train_dataset_size
        weight_count += 1
    
    train(userid, train_dataset_size, num_users, conn)

In [24]:
# # Função para treinar apenas o Discriminador da GAN
# def train_discriminator(epochs=1, batch_size=128):
#     batch_count = x_train.shape[0] // batch_size

#     for e in range(epochs):
#         for _ in range(batch_count):
#             image_batch = x_train[np.random.randint(0, x_train.shape[0], size=batch_size)]
#             noise = np.random.normal(0, 1, size=[batch_size, latent_dim])
#             generated_images = generator.predict(noise)

#             X = np.concatenate([image_batch, generated_images])
#             y_dis = np.zeros(2 * batch_size)
#             y_dis[:batch_size] = 0.9  # Rótulos suavizados para o treinamento estável

#             # Treina o discriminador
#             d_loss = discriminator.train_on_batch(X, y_dis)

#         print(f'Época {e+1}/{epochs}, Discriminador Loss: {d_loss}')

### Init

In [25]:
rounds = 3
local_epoch = 30
users = 2 # number of clients

latent_dim = 100

# Inicializador de pesos para as camadas da GAN
initializer = initializers.RandomNormal(mean=0.0, stddev=0.02)

# Constrói o gerador
generator = Sequential()
generator.add(Dense(256, input_dim=latent_dim, kernel_initializer=initializer))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Dense(512, kernel_initializer=initializer))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Dense(1024, kernel_initializer=initializer))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))
generator.add(Dense(784, activation='tanh', kernel_initializer=initializer))
generator.compile(loss='binary_crossentropy', optimizer=Adam(lr=0.0002, beta_1=0.5))

# Constrói o discriminador
discriminator = Sequential()
discriminator.add(Dense(1024, input_dim=784, kernel_initializer=initializer))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dense(512, kernel_initializer=initializer))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dense(256, kernel_initializer=initializer))
discriminator.add(LeakyReLU(alpha=0.2))
discriminator.add(Dense(1, activation='sigmoid', kernel_initializer=initializer))
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))


clientsoclist = [0]*users

start_time = 0
weight_count = 0

global_weights = copy.deepcopy(generator.get_weights())

datasetsize = [0]*users
weights_list = [0]*users

lock = Lock()

total_sendsize_list = []
total_receivesize_list = []

client_sendsize_list = [[] for i in range(users)]
client_receivesize_list = [[] for i in range(users)]

train_sendsize_list = [] 
train_receivesize_list = []



In [26]:
host = socket.gethostbyname(socket.gethostname())
port = 10080
print(host)

127.0.1.1


In [27]:
s = socket.socket()
s.bind((host, port))
s.listen(5)

In [28]:
run_thread(receive, users)

Conntected with ('127.0.0.1', 38612)
Conntected with ('127.0.0.1', 44200)
timmer start!
User1's Round 1 is done
User0's Round 1 is done
User1's Round 2 is done
User0's Round 2 is done
User0's Round 3 is done
User1's Round 3 is done
TrainingTime: 682.3190906047821 sec


In [29]:
end_time = time.time()  # store end time
print("TrainingTime: {} sec".format(end_time - start_time))

TrainingTime: 682.3379547595978 sec


### Print all of communication overhead

In [30]:
# print('val_acc list')
# for acc in val_acc:
#     print(acc)

print('\n')
print('---total_sendsize_list---')
total_size = 0
for size in total_sendsize_list:
#     print(size)
    total_size += size
print("total_sendsize size: {} bytes".format(total_size))
print('\n')

print('---total_receivesize_list---')
total_size = 0
for size in total_receivesize_list:
#     print(size)
    total_size += size
print("total receive sizes: {} bytes".format(total_size) )
print('\n')

for i in range(users):
    print('---client_sendsize_list(user{})---'.format(i))
    total_size = 0
    for size in client_sendsize_list[i]:
#         print(size)
        total_size += size
    print("total client_sendsizes(user{}): {} bytes".format(i, total_size))
    print('\n')

    print('---client_receivesize_list(user{})---'.format(i))
    total_size = 0
    for size in client_receivesize_list[i]:
#         print(size)
        total_size += size
    print("total client_receive sizes(user{}): {} bytes".format(i, total_size))
    print('\n')

print('---train_sendsize_list---')
total_size = 0
for size in train_sendsize_list:
#     print(size)
    total_size += size
print("total train_sendsizes: {} bytes".format(total_size))
print('\n')

print('---train_receivesize_list---')
total_size = 0
for size in train_receivesize_list:
#     print(size)
    total_size += size
print("total train_receivesizes: {} bytes".format(total_size))
print('\n')




---total_sendsize_list---
total_sendsize size: 35850378 bytes


---total_receivesize_list---
total receive sizes: 35849598 bytes


---client_sendsize_list(user0)---
total client_sendsizes(user0): 17925189 bytes


---client_receivesize_list(user0)---
total client_receive sizes(user0): 17924799 bytes


---client_sendsize_list(user1)---
total client_sendsizes(user1): 17925189 bytes


---client_receivesize_list(user1)---
total client_receive sizes(user1): 17924799 bytes


---train_sendsize_list---
total train_sendsizes: 35850264 bytes


---train_receivesize_list---
total train_receivesizes: 35849568 bytes


