# Import necessary libraries

In [35]:
#Pandas and plotting tools
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Plotting (Box Plots)

In [36]:
%%capture plots # To save RAM

# Set your paths
input_directory = "./Results/V5_3/"
output_directory = './Results/Boxplots_V5_3/'

# Lists of all variables to be used (same as in RUN_simulator.ipynb)
# X_axis_variables = ['G_base', 'C', 'fraction_server', 'R', 'B', 'client_amount', 'D_clients_in',
#                    'D_clients_out', 'D_weights']
X_axis_variables = ['G_base', 'C', 'client_amount', 'B']
Rand_variables = X_axis_variables[:]

# Create the dictionary with all of the values for each of the ticks in the X axis for each of the
# X_axis_variables
X_dict = {}
X_dict['G_base'] = [630, 10000, 20000, 40000, 60480]
X_dict['C'] = [7365, 11000, 13000, 15000, 19500]
X_dict['fraction_server'] = [0.51, 0.65, 0.75, 0.85, 0.95]
X_dict['R'] = [64, 100, 150, 250, 320]
X_dict['B'] = [1.25, 5, 7.5, 10, 12.5]
X_dict['client_amount'] = [1, 3, 4, 5, 7]
X_dict['D_clients_in'] = [0.1, 0.4, 0.6, 0.8, 1.0]
X_dict['D_clients_out'] = [1, 4, 6, 8, 10]
X_dict['D_weights'] = X_dict['D_clients_out'][:]

# Run and log the results
# Set the number of x_ticks to be used on the X_axis of each plot. This number must be equivalent to the length
# of each of the previous dictionary entries
total_x_ticks = len(X_dict['G_base'])

# Set the font sizes for the plots
font_size_axes = 40
font_size_ticks = 35
font_size_legend = 30
color_simulator = 'blue'
line_width = 1.5

# Pre-process and plot
for X in tqdm(X_axis_variables):
    # Set the values on the X axis
    dataX = X_dict[X]    
    # Obtain the values for each of the single "boxes" for the plot
    for Rand in Rand_variables:
        if X != Rand:
            
            # Create the variables that will store the data for each architecture for each single "box" 
            # of the box plot for each x_tick
            dataY_PL = []
            dataY_FL = []
            dataY_SL = []
            dataY_PSL = []
            dataY_FSL = []

            
            # LOOP TO OBTAIN DATA
            for i in range(total_x_ticks):
                
                ################################# PRE-PROCESSING #################################
                file_name = 'X_' + X + '_Rand_' + Rand + '_' + str(i)
                file_path =  input_directory + file_name + ".csv"
                # Import the dataframe
                results_df = pd.read_csv(file_path)
                results_df = results_df.drop(results_df.columns[0], axis=1)
                # Remove all "None Feasible" results
                results_filtered = results_df[results_df['Best Architecture'] != "None feasible"]
                # All Runtimes where the architecture Outcome is NOT 'Success' will be turned to None 
                # so they are not plotted later
                results_filtered.loc[results_filtered['PL Outcome'] != 'Success', 'PL Runtime [h]'] = None
                results_filtered.loc[results_filtered['FL Outcome'] != 'Success', 'FL Runtime [h]'] = None
                results_filtered.loc[results_filtered['SL Outcome'] != 'Success', 'SL Runtime [h]'] = None
                results_filtered.loc[results_filtered['PSL Outcome'] != 'Success', 'PSL Runtime [h]'] = None
                results_filtered.loc[results_filtered['FSL Outcome'] != 'Success', 'FSL Runtime [h]'] = None
                
                ################################# PLOTTING #################################                
            
                #Set plot dimensions
                plt.figure(figsize = (15,7.5))
                
                # Set the name of the Y axis
                plt.ylabel("Best Train Time [h]", size = font_size_axes)
                
                # Set the name of the X axis
                plt.xlabel("Distributed Training Architecture", size = font_size_axes)

                # Set the title of the plot
#                 plt.title("Best Train Time VS " + X + " VS " + Rand, size = 20)
                
                # Using Pandas boxplot() because plt.boxplot() does not deal nicely with NaN values!
                # dataframe.boxplot() will take any info I gave to plt.whatever and use it for its own info!
                ax = results_filtered.boxplot(column=['PL Train Time [h]', 'FL Train Time [h]',
                                                           'SL Train Time [h]', 'PSL Train Time [h]', 
                                                           'FSL Train Time [h]'], 
                                                   grid = False, 
                                                 fontsize = font_size_ticks,
                                                 showfliers = False,
                                                 boxprops=dict(linestyle='-', linewidth=line_width, 
                                                               color = color_simulator),
                                                 flierprops=dict(linestyle='-', linewidth=line_width, 
                                                               color = color_simulator),
                                                 medianprops=dict(linestyle='-', linewidth=line_width, 
                                                               color = color_simulator),
                                                 whiskerprops=dict(linestyle='-', linewidth=line_width, 
                                                               color = color_simulator),
                                                 capprops=dict(linestyle='-', linewidth=line_width, 
                                                               color = color_simulator),
                                                 return_type = 'axes')
                
                # Set the position (relative to axis dimensions) of our text box that will 
                # contain the constant parameter X
                text_box_dim_x = 0.98
                text_box_dim_y = 0.95
                
                # Set the text in the textbox
                if X == "G_base":
                    x_axis_name = "NN Size = "
                    # Set the legend for the constant parameter X
                    plt.text(text_box_dim_x, text_box_dim_y,
                             x_axis_name + str(round(X_dict[X][i], 2)) + " GFLOPs",
                             bbox=dict(facecolor='white',
                                       alpha=0.5),
                             horizontalalignment='right',
                             verticalalignment = 'center',
                             transform = ax.transAxes,
                             fontsize=font_size_legend)
                elif X == "M_base":
                    x_axis_name = "NN Size in Memory = "
                    # Set the legend for the constant parameter X
                    plt.text(text_box_dim_x, text_box_dim_y,
                             x_axis_name + str(round(X_dict[X][i], 2)) + " GB",
                             bbox=dict(facecolor='white',
                                       alpha=0.5),
                             horizontalalignment='right',
                             verticalalignment = 'center',
                             transform = ax.transAxes,
                             fontsize=font_size_legend)
                elif X == "C":
                    x_axis_name = "Processor Power = "
                    # Set the legend for the constant parameter X
                    plt.text(text_box_dim_x, text_box_dim_y,
                             x_axis_name + str(round(X_dict[X][i], 2)) + " GFLOP/s",
                             bbox=dict(facecolor='white',
                                       alpha=0.5),
                             horizontalalignment='right',
                             verticalalignment = 'center',
                             transform = ax.transAxes,
                             fontsize=font_size_legend)
                elif X == "fraction_server":
                    x_axis_name = "Fraction of NN on Server = "
                    # Set the legend for the constant parameter X
                    plt.text(text_box_dim_x, text_box_dim_y,
                             x_axis_name + str(round(X_dict[X][i], 2)),
                             bbox=dict(facecolor='white',
                                       alpha=0.5),
                             horizontalalignment='right',
                             verticalalignment = 'center',
                             transform = ax.transAxes,
                             fontsize=font_size_legend)
                elif X == "R":
                    x_axis_name = "Processor Residual Memory = "
                    # Set the legend for the constant parameter X
                    plt.text(text_box_dim_x, text_box_dim_y,
                             x_axis_name + str(round(X_dict[X][i], 2)) + " GB",
                             bbox=dict(facecolor='white',
                                       alpha=0.5),
                             horizontalalignment='right',
                             verticalalignment = 'center',
                             transform = ax.transAxes,
                             fontsize=font_size_legend)
                elif X == "B":
                    x_axis_name = "Network Bandwidth = "
                    # Set the legend for the constant parameter X
                    plt.text(text_box_dim_x, text_box_dim_y,
                             x_axis_name + str(round(X_dict[X][i], 2)) + " GB/s",
                             bbox=dict(facecolor='white',
                                       alpha=0.5),
                             horizontalalignment='right',
                             verticalalignment = 'center',
                             transform = ax.transAxes,
                             fontsize=font_size_legend)
                elif X == "client_amount":
                    # Move textbox for a specific case where the client_amount = 1:
                    if X_dict[X][i] == 1:
                        x_axis_name = "Number of Clients / Split = "
                        # Set the legend for the constant parameter X
                        plt.text(0.47, 0.05,
                                 x_axis_name + str(round(X_dict[X][i], 2)),
                                 bbox=dict(facecolor='white',
                                           alpha=0.5),
                                 horizontalalignment='left',
                                 verticalalignment = 'center',
                                 transform = ax.transAxes,
                                 fontsize=font_size_legend)
                    else:
                        x_axis_name = "Number of Clients / Split = "
                        # Set the legend for the constant parameter X
                        plt.text(text_box_dim_x, text_box_dim_y,
                                 x_axis_name + str(round(X_dict[X][i], 2)),
                                 bbox=dict(facecolor='white',
                                           alpha=0.5),
                                 horizontalalignment='right',
                                 verticalalignment = 'center',
                                 transform = ax.transAxes,
                                 fontsize=font_size_legend)
                elif X == "D_clients_out":
                    x_axis_name = "Size of Intermediate Results = "
                    # Set the legend for the constant parameter X
                    plt.text(text_box_dim_x, text_box_dim_y,
                             x_axis_name + X + " = " + str(round(X_dict[X][i], 2)) + " GB",
                             bbox=dict(facecolor='white',
                                       alpha=0.5),
                             horizontalalignment='right',
                             verticalalignment = 'center',
                             transform = ax.transAxes,
                             fontsize=font_size_legend)
                elif X == "D_weights":
                    x_axis_name = "Size of Weights Matrix = "
                    # Set the legend for the constant parameter X
                    plt.text(text_box_dim_x, text_box_dim_y,
                             x_axis_name + str(round(X_dict[X][i], 2)) + " GB",
                             bbox=dict(facecolor='white',
                                       alpha=0.5),
                             horizontalalignment='right',
                             verticalalignment = 'center',
                             transform = ax.transAxes,
                             fontsize=font_size_legend)
        
                plt.xticks([1, 2, 3, 4, 5], ['PL', 'FL', 'SL', 'PSL', 'FSL'])
                plt.savefig(fname = output_directory + file_name + '.png', bbox_inches='tight')
print("DONE CREATING PLOTS!")