# Plot grid search

In [None]:
# the name of folders that contains all results for this plot
#### uncomment corresponding dataset to plot ####
exp_id, dataset = 'test_synthetic_iid', 'synthetic_iid'
# exp_id, dataset = 'test_synthetic_0_0', 'synthetic_0_0'
# exp_id, dataset = 'test_synthetic_0.5_0.5', 'synthetic_0.5_0.5'
# exp_id, dataset = 'test_synthetic_1_1', 'synthetic_1_1'
# exp_id, dataset = 'test_FEMNIST', 'FEMNIST'

# specify log folder
log_folder = "../logs/" + exp_id

# list of algorithms and their legends
algs = ['fedavg','fedprox','fedpd','feddr']
legend_list = ['FedAvg','FedProx','FedPD','FedDR']
dataname_dict = {'synthetic_iid': 'synthetic-iid',
                 'synthetic_0_0': 'synthetic-(0,0)',
                 'synthetic_0.5_0.5':'synthetic-(0.5,0.5)',
                 'synthetic_1_1': 'synthetic-(1,1)',
                 'FEMNIST': 'FEMNIST'}

# Import library and some support function

In [None]:
import os
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt

import matplotlib.pylab as pylab
params = {'legend.fontsize': '14',
          'figure.figsize': (15, 5),
          'figure.titlesize': '16',
         'axes.labelsize': '12',
         'axes.titlesize':'12',
         'xtick.labelsize':'12',
         'ytick.labelsize':'12'}
pylab.rcParams.update(params)

def get_file(log_folder, alg_name, dataset):
    files = [f for f in os.listdir(log_folder) if alg_name in f and dataset in f]
    return files

def read_csv(file_name):
    if os.path.exists(file_name):
        data = pd.read_csv( file_name)
        return data
    else:
        raise ValueError('File not exists: {}'.format(path))

def plot_results(alg_data, legend_list=None, figsize=(16,3.5), title=None, 
                     lstyle=None, freq=None, show_xlabel=True, use_bytes=False, 
                     use_rel_loss=False, plot_log=False, xlim=None, ylim=None):
    # Create two subplots and unpack the output array immediately
    f, (ax1, ax2, ax3) = plt.subplots(1, 3, sharey=False, figsize=figsize)
    
    if lstyle is None:
        lstyle = ['-' for _ in range(len(alg_data))]
    
    if use_bytes:
        max_x = np.max(alg_data[0].NumBytes)
        if freq is None:
            freq = [np.size(alg_data[i].NumBytes)//10 for i in range(len(alg_data))]
    else:
        if freq is None:
            freq = [np.size(alg_data[i].ComRound)//10 for i in range(len(alg_data))]
        max_x = np.max(alg_data[0].ComRound)
    
    if use_rel_loss:
        min_loss = np.min([np.min(alg_data[i].TrainLoss) for i in range(len(alg_data))])
        for i in range(len(alg_data)):
            alg_data[i]['RelTrainLoss'] = (alg_data[i]['TrainLoss'] - min_loss)/ np.abs(min_loss)
    for i, data in enumerate(alg_data):
        if use_bytes:
            if use_rel_loss:
                if plot_log:
                    ax1.semilogy(data.NumBytes, data.RelTrainLoss, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
                else:
                    ax1.plot(data.NumBytes, data.RelTrainLoss, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
            else:
                if plot_log:
                    ax1.semilogy(data.NumBytes, data.TrainLoss, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
                else:
                    ax1.plot(data.NumBytes, data.TrainLoss, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
            ax2.plot(data.NumBytes, data.TrainAcc, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
            ax3.plot(data.NumBytes, data.TestAcc, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
        else:
            if use_rel_loss:
                if plot_log:
                    ax1.semilogy(data.ComRound, data.RelTrainLoss, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
                else:
                    ax1.plot(data.ComRound, data.RelTrainLoss, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
            else:
                if plot_log:
                    ax1.semilogy(data.ComRound, data.TrainLoss, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
                else:
                    ax1.plot(data.ComRound, data.TrainLoss, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
            ax2.plot(data.ComRound, data.TrainAcc, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)
            ax3.plot(data.ComRound, data.TestAcc, lstyle[i], linewidth=2, markevery = freq[i], markersize=8)

    if show_xlabel:
        if use_bytes:
            ax1.set_xlabel('# Bytes')
            ax2.set_xlabel('# Bytes')
            ax3.set_xlabel('# Bytes')
        else:
            ax1.set_xlabel('# Comm. Rounds')
            ax2.set_xlabel('# Comm. Rounds')
            ax3.set_xlabel('# Comm. Rounds')
    ax1.set_ylabel('TrainLoss')
    ax2.set_ylabel('TrainAcc')
    ax3.set_ylabel('TestAcc')
    ax1.grid(axis='y')
    ax2.grid(axis='y')
    ax3.grid(axis='y')
    if title is not None:
        f.suptitle(title)
    
    div = 50
    if xlim is None:
        ax1.set_xlim([-max_x/div, max_x])
        ax2.set_xlim([-max_x/div, max_x])
        ax3.set_xlim([-max_x/div, max_x])
    else:
        ax1.set_xlim(xlim)
        ax2.set_xlim(xlim)
        ax3.set_xlim(xlim)
    
    if ylim is not None:
        if ylim[0] is not None:
            ax1.set_ylim(ylim[0])
        if ylim[1] is not None:
            ax2.set_ylim(ylim[1])
        if ylim[2] is not None:
            ax3.set_ylim(ylim[2])
    
    
    if legend_list is None:
        legend_list = [i for i in range(len(alg_data))]
    ax3.legend(legend_list, loc='lower right', borderaxespad=0.)
    plt.show()
    
    return f

# Read data

In [None]:
# get data
data_list = []
file_name_list = []
for alg_name in algs:
#     print('Grid Search for',alg_name)
    file_name = get_file(log_folder, alg_name, dataset)[0]
    file_name_list.append(file_name)
    data_list.append(read_csv(os.path.join(log_folder,file_name)))
file_name_list

# Plot in terms of communication rounds

In [None]:
fig = plot_results(data_list, 
                   legend_list=legend_list, 
                   title=dataname_dict[dataset], 
                   lstyle=['-d','--s','-^','--o'], 
                   freq=[30,40,30,40])

# Plot in terms of number of bytes

In [None]:
fig = plot_results(data_list, 
                   legend_list=legend_list, 
                   title=dataname_dict[dataset], 
                   lstyle=['-d','--s','-^','--o'], 
                   freq=[30,40,30,40],
                   use_bytes=True)