In [1]:
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import LSTM, Dense
import numpy as np
import socket
import struct
import pickle
from threading import Thread, Lock
import time
import os
import h5py

In [2]:
def create_lstm_model():
    model = Sequential([
        LSTM(200, input_shape=(10, 1), activation='relu'),  # Adjust input_shape based on your dataset
        Dense(100, activation='relu'),
        Dense(10)
    ])
    model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
    return model

In [3]:
def send_msg(sock, msg):
    msg = pickle.dumps(msg)
    msg = struct.pack('>I', len(msg)) + msg
    sock.sendall(msg)

def recv_msg(sock):
    raw_msglen = recvall(sock, 4)
    if not raw_msglen:
        return None
    msglen = struct.unpack('>I', raw_msglen)[0]
    return pickle.loads(recvall(sock, msglen))

def recvall(sock, n):
    data = bytearray()
    while len(data) < n:
        packet = sock.recv(n - len(data))
        if not packet:
            return None
        data.extend(packet)
    return data


In [4]:
def federated_average(models_weights):
    """Compute the federated average of the models' weights."""
    # Ensure there are models to average
    if not models_weights:
        raise ValueError("No model weights to average")

    # Initialize average weights as the structure of the first model's weights, filled with zeros
    average_weights = [np.zeros_like(weights) for weights in models_weights[0]]

    # Sum up all model weights
    for weights in models_weights:
        for i, weight in enumerate(weights):
            average_weights[i] += weight

    # Divide by the number of models to get the average
    for i in range(len(average_weights)):
        average_weights[i] /= len(models_weights)

    return average_weights


def client_handler(connection, client_id, global_model):
    # Send global model weights
    weights = global_model.get_weights()
    send_msg(connection, weights)
    
    # Receive updated weights from client
    updated_weights = recv_msg(connection)
    print(f"Received updated weights from client {client_id}")
    
    connection.close()
    return updated_weights

def run_server(host='localhost', port=12345, clients=1):
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.bind((host, port))
    server_socket.listen(clients)
    print(f"Server listening on {host}:{port}")

    client_threads = []
    client_models = []
    global_model = create_lstm_model()
    
    for i in range(clients):
        connection, address = server_socket.accept()
        print(f"Connected to client {address}")
        client_thread = Thread(target=client_handler, args=(connection, i, global_model))
        client_threads.append(client_thread)
        client_thread.start()

    for thread in client_threads:
        thread.join()

    # Assuming client_handler returns and appends updated weights to client_models
    global_weights = federated_average(client_models)
    global_model.set_weights(global_weights)
    print("Updated global model with averaged weights")

    # Optionally, save the global model
    global_model.save('global_model.h5')

run_server()

Server listening on localhost:12345
Connected to client ('127.0.0.1', 54493)
Received updated weights from client 0


ValueError: No model weights to average