In [None]:
import os
import glob
import pickle
import numpy as np
import scipy.stats
import matplotlib.pyplot as plt
from matplotlib.patches import Circle, RegularPolygon
from matplotlib.path import Path
from matplotlib.projections.polar import PolarAxes
from matplotlib.projections import register_projection
from matplotlib.spines import Spine
from matplotlib.transforms import Affine2D
from mpl_toolkits.axes_grid1.inset_locator import inset_axes

plt.rcParams.update({"font.size": 15})


methods = ["FedAvg", "FedMed", "FedProx", "qFedAvg"]
local_rounds = ["01", "05", "10", "25"]

In [None]:
def plot_accuracy(dataset):
    ROUNDS = 50
    exp_matches = [
        " CNN on IID",
        " CNN on Non IID",
        " MLP on IID",
        " MLP on Non IID",
        " LSTM on IID",
        " LSTM on Non IID",
    ]

    for exp_match in exp_matches:
        for local_round in local_rounds:
            plt.figure(figsize=[8, 6])
            for method in methods:
                pickle_files = glob.glob(f"./{dataset}/{method}/{local_round}/*.pkl")
                print(pickle_files)

                if not pickle_files:
                    continue

                with open(pickle_files[0], "rb") as file:
                    log_dict = pickle.load(file)

                for experiment in log_dict.keys():
                    if experiment.endswith(exp_match):
                        print(experiment)

                        if "Non IID" in experiment:
                            IS_IID = "Non_IID"
                        else:
                            IS_IID = "IID"

                        for accuracy_profile in log_dict[experiment]["test_accuracy"]:
                            if len(accuracy_profile) < ROUNDS:
                                accuracy_profile.extend([accuracy_profile[-1]] * (ROUNDS - len(accuracy_profile)))

                        accuracy_runs = np.array(log_dict[experiment]["test_accuracy"])

                        if dataset == "Shakespeare":
                            accuracy_runs = accuracy_runs * 100

                        mean_accuracy_profile = np.mean(accuracy_runs, axis=0)
                        std_dev_accuracy_profile = np.std(accuracy_runs, axis=0)

                        plt.grid(False)
                        plt.plot(np.arange(ROUNDS), mean_accuracy_profile, label=f"{method}")
                        plt.fill_between(
                            np.arange(ROUNDS),
                            mean_accuracy_profile - std_dev_accuracy_profile,
                            mean_accuracy_profile + std_dev_accuracy_profile,
                            alpha=0.5,
                        )

                        plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True)
                        # plt.title(f"Local Rounds {local_round}", pad=27.5)
                        plt.xlabel("Global Communication Rounds")
                        plt.ylabel("Test Accuracy")

                        plt.tight_layout()
                        plt.savefig(
                            f"./results/{dataset}/Accuracy_Profile/{IS_IID}/{experiment}_{local_round}.svg",
                            format="svg",
                            dpi=1000,
                        )
            plt.show()

In [None]:
def plot_accuracy_stacked_error_bar_plot(dataset):
    ROUNDS = 50

    final_accuracy_mean_tracker = {}
    final_accuracy_std_tracker = {}
    final_accuracy_max_tracker = {}
    final_accuracy_min_tracker = {}

    for local_round in local_rounds:
        plt.figure(figsize=[8, 6])
        for method in methods:
            pickle_files = glob.glob(f"./{dataset}/{method}/{local_round}/*.pkl")
            print(pickle_files)

            if not pickle_files:
                continue

            with open(pickle_files[0], "rb") as file:
                log_dict = pickle.load(file)

            for experiment in log_dict.keys():
                print(experiment)
                for accuracy_profile in log_dict[experiment]["test_accuracy"]:
                    if len(accuracy_profile) < ROUNDS:
                        accuracy_profile.extend([accuracy_profile[-1]] * (ROUNDS - len(accuracy_profile)))

                accuracy_runs = np.array(log_dict[experiment]["test_accuracy"])

                if dataset == "Shakespeare":
                    accuracy_runs = accuracy_runs * 100

                mean_accuracy_profile = np.mean(accuracy_runs, axis=0)
                std_dev_accuracy_profile = np.std(accuracy_runs, axis=0)
                max_accuracy_profile = np.max(accuracy_runs, axis=0)
                min_accuracy_profile = np.min(accuracy_runs, axis=0)

                final_accuracy_mean_tracker[f"{method}-{experiment}"] = mean_accuracy_profile[-1]
                final_accuracy_std_tracker[f"{method}-{experiment}"] = std_dev_accuracy_profile[-1]
                final_accuracy_max_tracker[f"{method}-{experiment}"] = max_accuracy_profile[-1]
                final_accuracy_min_tracker[f"{method}-{experiment}"] = min_accuracy_profile[-1]

            exp_matches = [" on IID", " on Non IID"]

            for exp_match in exp_matches:
                if "Non IID" in exp_match:
                    IS_IID = "Non_IID"
                else:
                    IS_IID = "IID"

                final_acc_mean_list = []
                final_acc_std_list = []
                final_acc_max_list = []
                final_acc_min_list = []
                key_list = []
                for key in final_accuracy_mean_tracker:
                    if key.endswith(exp_match):
                        key_list.append(key.replace(f"-{dataset}", "").replace(exp_match, ""))
                        final_acc_mean_list.append(final_accuracy_mean_tracker[key])
                        final_acc_std_list.append(final_accuracy_std_tracker[key])
                        final_acc_max_list.append(final_accuracy_max_tracker[key])
                        final_acc_min_list.append(final_accuracy_min_tracker[key])

                plt.grid(True)
                plt.errorbar(
                    np.arange(len(key_list)),
                    np.array(final_acc_mean_list),
                    np.array(final_acc_std_list),
                    fmt="ok",
                    lw=3,
                )
                plt.errorbar(
                    np.arange(len(key_list)),
                    np.array(final_acc_mean_list),
                    [
                        np.array(final_acc_mean_list) - np.array(final_acc_min_list),
                        np.array(final_acc_max_list) - np.array(final_acc_mean_list),
                    ],
                    fmt=".k",
                    ecolor="black",
                    lw=1,
                )

                plt.xticks(np.arange(len(key_list)), key_list, rotation="45", ha="right")
                # plt.title(f"Local Rounds {local_round}")
                plt.xlabel("Algorithms")
                plt.ylabel("Test Accuracy")

                plt.margins(0.1)
                plt.tight_layout()
                plt.savefig(
                    f"./results/{dataset}/Stacked_Error_Bar/{IS_IID}/{dataset}{exp_match}_{local_round}.svg",
                    format="svg",
                    dpi=1000,
                )
                plt.show()

In [None]:
def plot_accuracy_stacked_error_bar_multiple_plot(dataset):
    ROUNDS = 50

    final_accuracy_mean_tracker = {}
    final_accuracy_std_tracker = {}
    final_accuracy_max_tracker = {}
    final_accuracy_min_tracker = {}

    fmt_dict = {"25": "ok", "10": "ob", "05": "og", "01": "or"}
    efmt_dict = {"25": ".k", "10": ".b", "05": ".g", "01": ".r"}
    ecolor_dict = {"25": "black", "10": "blue", "05": "green", "01": "red"}
    fig_dict = {" on IID": 1, " on Non IID": 2}

    for local_round in local_rounds:
        for method in methods:
            pickle_files = glob.glob(f"./{dataset}/{method}/{local_round}/*.pkl")
            print(pickle_files)

            if not pickle_files:
                continue

            with open(pickle_files[0], "rb") as file:
                log_dict = pickle.load(file)

            for experiment in log_dict.keys():
                print(experiment)
                for accuracy_profile in log_dict[experiment]["test_accuracy"]:
                    if len(accuracy_profile) < ROUNDS:
                        accuracy_profile.extend([accuracy_profile[-1]] * (ROUNDS - len(accuracy_profile)))

                accuracy_runs = np.array(log_dict[experiment]["test_accuracy"])

                if dataset == "Shakespeare":
                    accuracy_runs = accuracy_runs * 100

                mean_accuracy_profile = np.mean(accuracy_runs, axis=0)
                std_dev_accuracy_profile = np.std(accuracy_runs, axis=0)
                max_accuracy_profile = np.max(accuracy_runs, axis=0)
                min_accuracy_profile = np.min(accuracy_runs, axis=0)

                final_accuracy_mean_tracker[f"{method}-{experiment}_{local_round}"] = mean_accuracy_profile[-1]
                final_accuracy_std_tracker[f"{method}-{experiment}_{local_round}"] = std_dev_accuracy_profile[-1]
                final_accuracy_max_tracker[f"{method}-{experiment}_{local_round}"] = max_accuracy_profile[-1]
                final_accuracy_min_tracker[f"{method}-{experiment}_{local_round}"] = min_accuracy_profile[-1]

    exp_matches = [" on IID", " on Non IID"]

    for exp_match in exp_matches:
        plt.figure(figsize=[8, 6])
        for local_round in local_rounds:
            if "Non IID" in exp_match:
                IS_IID = "Non_IID"
            else:
                IS_IID = "IID"

            final_acc_mean_list = []
            final_acc_std_list = []
            final_acc_max_list = []
            final_acc_min_list = []
            key_list = []

            for key in final_accuracy_mean_tracker:
                if exp_match in key and key.endswith(local_round):
                    key_list.append(key.split("_")[0].replace(f"-{dataset}", "").replace(exp_match, ""))
                    final_acc_mean_list.append(final_accuracy_mean_tracker[key])
                    final_acc_std_list.append(final_accuracy_std_tracker[key])
                    final_acc_max_list.append(final_accuracy_max_tracker[key])
                    final_acc_min_list.append(final_accuracy_min_tracker[key])

                    plt.grid(True)

                    if dataset == "Shakespeare":
                        plt.errorbar(
                            np.arange(len(key_list)),
                            np.array(final_acc_mean_list),
                            np.array(final_acc_std_list),
                            fmt=fmt_dict[local_round],
                            label=int(local_round) if "qFedAvg" in key and exp_match in key else None,
                            lw=3,
                        )
                    else:
                        plt.errorbar(
                            np.arange(len(key_list)),
                            np.array(final_acc_mean_list),
                            np.array(final_acc_std_list),
                            fmt=fmt_dict[local_round],
                            label=int(local_round) if f"qFedAvg-{dataset} CNN" in key and exp_match in key else None,
                            lw=3,
                        )
                    plt.errorbar(
                        np.arange(len(key_list)),
                        np.array(final_acc_mean_list),
                        [
                            np.array(final_acc_mean_list) - np.array(final_acc_min_list),
                            np.array(final_acc_max_list) - np.array(final_acc_mean_list),
                        ],
                        fmt=efmt_dict[local_round],
                        ecolor=ecolor_dict[local_round],
                        lw=1,
                    )

                    plt.xticks(np.arange(len(key_list)), key_list, rotation="45", ha="right")
                    # plt.title(f"Local Rounds {local_round}")
                    plt.xlabel("Algorithms")
                    plt.ylabel("Test Accuracy")

        plt.margins(0.1)
        plt.legend(loc="upper center", bbox_to_anchor=(0.5, 1.15), ncol=4, fancybox=True)
        plt.tight_layout()
        plt.savefig(
            f"./results/{dataset}/Stacked_Error_Bar_Multiple/{IS_IID}/{dataset}{exp_match.split('-')[0]}.svg",
            format="svg",
            dpi=1000,
        )
        plt.savefig(
            f"../results/Local_Rounds/{dataset}/{IS_IID}.svg",
            format="svg",
            dpi=1000,
        )
        plt.show()

In [None]:
plot_accuracy("MNIST")
plot_accuracy("CIFAR")
plot_accuracy("Shakespeare")

In [None]:
plot_accuracy_stacked_error_bar_plot("MNIST")
plot_accuracy_stacked_error_bar_plot("CIFAR")
plot_accuracy_stacked_error_bar_plot("Shakespeare")

In [None]:
plot_accuracy_stacked_error_bar_multiple_plot("MNIST")
plot_accuracy_stacked_error_bar_multiple_plot("CIFAR")
plot_accuracy_stacked_error_bar_multiple_plot("Shakespeare")