In [None]:
CLIENT_COUNT = 3
ROUNDS = 10

SERVER_ADRESS = "tcp://*:5555"
PUBLISHING_ADRESS = "tcp://*:5557"

In [None]:
import zmq
import time
import zmq
import pickle
import math
import base64

# Helpers

def convert_size(size_bytes):
    if size_bytes == 0:
        return "0B"
    size_name = ("B", "KB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB")
    i = int(math.floor(math.log(size_bytes, 1024)))
    p = math.pow(1024, i)
    s = round(size_bytes / p, 2)
    return "%s %s" % (s, size_name[i])

def elapsed_time_total(start, end):
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Total Training Time: {:0>2}:{:0>2}:{:05.2f}"
                .format(int(hours),int(minutes),seconds))

def elapsed_time_avg(start, end):
    hours, rem = divmod(end-start, 3600)
    minutes, seconds = divmod(rem, 60)
    print("Averaging overhead: {:0>2}:{:0>2}:{:05.2f}"
                .format(int(hours),int(minutes),seconds))

def write_data(file_name, data):
    if type(data) == bytes:
        #bytes to base64
        data = base64.b64encode(data)
         
    with open(file_name, 'wb') as f: 
        f.write(data)
 
def read_data(file_name):
    with open(file_name, "rb") as f:
        data = f.read()
    #base64 to bytes
    return base64.b64decode(data)

In [None]:
from model_def import SimpleNN

In [None]:
context = zmq.Context()
socket = context.socket(zmq.ROUTER)
socket.bind(SERVER_ADRESS)
pub_socket = context.socket(zmq.PUB)
pub_socket.bind(PUBLISHING_ADRESS)
print("Context created")


start_total = time.time()

print("The server is running now!")

model = SimpleNN()
final_model = None

all_models = []

round_index = 0
for i in range(ROUNDS * CLIENT_COUNT):
    identifier, message = socket.recv_multipart()
    print(f"Received request from {identifier}")

    if message == b"New":
        print("New client connected, sending model")
        # TODO: Verify this works
        toSend = pickle.dumps(model)
        socket.send_multipart([identifier, toSend])
        print(f"Model sent")
    else:
        print(f"Received model from client")
        received_model = pickle.loads(message)
        all_models.append(received_model)

        print(f"Have models: {len(all_models)}/{CLIENT_COUNT}")
        if len(all_models) == CLIENT_COUNT:
            print(f"Averaging models ({round_index})")
            start_avg = time.time()

            averaged_model = SimpleNN()
            for param in averaged_model.parameters():
                param.data *= 0

            for model in all_models:
                for param, avg_param in zip(model.parameters(), averaged_model.parameters()):
                    avg_param.data += param.data / CLIENT_COUNT
                    
            all_models = []

            end_avg = time.time()
            elapsed_time_avg(start_avg, end_avg)

            final_model = averaged_model

            print(f"Sending averaged model ({round_index})")
            toSend = pickle.dumps(averaged_model)
            pub_socket.send(toSend)
            print(f"Averaged model sent ({round_index})")
            round_index += 1

end_total = time.time()
elapsed_time_total(start_total, end_total)