# Federated Split Learning (FSL) Allocation Algorithm

## The algorithm 

In [265]:
# Reminder of what each state needs:
# state0_FSL() needs () and returns (param_servers, edge_servers, clients)
# state1_FSL() needs (param_servers, edge_servers, clients) and returns (best_path)

def FSL_allocation_alg(event_processors, event_links, event_clients):
    
    # Set global variables so all states run correctly
    global processors
    processors = event_processors
    global links
    links = event_links
    global state
    state = 0
    global previous_state
    previous_state = 0
    global min_clients
    min_clients = event_clients
    
    # Create local local variable used to store the result of state1_FSL(), which needs to be used later on
    result_state1 = None
    
    # With the obtained data, run the simulator for the desired single event
    start = time.time()
    
    while state < 3:

        if state == 0:        
            result_state0 = state0_FSL()

        elif state == 1:
            # If I got here from state0_FSL()
            if previous_state == 0:
                result_state1 = state1_FSL(result_state0[0], result_state0[1], result_state0[2])

    end = time.time()
    runtime = end - start
    
    # Return a specific result depending on the final state at the end of the algorithm
    if state == 3:
        return ["Success", result_state1[0], result_state1[1], result_state1[2], result_state1[3], runtime]
    elif state == 4:
        return ["Unfeasible Type A", None, None, None, None, runtime]
    elif state == 5:
        return ["Unfeasible Type B", None, None,None, None, runtime]
    elif state == 6:
        return ["Unfeasible Type C", None, None,None, None, runtime]
    elif state == 7:
        return ["Unfeasible Type D", None, None,None, None, runtime]

## Each individual state

In [263]:
# Declare our global variables first
def state0_FSL():
    # Create our variables that will be passed from one state to another
    param_servers = []
    edge_servers = []
    clients = []
    
    # Create our global state and previous_state monitor variables
    global state
    state = 0
    global previous_state
    previous_state = 0
    
    # First we check which processors can work as param_servers, edge_servers, and clients
    for processor in processors:
        # If the processor can work as a param_server, then...
        # Plus a heuristic to stop at just the right amount of processors needed to obtain a good path
        if processor.residual >= processor.M_agg and len(param_servers) < ((min_clients * 2) + 1):
            param_servers.append(processor)
        # If the processor can work as a edge_server, then...
        if processor.residual >= (processor.M_server + total_batch_size) and len(edge_servers) < ((min_clients * 2) + 1):
            edge_servers.append(processor)
        # If the processors can work as a client, then...
        if processor.residual >= (processor.M_client + total_batch_size) and len(clients) < ((min_clients * 2) + 1):
            clients.append(processor)
    
    # We will deal with the fact that a processor can be in multiple lists later on
    
    # If no edge_servers exist, then we have to stop the program here
    # We only focus on edge_servers as the criteria because they have to do the heaviest processing and require
    # the most residual memory
    if len(edge_servers) == 0:
        # Finish the current state
        previous_state = 0
        # Go to state that adequately finishes the program
        state = 4
        return
    
    # If there aren't enough clients to the min_clients > len(clients), then end the program
    elif len(clients) < min_clients:
        # Finish the current state
        previous_state = 0
        # Go to state that adequately finishes the program
        state = 5
        return
    
    # In this configuration, I can only split properly if min_clients is less than half of the processors minus 1
    # OR in the very specific case where there are 3 processors and the user only wants one split.
    # If none of these are met, then the program ends
    elif min_clients > (len(processors)/2 - 1) and min_clients != 1 and len(processors) != 3:
        # Finish the current state
        previous_state = 0
        # Go to state that adequately finishes the program
        state = 6
        return
    
    # We finished building the param_servers, edge_servers, and clients lists, so now we proceed to the next state
    # Define the state that just finished!
    previous_state = 0

     # Define the next state!
    state = 1
    
    return (param_servers, edge_servers, clients)

In [1]:
def state1_FSL(param_servers, edge_servers, clients):
    
    global previous_state
    global state

    # Now we will proceed to define ALL possible edge-client pairs. An "alg_time" variable will be calculated, 
    # but this time is only used for the algorithm to make a decision later on as to which are the BEST pairs. 
    edge_client_times_all = []
    
    for edge_server in edge_servers:
        for client in clients:
#             if edge_server.name != client.name and len(edge_client_times_all) < max_paths:
            if edge_server.name != client.name:
                alg_time = ( 
                           (client.D_client_in / find_in_list(links, "link_" + str(client.ID) 
                                                              + str(client.ID)).value)
                           + client.G_client / client.power 
                           + client.D_client_out / find_in_list(links, "link_" + str(client.ID) 
                                                                + str(edge_server.ID)).value
                           + edge_server.G_server / edge_server.power
                           )
                edge_client_times_all.append([edge_server, client, alg_time])
    
    # Now that we have ALL potential edge-client pairs, we need to make it so that only ONE client can be paired
    # with ONE server. For example, we may have edge2 with client1, and edge1 with client2, or edge1 with client3
    # (Here the number after "edge" or "client" represents the processor.ID)
    # We only want the BEST out of all of these potential combinations with processor1, processor2, and so on
    # We are essentially obtaining the BEST edge-client pairs of ALL possible edge-client pairs, and making sure
    # no edge or client is repeated in another pair!
    
    # First sort the pairs from best (smallest) alg_time to worst (highest) alg_time
    sorted_edge_client_times = sorted(edge_client_times_all, key = lambda x: x[2])
    
    # Now add the BEST (first) edge-client pair from the sorted list to a list of pairs to eventually use
    pairs_to_use = []
    pairs_to_use.append(sorted_edge_client_times[0])
    
    # Now to obtain the rest of the BEST possible pairs. We also have to make sure that the "best pairs" we add
    # do not have overlapping edges or clients. For example, we CANNOT have one best pair be 2-1 and another be
    # 3-2. Processor2 is overlapping here, which is not permitted!
    
    for pair1 in sorted_edge_client_times:
        edges_to_use = []
        clients_to_use = []
        for pair2 in pairs_to_use:
            edges_to_use.append(pair2[0])
            clients_to_use.append(pair2[1])
        if ( pair1[0] not in edges_to_use and pair1[0] not in clients_to_use and 
            pair1[1] not in edges_to_use and pair1[1] not in clients_to_use and
            len(pairs_to_use) < min_clients ):
            pairs_to_use.append(pair1)
    
    # Now the "pairs_to_use" list contains the BEST edge-client pairs for our current network configuration. Here
    # BEST means "lowest FP training time", considering power, memory, bandwidth, and all related variables for
    # each edge-client configuration
    
    # Now we need to choose the best param_server to work with these edge-client pairs. To do so, we obtain the 
    # total BANDWIDTH between each of the EDGES of the chosen pairs to each potential PARAM SERVER. The config
    # that leads to the MAX total bandwidth between edge_servers and a param_server will tell us which is the best
    # param_server to use. Here the processing power does not play much of a role, considering the low FLOPs 
    # required for the param_server to meet its purpose. Hence:
    
    # Create the total B list
    edge_param_Bs = []
    
    # Before proceeding, there is a possibility that a processor in the original param_server list has already
    # been chosen as a client or an edge_processor. Hence, we must first modify this original list and only use
    # the processors that have NOT been chosen. To do so:
    
    # New param_servers list to avoid modifying the original one and filter this one instead:
    available_param_servers = param_servers [:]
    # Now we modify this local list
    for param_server in param_servers:
        for pair in pairs_to_use:
            if ( param_server.name == pair[0].name or param_server.name == pair[1].name ) and (
            param_server in available_param_servers ):
                available_param_servers.remove(param_server)
                
    # Check if there are any feasible processors to serve as param_servers before proceeding, otherwise proceed
    if len(available_param_servers) == 0:
        # Finish the current state
        previous_state = 0
        # Go to state that adequately finishes the program
        state = 7
        return
    
    # Pick a server from the param_servers list
    for server in available_param_servers:
        # Create the accumulator variable
        total_B_edge_param = 0
        # Check the max_paths constraint
        if len(edge_param_Bs) < max_paths:
            # Now, for each of the pairs in pairs_to_use, add up the B from the EDGE from each of these pairs to the 
            # chosen param_server
            for pair in pairs_to_use:
                total_B_edge_param = (total_B_edge_param + 
                                      find_in_list(links, "link_" + str(pair[0].ID) + str(server.ID)).value
                                     )
            # Append the total value to the list, and the server associated with that value
            edge_param_Bs.append([server, total_B_edge_param])

    # After the previous actions have been performed for all edges from all pairs to all potential param_servers, 
    # we now choose the configuration that had the highest bandwidth between them all. 
    # Remember that any edge in any configuration will always output D_weights, since it is the same full NN
    # for every single edge-client pair. Hence, we only care about the max possible bandwidth for
    # each coniguration
    best_edge_param_B = max(edge_param_Bs, key=itemgetter(1))
    best_param_server = best_edge_param_B[0]

    # Now we have the best param_server for the current pairs. Hence, we now have the BEST edge-client pairs 
    # PLUS the BEST param_server to go with these

    # Now we can calculate the total time for this path. Remember that because this is a PARALLELIZED
    # architecture, I only need to know the time of the STRAGGLER in the system! I do NOT need to add ALL times, 
    # since everything is happening at the same time! By the time the straggler is done, ALL the other edge-client
    # pairs will ALSO be done. Hence:
    T_train_agg = []
    best_edge_servers = []
    best_clients = []
    for pair in pairs_to_use:
        # Remember that each pair = [edge_server, client, alg_time (not used here)]
        edge = pair[0]
        client = pair[1]
        T_train_agg.append(
            ( 
            client.D_client_in / find_in_list(links, "link_" + str(client.ID) + str(client.ID)).value + 
            (1 + 1.5) * client.G_client / client.power +
            2.0 * client.D_client_out / find_in_list(links, "link_" + str(client.ID) + str(edge.ID)).value + 
            (1 + 1.5) * edge.G_server / edge.power
            )
                        )
        # Also log the following important info for later
        best_edge_servers.append(edge.name)
        best_clients.append(client.name)
    T_train_agg_straggler_time = max(T_train_agg)
        
    # Not outside the for loop. Remember aggregation is only done ONCE per EPOCH!
#     T_train_total = (T_train_agg_straggler_time * batch_size * total_batches_FSL
#                      + best_param_server.G_agg / best_param_server.power)
    # Select the edge to use
    edge = pairs_to_use[0][0]
    T_train_total = (T_train_agg_straggler_time * total_batches_FSL 
                     + 2.0 * edge.D_weights / find_in_list(links, "link_" + str(edge.ID) + str(best_param_server.ID)).value
                     + best_param_server.G_agg / best_param_server.power
                    )
    
    # Now we have obtained the total training time of ONE epoch for the best param-edge-client config
    # With this algorithm, no looping is needed! Additionally, we want the output to consider ALL epochs 
    # defined in the config.ipynb file AND to be in HOURS, not SECONDS
    parall_factor_FSL = 0.9
    best_path = [best_param_server.name, best_edge_servers, best_clients, (T_train_total * epochs * correction_factor_parall * parall_factor_FSL) / time_factor]
    
    # Finally, finish the program and return the necessary info
    previous_state = 1
    state = 3
    return best_path 