# Split Learning (SL) Allocation Algorithm

## The algorithm

In [1]:
# Reminder of what each state needs:
# state0_SL() needs () and returns (clients, servers, paths)
# state1_SL() needs (clients, servers, paths) and returns (paths)
# OR (clients, servers, paths)
# state2_SL() needs (paths) and returns (best_path)

def SL_allocation_alg(event_processors, event_links, event_clients):
    
    # Set global variables so all states run correctly
    global processors
    processors = event_processors
    global links
    links = event_links
    global state
    state = 0
    global previous_state
    previous_state = 0
    global min_clients
    min_clients = event_clients
    
    # Create local local variable used to store the result of state2_SL(), which needs to be used later on
    result_state2 = None
    
    # With the obtained data, run the simulator for the desired single event
    start = time.time()
    
    while state < 3:

        if state == 0:        
            result_state0 = state0_SL()

        elif state == 1:
            # If I got here from state0_SL()
            if previous_state == 0:
                result_state1 = state1_SL(result_state0[0], result_state0[1], result_state0[2])
            
            # If I got here from state1_SL()
            elif previous_state == 1:
                result_state1 = state1_SL(result_state1[0], result_state1[1], result_state1[2])

        elif state == 2:
            result_state2 = state2_SL(result_state1)

    end = time.time()
    runtime = end - start
    
    # Return a specific result depending on the final state at the end of the algorithm
    if state == 3:
        return ["Success", result_state2[0], result_state2[1], result_state2[2], runtime]
    elif state == 4:
        return ["Unfeasible Type A", None, None, None, runtime]
    elif state == 5:
        return ["Unfeasible Type B", None, None,None, runtime]

## Each individual state

In [2]:
# Declare our global variables first
def state0_SL():
    # Create our variables that will be passed from one state to another
    paths = []
    clients = []
    servers = []
    
    # Create our global state and previous_state monitor variables
    global state
    state = 0
    global previous_state
    previous_state = 0
    
    # First we check which processors can work as servers and as clients
    for processor in processors:
        #If the processor can work as a server, then he can ALSO work as a potential client (M_server > M_client)
        # Plus a heuristic just to stop at "max_paths" potential paths to calculate
        if processor.residual >= (processor.M_server + total_batch_size) and len(servers) <= max_paths:
            servers.append(processor)
            clients.append(processor)
        # If the processor can ONLY work as a client, then add him to the list of potential clients ONLY
        # Plus the heuristic to limit the amount of potential clients to consider
        elif processor.residual >= (processor.M_client + total_batch_size) and len(clients) <= max_paths:
            clients.append(processor)
    
    # We will deal with the fact that a processor can be both in a later stage of the algorithm
    
    # If no servers exist, then we have to stop the program here
    if len(servers) == 0:
        # Finish the current state
        previous_state = 0
        # Go to state that adequately finishes the program
        state = 4
        return
    
    # If there aren't enough clients to the min_clients > len(clients), then end the program
    elif len(clients) < min_clients:
        # Finish the current state
        previous_state = 0
        # Go to state that adequately finishes the program
        state = 5
        return
    
    # We finished building the clients and servers list, so now we proceed to the next state
    # Define the state that just finished!
    previous_state = 0

     # Define the next state!
    state = 1
    
    return (clients, servers, paths)

In [3]:
def state1_SL(clients, servers, paths):
    
    global previous_state
    global state
    
    # First copy the information from clients and servers onto a local list
    local_clients = clients [:]
    local_servers = servers [:]
    
    # Now we pick the first server from the list, and remove it from the servers list AND the clients list
    # (IF it exists in the clients list of course)
    current_server = local_servers[0]
    local_servers.remove(current_server)
    for client in local_clients:
        if client.name == current_server.name:
            local_clients.remove(current_server)
    
    # Now we will proceed to calculate all the vectors and matrices we need to make the paths later
    T_client_serv = []
    
    # The following will be for the scaling factor to be used
    total_client_D = 0
    for client in local_clients:
        total_client_D = total_client_D + client.D_client_in
    
    # Now we will calculate the full epoch (FP + BP) times for all client-current_server combos
    for client in local_clients:
        
        # Obtain bandwidth (link) information
        bandwidth_fetch = find_in_list(links, "link_" + str(client.ID) + str(client.ID)).value
        bandwidth_server = find_in_list(links, "link_" + str(client.ID) + str(current_server.ID)).value
                
        # Calculate the client-server combo's respective T_fetch, T_procs, and T_transf
        T_fetch = client.D_client_in / bandwidth_fetch
        T_proc_client = client.G_client / client.power
        T_proc_server = current_server.G_server / current_server.power
        T_transf = client.D_client_out / bandwidth_server
        
        # Calculate the total T_client_serv for this specific client-server combo and log it
        
        # Both FP AND BP
        T_client_serv_value = T_fetch + (1+1.5) * T_proc_client + (1+1.5) * T_proc_server + 2 * T_transf
        
        # Append
        T_client_serv.append([client, current_server, T_client_serv_value])
    
    # Now we sort all the full epoch times from lowest (fastest) to highest (slowest)
    sorted_T_client_serv = sorted(T_client_serv, key=lambda x: x[2])
    
    # Based on the minimum split required (== number of clients to use), obtain the necessary clients that have 
    # the fastest full epoch times for this client-current_server combo.
    clients_to_use = []
    for i in range(min_clients):
        clients_to_use.append(sorted_T_client_serv[i])
    
    # Now using the chosen clients, we calculate the T_transf_weights between all of them. This is the time to 
    # pass the output weights between each other
    T_transf_weights = 0
    
    # If we only need ONE client, the the time to transfer weights between clients is ZERO
    if len(clients_to_use) == 1:
        T_transf_weights = 0
    # ELSE, we must consider the time it takes to transfer weights between clients and add those up 
    else:
        for i in range(len(clients_to_use) - 1):
            current_client = clients_to_use[i]
            next_client = clients_to_use[i + 1]
            current_client_ID = current_client[0].ID
            next_client_ID = next_client[0].ID
            T_transf_weights = ( T_transf_weights + 
                                ( current_client[0].D_weights / find_in_list(links, "link_" 
                                                                             + str(current_client_ID)
                                                                             + str(next_client_ID)).value)
                               )

    # Now we have the fastest full epoch times for this client-current_server combo AND the total time it will
    # take for the clients to pass the weights data to each other. So now we add everything together to obtain 
    # a final BEST full epoch training time for the chosen clients-current_server combo
    
    # First, the full epoch time for all the clients to use
    total_full_epoch_time = 0
    clients_to_use_names = []
    for element in clients_to_use:
        total_full_epoch_time = total_full_epoch_time + element[2]
        # Also take advantage of this for loop to add all clients_to_use to a new list in an easier to read way
        clients_to_use_names.append(element[0].name)
    
    # Finally, add both times up for the final training time for this path
#     total_training_time = total_full_epoch_time  * batch_size * total_batches_SL + T_transf_weights
    total_training_time = total_full_epoch_time * total_batches_SL + T_transf_weights
    
    # Add the current_path to paths. Additionally, the current total_training_time is in SECONDS, and considers
    # only ONE epoch. We want it to consider ALL epochs defined in the config.ipynb file AND have its output
    # be in HOURS, not seconds
    current_path = [current_server.name, clients_to_use_names, (total_training_time * epochs) / time_factor]
    paths.append(current_path)
    
    
    # Now we have finalized the path using this current_server as the server. So, we also want to know the other
    # potential paths starting from other servers. To do so, we remove the current_server from "servers", return
    # the current result, and then start this state again with the new info!
    
    # Remove the current_server from the global servers list (the one given to this function at the beginning)
    servers.remove(current_server)
    
    # Then we just start this state all over again until a path for all servers has been obtained!
    if len(servers) > 0:
        # First we finish the current state
        previous_state = 1
        # And move on to the next state (call this one again basically)
        state = 1
        return (clients, servers, paths)
    else:
        # Finish the current state
        previous_state = 1
        # And move on to the final state
        state = 2
        return paths

In [4]:
# def state2_SL(clients, aggregators, T_proc, T_transf, T_agg, paths):
def state2_SL(paths):
    
    global previous_state
    global state
    
    # This state is only in charge of obtaining the best path out of all the calculated paths so far
    # Remember that:
    # paths is a nested list of finalized individual paths
    # and each individual path is a list containing [server, client list, train time]
    
    best_path = min(paths, key=itemgetter(2))
            
    # Program is done
    previous_state = 2
    state = 3
    return (best_path) 