In [None]:
import json
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np 
from matplotlib.patches import PathPatch
from matplotlib.path import Path
import seaborn as sns

# Set this to True if you want to use Latex. 
plt.rcParams['text.usetex'] = False

def draw_error_band(ax, x, y, err, **kwargs):
    """
    Draw an error band on a matplotlib Axes object. Taken from: https://matplotlib.org/stable/gallery/lines_bars_and_markers/curve_error_band.html

    Parameters:
    - ax (matplotlib.axes.Axes): The Axes object to draw the error band on.
    - x (array-like): The x-coordinates of the data points.
    - y (array-like): The y-coordinates of the data points.
    - err (float): The error magnitude to determine the width of the error band.
    - **kwargs: Additional keyword arguments to be passed to the PathPatch constructor.

    Returns:
    None
    """
    dx = np.concatenate([[x[1] - x[0]], x[2:] - x[:-2], [x[-1] - x[-2]]])
    dy = np.concatenate([[y[1] - y[0]], y[2:] - y[:-2], [y[-1] - y[-2]]])
    l = np.hypot(dx, dy)
    nx = dy / l
    ny = -dx / l

    # end points of errors
    xp = x + nx * err
    yp = y + ny * err
    xn = x - nx * err
    yn = y - ny * err

    vertices = np.block([[xp, xn[::-1]], [yp, yn[::-1]]]).T
    codes = np.full(len(vertices), Path.LINETO)
    codes[0] = codes[len(xp)] = Path.MOVETO
    path = Path(vertices, codes)
    ax.add_patch(PathPatch(path, **kwargs))

def plot_key(df, key, markers, colors, ax, show_legend=True, x_scale=1):
    """
    Plot a key from a DataFrame on a given axis.

    Parameters:
    - df (pandas.DataFrame): The DataFrame containing the data.
    - key (str): The key/column name to plot.
    - markers (list): List of marker styles for each plot.
    - colors (list): List of colors for each plot.
    - ax (matplotlib.axes.Axes): The axis to plot on.
    - show_legend (bool): Whether to show the legend. Default is True.
    - x_scale (int): Scaling factor for the x-axis. Default is 1. This is useful to align the values if multiple x-axis are used

    Returns:
    - legend_entries (list): List of legend entries for each plot.
    """
    ysmall = df[(df["model"] == "small") & (df["batch"] == True)][key]["mean"].values[0]
    ysmall_std = df[(df["model"] == "small") & (df["batch"] == True)][key]["std"].values[0]
    ybig = df[(df["model"] == "big") & (df["batch"] == True)][key]["mean"].values[0]
    ybig_std = df[(df["model"] == "big") & (df["batch"] == True)][key]["std"].values[0]

    if ysmall_std > 0 and ybig_std > 0:
        ax.errorbar([0, 1*x_scale], [ysmall, ybig], yerr=[ysmall_std, ybig_std], c="k", fmt='x')
    else:
        ax.scatter([0, 1*x_scale], [ysmall, ybig], c="k", marker='x')

    plot_number = 0
    legend_entries = []

    ymax = ybig_std + ybig
    ymin = ysmall - ysmall_std
    for c in df["calibration"].dropna().unique():
        for tm in df["train_method"].dropna().unique():
            x = df[(df["model"] == "RE") & (df["batch"] == True) & (df["train_method"] == tm) & (df["calibration"] == c)]["p"]
            y = df[(df["model"] == "RE") & (df["batch"] == True) & (df["train_method"] == tm) & (df["calibration"] == c)][key]["mean"]
            yerr = df[(df["model"] == "RE") & (df["batch"] == True) & (df["train_method"] == tm) & (df["calibration"] == c)][key]["std"]

            x = x * x_scale
            legend_entries.append(f"{tm} {'no calibration' if not c else 'calibrated'}")

            yerr = np.array(yerr)
            y = np.array(y)
            for i in range(len(y)):
                if i == 0 and np.isnan(y[i]):
                    tmp = [yi for yi in y if not np.isnan(yi)]
                    y[i] = np.amin(tmp) if len(tmp) > 0 else 0
                elif i > 0 and np.isnan(y[i]):
                    y[i] = y[i-1]

                if i == 0 and np.isnan(yerr[i]):
                    yerr[i] = 0
                elif i > 0 and np.isnan(yerr[i]):
                    yerr[i] = yerr[i-1]
            yerr = np.array(yerr)
            y = np.array(y)

            if show_legend:
                ax.plot(x, y, color=colors[plot_number], marker=markers[plot_number], label=f"{tm} {'no calibration' if not c else 'calibrated'}", markersize=4)
            else:
                ax.plot(x, y, color=colors[plot_number], marker=markers[plot_number], markersize=4)

            if key == "p_per_batch":
                yerr = df[(df["model"] == "RE") & (df["batch"] == True) & (df["train_method"] == tm) & (df["calibration"] == c)]["p_per_batch_std"]["mean"]

            ax.fill_between(x, y - yerr, y + yerr, color=colors[plot_number], alpha=0.2)
            ymax = max(ymax, np.amax(y+yerr))
            ymin = min(ymin, np.amin(y-yerr))

            plot_number = plot_number+1 if plot_number+1 < len(colors) else 0

    if not (np.isnan(ymin) or np.isinf(ymin) or np.isnan(ymax) or np.isinf(ymax)):
        ax.set_ylim([ymin-ymin*0.01, ymax+ymax*0.01])
    return legend_entries

def merge_lists(x):
    """
    Simple helper function. Merge a list of lists into a single list.

    Args:
        x (list): A list of lists.

    Returns:
        list: A single merged list containing all elements from the input lists.
    """
    merged_list = []
    for sublist in x:
        merged_list.extend(sublist)
    return merged_list

def plot_distribution(dff, ax, cal_unique, train_unique, markers, colors):
    """
    Plot the distribution of the usage of the big model for various training methods. 

    Args:
        dff (DataFrame): The input DataFrame containing the data.
        ax (Axes): The matplotlib Axes object to plot on.
        cal_unique (list): List of unique calibration values in dff. Usually this is just [True, False]
        train_unique (list): List of unique training methods. Usually, this is just ["virtual-labels", "confidence"]
        markers (list): List of markers for each plot.
        colors (list): List of colors for each plot.

    Returns:
        None
    """
    plot_number = 0

    for c in cal_unique:
        for tm in train_unique:
            ps = dff[(dff["model"] == "RE") & (dff["batch"] == True) & (dff["train_method"] == tm) & (dff["calibration"] == c)]["p"]
            ymean = []
            all_y = []
            all_x = []
            for pi in ps:
                y = dff[(dff["model"] == "RE") & (dff["batch"] == True) & (dff["train_method"] == tm) & (dff["calibration"] == c) & (dff["p"] == pi)]["p_per_batch"].values[0]
                x = [pi for _ in range(len(y))]
                all_x.extend(x)
                all_y.extend(y)
                ymean.append(np.mean(y))
            
            if len(all_x) > 10_000:
                idx = np.random.choice(range(len(all_x)), size=10_000, replace=False)
            else:
                idx = range(len(all_x))

            all_x = np.array(all_x)[idx]
            all_y = np.array(all_y)[idx]
            sns.stripplot(x=all_x, y=all_y, jitter=0.2, size=1,ax=ax, color=colors[plot_number])
            ax.plot(ax.get_xticks(), ymean, c = colors[plot_number], marker=markers[plot_number], zorder=2,  markersize=4)
            plot_number = plot_number+1 if plot_number+1 < len(colors) else 0
        
    x = ax.get_xticks()
    y = [xi / max(x) for xi in x]
    ax.plot(x,y, c="k", linestyle="--", alpha=0.5, zorder=3)

def read_json(json_path):
    """
    Reads a JSON file containing metrics data and returns a pandas DataFrame.

    Parameters:
    json_path (str): The path to the JSON file.

    Returns:
    df (pandas.DataFrame): The DataFrame containing the metrics data.
    keys (list): The list of keys used for aggregation.

    """
    metrics = json.load(open(json_path,"r"))

    df_metrics = [{
        "model":m["model"],
        "rejector":m["rejector"],
        "run":m["run"],
        "batch":m["batch"],
        "p":m["p"],
        "p_per_batch":m["p_per_batch"],
        "train_method":m["train_method"],
        "calibration":m["calibration"],
        "f1 macro":m["f1 macro"],
        "f1 micro":m["f1 micro"],
        "accuracy":m["accuracy"],
        "time":np.mean(m["time"]),
        "power":np.mean(m["power_per_batch"]),
        "poweravg":np.mean(m["poweravg_per_batch"])
    } for m in metrics]

    df = pd.DataFrame(df_metrics)
    
    df["p_per_batch_std"] = df["p_per_batch"].apply(lambda x: np.std(x))
    df["p_per_batch"] = df["p_per_batch"].apply(lambda x: np.mean(x))

    def batch_failed(row):
        p = row["p"]
        return 1 if row["p_per_batch"] > p else 0

    df["failed_batches"] = df.apply(batch_failed, axis=1)
    agg = {
        "time":['mean','std'],
        "f1 macro":['mean','std'],
        "f1 micro":['mean','std'],
        "accuracy":['mean','std'],
        "power":['mean','std'],
        "poweravg":['mean','std'],
        "p_per_batch":['mean','std'],
        "p_per_batch_std":['mean','std'],
        "failed_batches":['mean','std'],
    }

    df = df.groupby(["model", "rejector", "p", "batch", "train_method", "calibration"], dropna=False).agg(agg)
    df.reset_index(inplace=True)
    df["power"] = df["power"] / 1000
    df["poweravg"] = df["poweravg"] / 1000
    return df, agg.keys()

def read_grouped(json_path):
    """
    Read and process a JSON file containing metrics data and returns a dataframe containing the usage of the big model per batch.

    Parameters:
    json_path (str): The path to the JSON file.

    Returns:
    pandas.DataFrame: A DataFrame containing the processed metrics data, i.e. a dataframe containing the usage of the big model per batch.
    """
    metrics = json.load(open(json_path,"r"))

    df = pd.DataFrame(metrics)
    dff = df.groupby(["model", "rejector", "p", "batch", "train_method", "calibration"]).agg({
        'p_per_batch': merge_lists,
    }).reset_index()

    return dff

In [None]:
# Read all files for plotting 
all_files = [os.path.join("rf", f) for f in os.listdir("rf") if os.path.join("rf", f) and f.endswith(".json") ] + [os.path.join("dt3", f) for f in os.listdir("dt3") if os.path.join("dt3", f) and f.endswith(".json") ] + ["cifar100.json", "imagenet.json"]

# Generate individual plots for accuracy / time / power consumption. These plot the raw data and are not used in the paper directly.
for p in all_files:
    df, agg_keys = read_json(p)
    name = p.split(".")[0]

    colors = ['#66c2a5','#fc8d62','#8da0cb','#e78ac3']
    markers = ["s", "*", "x", "D"]

    for m in df["rejector"].dropna().unique():
        dff = df[ (df["rejector"] == m) | (df["rejector"].isna()) ]
        for key in ["accuracy", "time", "power"]:
            plot_key(dff, key, markers, colors, plt.gca(), show_legend = True)
            if key == "accuracy":
                plt.title(f"Test accuracy on {name}")
                plt.ylabel(f"Accuracy")
            elif key == "poweravg" or key == "power":
                plt.title(f"Average power consumption per batch on {name}")
                plt.ylabel(f"Watt")
            else:    
                plt.title(f"{key} on {name}")
            plt.xlabel("p")
            plt.legend(loc='upper center',  bbox_to_anchor=(0.5, -0.12), shadow=True, ncol=2)
            plt.tight_layout(rect=[0, 0, 1, 0.95])
            plt.show()

In [None]:
colors = ['#66c2a5','#fc8d62','#8da0cb','#e78ac3']
markers = ["s", "*", "x", "D"]
dff = read_grouped("backup/imagenet.json")

# Generate individual plots the usage of the big model for each batch and method. These plot the raw data and are not used in the paper directly.
for rejector in dff["rejector"].dropna().unique():
    plt.xlabel("p")
    plt.ylabel("\widehat p per batch")
    plt.title(f"Distribution of used budger per batch")
    plot_distribution(dff, plt.gca(), dff["calibration"].dropna().unique(), dff["train_method"].dropna().unique(), markers, colors)

    plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
plt.rcParams['text.usetex'] = True

fig, axs = plt.subplots(3, 2, sharex=True)
colors = ['#66c2a5','#fc8d62','#8da0cb','#e78ac3']
markers = ["s", "*", "x", "D"]

# Generate the 3x2 plots of CIFAR100 and Imagenet. This is Fig. 2 in the paper
for i,jpath in enumerate(["cifar100.json", "imagenet.json"]):
    df,_ = read_json(jpath)
    df_distribution = read_grouped(jpath)

    legend_labels = plot_key(df, "accuracy", markers, colors, axs[0,i], show_legend = i == 0, x_scale=10)
    plot_key(df, "power", markers, colors, axs[1,i], show_legend = False, x_scale=10)
    # plot_key(df, "p_per_batch", markers, colors, axs[2,i], show_legend = False)
    plot_distribution(df_distribution, axs[2,i], df["calibration"].dropna().unique(), df["train_method"].dropna().unique(), markers, colors)

    if i == 0:
        axs[0,i].set_ylabel("Test accuracy [\%]")
        axs[0,i].yaxis.set_label_coords(-0.2, 0.5)
        axs[1,i].set_ylabel("Power [W]")
        axs[1,i].yaxis.set_label_coords(-0.2, 0.5)
        axs[2,i].set_ylabel(r"$\widehat p$ per batch [\%]")
        axs[2,i].yaxis.set_label_coords(-0.2, 0.5)


# Set titles for each column
axs[0, 0].set_title('CIFAR100')
axs[0, 1].set_title('ImageNet')

# Add a legend to the figure
fig.legend(loc='upper center',  bbox_to_anchor=(0.5, 0), shadow=True, ncol=2)

# Adjust the layout to make space for the legend
plt.tight_layout(rect=[0, 0, 1, 0.95])

# Save the figure to a PDF
with PdfPages('cifar100_imagenet.pdf') as pdf:
    pdf.savefig(fig, bbox_inches='tight')

plt.show()

In [None]:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
plt.rcParams['text.usetex'] = True

fig, axs = plt.subplots(3, 2, sharex=True)
colors = ['#66c2a5','#fc8d62','#8da0cb','#e78ac3']
markers = ["s", "*", "x", "D"]
all_files = [os.path.join("dt3", f) for f in os.listdir("dt3") if os.path.join("dt3", f) and f.endswith(".json") ]

# Generate the 3x2 plots for the UCI datasets. This is Fig. 3 in the paper
for i,jpath in enumerate(all_files):
    df,_ = read_json(jpath)
    df_distribution = read_grouped(jpath)

    if i < 3:
        legend_labels = plot_key(df, "accuracy", markers, colors, axs[i,0], show_legend = i == 0)
        axs[i, 0].set_title(jpath.split("dt3/")[1].split(".json")[0])
    else:
        legend_labels = plot_key(df, "accuracy", markers, colors, axs[i % 3,1], show_legend = i == 0)
        axs[i % 3, 1].set_title(jpath.split("dt3/")[1].split(".json")[0])

    if i == 0:
        axs[0,i].set_ylabel("Test accuracy [\%]")
        axs[0,i].yaxis.set_label_coords(-0.2, 0.5)
        axs[1,i].set_ylabel("Test accuracy [\%]")
        axs[1,i].yaxis.set_label_coords(-0.2, 0.5)
        axs[2,i].set_ylabel("Test accuracy [\%]")
        axs[2,i].yaxis.set_label_coords(-0.2, 0.5)

# Add a legend to the figure
fig.legend(loc='upper center',  bbox_to_anchor=(0.5, 0), shadow=True, ncol=2)

# Adjust the layout to make space for the legend
plt.tight_layout(rect=[0, 0, 1, 0.95])

# Save the figure to a PDF
with PdfPages('dt3.pdf') as pdf:
    pdf.savefig(fig, bbox_inches='tight')

plt.show()

In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.backends.backend_pdf import PdfPages
plt.rcParams['text.usetex'] = True

colors = ['#66c2a5','#fc8d62','#8da0cb','#e78ac3']
markers = ["s", "*", "x", "D"]

# Generate the 3x2 plots for each UCI dataset individually. Also generate plots for the ablation study using Random Forest. These plots can be found in the appendix.
for base in ["rf", "dt3"]:
    all_files = [os.path.join(base, f) for f in os.listdir(os.path.join(base)) if f.endswith(".json") ]

    for j in range(0, len(all_files), 2):
        fig, axs = plt.subplots(3, 2, sharex=True)
        
        files = all_files[j:j+2]
        for i,jpath in enumerate(files):
            df,_ = read_json(jpath)
            df_distribution = read_grouped(jpath)

            legend_labels = plot_key(df, "accuracy", markers, colors, axs[0,i], show_legend = i == 0, x_scale=10)
            plot_key(df, "power", markers, colors, axs[1,i], show_legend = False, x_scale=10)
            # plot_key(df, "p_per_batch", markers, colors, axs[2,i], show_legend = False)
            plot_distribution(df_distribution, axs[2,i], df["calibration"].dropna().unique(), df["train_method"].dropna().unique(), markers, colors)

            if i == 0:
                axs[0,i].set_ylabel("Test accuracy [\%]")
                axs[0,i].yaxis.set_label_coords(-0.2, 0.5)
                axs[1,i].set_ylabel("Power [W]")
                axs[1,i].yaxis.set_label_coords(-0.2, 0.5)
                axs[2,i].set_ylabel(r"$\widehat p$ per batch [\%]")
                axs[2,i].yaxis.set_label_coords(-0.2, 0.5)


        # Set titles for each column
        axs[0, 0].set_title(os.path.basename(files[0]).split(".")[0])
        axs[0, 1].set_title(os.path.basename(files[1]).split(".")[0])

        # Add a legend to the figure
        fig.legend(loc='upper center',  bbox_to_anchor=(0.5, 0), shadow=True, ncol=2)

        # Adjust the layout to make space for the legend
        plt.tight_layout(rect=[0, 0, 1, 0.95])

        # Save the figure to a PDF
        with PdfPages(f'block{j}_{base}.pdf') as pdf:
            pdf.savefig(fig, bbox_inches='tight')

        plt.show()