In [None]:
import pickle

def l2norm(list1, list2):
    l2_norm_difference = sum((x - y) ** 2 for x, y in zip(list1, list2)) ** 0.5
    l2_norm_reference = sum(y**2 for y in list1) ** 0.5
    return l2_norm_difference / l2_norm_reference

data = {}
name = {"$lambda$-MR": "$\lambda$-MR", "Light Sampling": "IPSS", "Comb-Shapley": "MC-Shapley"}

for cnum in [3, 6, 10]:
    with open(f"./expres/linear_model_{cnum}_emnist_same.res", "rb") as f:
        result = pickle.load(f)
    exact = result["Comb-Shapley"][0]
    data[cnum] = {}
    for alg, alg_result in result.items():
        data[cnum][name.get(alg, alg)] = (alg_result[1], l2norm(exact, alg_result[0]))

for cnum in [20, 100]:
    with open(f"../sources/scalability/expres/linear_model_{cnum}_emnist_same.res", "rb") as f:
        result = pickle.load(f)
    data[cnum] = {}
    for alg, alg_result in result.items():
        data[cnum][name.get(alg, alg)] = (alg_result["time"], alg_result["zero_err"], alg_result["same_err"])

data

In [None]:
import matplotlib.pylab as plt

plt.style.use("ggplot")
plt.rcParams['axes.prop_cycle']
plt.rcParams["font.family"] = "Times New Roman"
plt.rcParams["mathtext.fontset"] = "custom"
plt.rcParams["mathtext.rm"] = "Times New Roman"

fontsize = 30

algs = ["Perm-Shapley", "MC-Shapley", "DIG-FL", "Extended-GTB", "CC-Shapley", "Extended-TMC", "GTG-Shapley", "OR", "$\lambda$-MR", "IPSS"]
colors = ["#FFFF00", "#8EBA42", "#8EBA42", "#FFB5B8", "#E24A33", "#348ABD", "#1F77b4", "#FF7F0E", "#2CA02C", "#D62728"]
hatchs = ["", "", "xx", "..", "//", "\\\\", "//", "**", "++", "oo"]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

cnums = [3, 6, 10, 20, 100]
x = [0, 1, 2, 3, 3.7]
w = 0.08

index1 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
index2 = [3, 4, 5, 9]

plt.figure(figsize=(18, 8))

for k, cnum in enumerate(cnums):
    index = index1 if cnum <= 10 else index2
    for i, idx in enumerate(index):
        alg, color, hatch = algs[idx], colors[idx], hatchs[idx]
        b = i - len(index) // 2
        plt.bar(x[k] + b * w, data[cnum][alg][0], width=w, label=alg if k == 0 else None, color=color, hatch=hatch, edgecolor="k", linewidth=1)

plt.axvline(x=2.57, linestyle='--', color="black")

plt.yscale("log")
plt.ylim(top=1e7)
plt.ylabel("Running Time (s)", fontsize=fontsize * 1.6, color="black")
plt.xticks(x, [f"Clients #{cnum}" for cnum in cnums], fontsize=fontsize, color="black")
plt.yticks(fontsize=fontsize, color="black")
plt.legend(ncol=2, loc="upper left", fontsize=fontsize * 0.8, columnspacing=0.2, handletextpad=0.2, handlelength=2)

plt.tight_layout()
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

cnums = [3, 6, 10, 20, 100]
x = [0, 1, 2, 3.3, 4.6]
w = 0.1

index1 = [2, 3, 4, 5, 6, 7, 8, 9]
index2 = [3, 4, 5, 9]

plt.figure(figsize=(18, 8))

for k, cnum in enumerate(cnums):
    if cnum <= 10:
        for i, idx in enumerate(index1):
            alg, color, hatch = algs[idx], colors[idx], hatchs[idx]
            b = i - len(index1) // 2
            plt.bar(x[k] + b * w, data[cnum][alg][1], width=w, label=alg if k == 0 else None, color=color, hatch=hatch, edgecolor="k", linewidth=1)
    else:
        for i, idx in enumerate(index2):
            alg, color, hatch = algs[idx], colors[idx], hatchs[idx]
            b1 = i - len(index2) - 0.5
            b2 = i + 1.5
            d = 2e-3 if data[cnum][alg][1] < 1e-3 else data[cnum][alg][1]
            plt.bar(x[k] + b1 * w, d                 , width=w, label=alg if k == 0 else None, color=color, hatch=hatch, edgecolor="k", linewidth=1)
            plt.bar(x[k] + b2 * w, data[cnum][alg][2], width=w, label=alg if k == 0 else None, color=color, hatch=hatch, edgecolor="k", linewidth=1)

plt.axvline(x=2.57, linestyle='--', color="black")

plt.text(x[3] - 4 * w, 0.3, 'no-free-rider', ha='center', va='bottom', fontsize=fontsize)
plt.text(x[4] - 4 * w, 0.3, 'no-free-rider', ha='center', va='bottom', fontsize=fontsize)
plt.text(x[3] + 2 * w, 6, 'symmetric-fairness', ha='center', va='bottom', fontsize=fontsize)
plt.text(x[4] + 3 * w, 6, 'symmetric-fairness', ha='center', va='bottom', fontsize=fontsize)

plt.yscale("log")
plt.ylim(bottom=1e-3 + 1e-5)
plt.ylabel("Relative Error", fontsize=fontsize * 1.6, color="black")
plt.xticks(x, [f"Clients #{cnum}" for cnum in cnums], fontsize=fontsize, color="black")
plt.yticks(fontsize=fontsize, color="black")
plt.legend(ncol=2, loc="upper left", fontsize=fontsize * 0.8, columnspacing=0.2, handletextpad=0.2, handlelength=2)

plt.tight_layout()
plt.show()