In [17]:
import os
import glob
import pickle
import time
from math import log2
from itertools import cycle, product
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm
%matplotlib inline

# Settings

In [18]:
# Data logs root directory
LOG_DIR = "logs_alphabeta"

# Loss function: either "logistic" regression, or "nonlinear" least squares
LOSS = "nonlinear"  # (logistic, nonlinear)
LOG_DIR = os.path.join(LOG_DIR, LOSS)

# The following should be the same as the one used in run_experiment.py
DATASETS = ("a9a", "w8a", "rcv1", "real-sim",)
OPTIMIZERS = ("SGD", "SARAH", "L-SVRG", "Adam")
T = 100  # Use 2xT used in run_experiment.py

# These are the metrics collected in the data logs
METRICS = ("loss", "gradnorm", "error")
METRIC = "error"  # choose metric

# These are aggregators for comparing multi-seed runs
AGGS = ("mean", "median")
AGG = "mean"  # choose aggregator

# Downsample this number of effective passes by averaging them
AVG_DOWNSAMPLE = 5

# These are the logs columns: effective passes + metrics
LOG_COLS = ["ep"] + list(METRICS)

# These are the hyperparameters of interest
ARG_COLS = ["lr", "alpha", "beta2"]

# Plots will be generated for this hyperparams/args setting.
# 'corrupt' should be the scale/suffix of the dataset as a string or 'none'.
FILTER_ARGS = {
    "corrupt": "none",
    "weight_decay": 0,
}

# Ignore all runs containing 'any' of these hyperparams.
IGNORE_ARGS = {}

### Utility functions for loading data

In [19]:
def ignore(args_dict):
    return any(args_dict[arg] in IGNORE_ARGS[arg]
               for arg in IGNORE_ARGS.keys() if arg in args_dict)


def loaddata(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
    return data


def contain_dict(dict1, dict2):
    return all(dict1[k] == v for k, v in dict2.items() if k in dict1)

# Gathering data and finding best hyperparameters for each (optimizer, dataset) combination

In [20]:
def unpack_args(fname):
    """
    Recover all args given file path.
    """
    args = {}
    # unpack path
    dirname, logname = os.path.split(fname)
    logdir, args["dataset"] = os.path.split(dirname)
    # parse args
    args["optimizer"], argstr = logname.split("(")
    argstr, _ = argstr.split(")")  # remove ').pkl'
    args_dict = {k:v for k,v in [s.split("=") for s in argstr.split(",")]}

    # Extract args
    if args["dataset"][-1] == ")":
        args["corrupt"] = args["dataset"][args["dataset"].index("("):]
    else:
        # It is very unlikely that the original dataset name will end with ')'
        args["corrupt"] = "none"

    if "seed" in args_dict:
        args["seed"] = int(args_dict["seed"])
    else:
        args["seed"] = 0

    args["BS"] = int(args_dict["BS"])
    args["lr"] = float(args_dict["lr"])
    if "weight_decay" in args_dict:
        args["weight_decay"] = float(args_dict["weight_decay"])
    else:
        args["weight_decay"] = 0
    if "lr_decay" in args_dict:
        args["lr_decay"] = float(args_dict["lr_decay"])
    else:
        args["lr_decay"] = 0
    if "p" in args_dict:
        args["p"] = args_dict["p"]
    if "precond" in args_dict:
        args["precond"] = args_dict["precond"]
        args["beta2"] = float(args_dict["beta2"])
        args["alpha"] = float(args_dict["alpha"])
    else:
        args["precond"] = "none"
        args["alpha"] = "none"
        args["beta2"] = "none"
    
    if args["optimizer"] == "Adam":
        args["beta2"] = float(args_dict["beta2"])
        args["alpha"] = 1e-8

    return args


def get_logs(logdir, dataset, optimizer, **filter_args):
    """
    Return all logs in 'logdir' containing the filter hyperparams.
    Dataset name should contain feature scaling, if any
    e.g. 'dataset' or 'dataset(k_min,k_max)'.
    
    Returns the data in the log file and its arguments/hyperparams.
    """
    remove_empty_data = False
    # Add
    if "corrupt" in filter_args and filter_args['corrupt'] != "none":
        # Add scale suffix to specify dataset    
        dataset += filter_args['corrupt']
    else:
        # No setting specified, use wildcard to match all suffixes
        dataset += "*"
    # Find all files matching this pattern
    for fname in glob.glob(f"{logdir}/{dataset}/{optimizer}(*).pkl"):
        exp_args = unpack_args(fname)
        # Skip if filter_args do not match args of this file
        if not contain_dict(exp_args, filter_args):
            continue
        # Load data
        data = loaddata(fname)
        # Handle empty data files
        if len(data) == 0:
            print(fname, "has no data!")
            """
            if "y" == input("Remove empty log files in the future? (y/n)"):
                remove_empty_data = True
            if remove_empty_data:
                try:
                    print("Removing", fname)
                    os.remove(fname)
                except OSError as e:
                    print ("Error: %s - %s." % (e.filename, e.strerror))
            """
            continue
        # @XXX: hack to correct wrong initial ep>0 for L-SVRG
        ep0 = data[0,0]
        if ep0 > 0.:
            data[:,0] -= ep0
        yield data, exp_args

        
# Gather data
all_dfs = {}
start_time = time.time()
for exp in product(DATASETS, OPTIMIZERS):
    exp_df = pd.DataFrame()
    # Get all log data given the experiment and filter args
    for data, args in get_logs(LOG_DIR, *exp, **FILTER_ARGS):
        if ignore(args):
            continue
        # Get experiment log data
        df = pd.DataFrame(data[:, :4], columns=LOG_COLS)
        # Get args of interest
        for col in ARG_COLS:
            df[col] = args[col]
        # Downsample by averaging metrics every AVG_DOWNSAMPLE epoch.
        df["ep"] = np.ceil(df["ep"] / AVG_DOWNSAMPLE) * AVG_DOWNSAMPLE
        df = df.groupby(["ep"] + ARG_COLS).mean().reset_index()
        # Get data up to the prespecified epoch T
        df = df[df["ep"] <= T]
        # @TODO: is this efficient?
        exp_df = exp_df.append(df, ignore_index=True)
    # Record all runs of exp in a single dataframe
    all_dfs[exp] = exp_df

    if len(exp_df) == 0:
        print("No log data found for this experiment!")
        print("- Experiment:", exp)
        print("- filter_args:", FILTER_ARGS)
        continue
data_gather_time = time.time() - start_time
print(f"Data frame lengths:")
for exp, df in all_dfs.items():
    print(f"{exp} -> {len(df)} data rows -> {len(df) // T} runs")
print(f"Took about {data_gather_time:.2f} seconds to gather all these data.")

Data frame lengths:
('a9a', 'SGD') -> 27720 data rows -> 277 runs
('a9a', 'SARAH') -> 27720 data rows -> 277 runs
('a9a', 'L-SVRG') -> 27720 data rows -> 277 runs
('a9a', 'Adam') -> 9240 data rows -> 92 runs
('w8a', 'SGD') -> 27720 data rows -> 277 runs
('w8a', 'SARAH') -> 27720 data rows -> 277 runs
('w8a', 'L-SVRG') -> 27720 data rows -> 277 runs
('w8a', 'Adam') -> 9240 data rows -> 92 runs
('rcv1', 'SGD') -> 27720 data rows -> 277 runs
('rcv1', 'SARAH') -> 27720 data rows -> 277 runs
('rcv1', 'L-SVRG') -> 27720 data rows -> 277 runs
('rcv1', 'Adam') -> 9240 data rows -> 92 runs
('real-sim', 'SGD') -> 27720 data rows -> 277 runs
('real-sim', 'SARAH') -> 27720 data rows -> 277 runs
('real-sim', 'L-SVRG') -> 27720 data rows -> 277 runs
('real-sim', 'Adam') -> 9240 data rows -> 92 runs
Took about 50.02 seconds to gather all these data.


In [21]:
for i, (exp, df) in enumerate(all_dfs.items()):
    if i == 3: break
    print(exp)
    display(df)

('a9a', 'SGD')


Unnamed: 0,ep,lr,alpha,beta2,loss,gradnorm,error
0,0.0,0.250000,1.000000e-03,0.990,1.768381,1.181963e-01,0.500000
1,5.0,0.250000,1.000000e-03,0.990,1.000000,7.298895e-20,0.240810
2,10.0,0.250000,1.000000e-03,0.990,1.000000,7.299940e-20,0.240810
3,15.0,0.250000,1.000000e-03,0.990,1.000000,7.301153e-20,0.240810
4,20.0,0.250000,1.000000e-03,0.990,1.000000,7.302079e-20,0.240810
...,...,...,...,...,...,...,...
27715,80.0,0.003906,1.000000e-07,0.995,0.965299,1.260563e-08,0.183585
27716,85.0,0.003906,1.000000e-07,0.995,0.964982,1.461944e-08,0.181776
27717,90.0,0.003906,1.000000e-07,0.995,0.965115,8.293785e-09,0.184632
27718,95.0,0.003906,1.000000e-07,0.995,0.965051,1.032599e-08,0.181048


('a9a', 'SARAH')


Unnamed: 0,ep,lr,alpha,beta2,loss,gradnorm,error
0,0.0,0.0625,1.000000e-07,0.995,1.768381,1.181963e-01,0.500000
1,5.0,0.0625,1.000000e-07,0.995,1.907918,2.176345e-05,0.478682
2,10.0,0.0625,1.000000e-07,0.995,2.851762,2.211585e-07,0.713700
3,15.0,0.0625,1.000000e-07,0.995,2.961502,1.999023e-11,0.739992
4,20.0,0.0625,1.000000e-07,0.995,2.969745,1.028276e-10,0.742105
...,...,...,...,...,...,...,...
27715,80.0,0.0625,1.000000e-07,0.995,2.672123,1.257463e-37,0.647615
27716,85.0,0.0625,1.000000e-07,0.995,2.672123,1.257463e-37,0.647615
27717,90.0,0.0625,1.000000e-07,0.995,2.672123,1.257463e-37,0.647615
27718,95.0,0.0625,1.000000e-07,0.995,2.672123,1.257463e-37,0.647615


('a9a', 'L-SVRG')


Unnamed: 0,ep,lr,alpha,beta2,loss,gradnorm,error
0,0.0,16.0000,1.000000e-07,0.95,1.768381,1.181963e-01,0.500000
1,5.0,16.0000,1.000000e-07,0.95,1.544670,9.056915e-11,0.387043
2,10.0,16.0000,1.000000e-07,0.95,1.461319,2.633319e-150,0.360370
3,15.0,16.0000,1.000000e-07,0.95,1.461319,2.633319e-150,0.360370
4,20.0,16.0000,1.000000e-07,0.95,1.461319,2.633319e-150,0.360370
...,...,...,...,...,...,...,...
27715,80.0,0.0625,1.000000e-01,0.99,0.963899,5.313822e-06,0.208206
27716,85.0,0.0625,1.000000e-01,0.99,0.962462,4.405531e-06,0.206733
27717,90.0,0.0625,1.000000e-01,0.99,0.961552,3.881815e-06,0.205424
27718,95.0,0.0625,1.000000e-01,0.99,0.960650,3.406000e-06,0.204335


## Get best hyperparams

In [22]:
for exp, df in all_dfs.items():
    if exp[1] == "Adam": continue
    alphas = set([] if "alpha" not in ARG_COLS else df["alpha"])
    betas = set([] if "beta2" not in ARG_COLS else df["beta2"])
    lrs = set([] if "lr" not in ARG_COLS else df["lr"])
    preconds = set([] if "precond" not in ARG_COLS else df["precond"])
    break

best_dfs = {}
best_dfs_alpha = {}
best_dfs_beta = {}
best_dfs_lr = {}
best_dfs_precond = {}
for exp in product(DATASETS, OPTIMIZERS):
    print("Finding best hyperparams for", exp)
    best_dfs_alpha[exp] = {}
    best_dfs_beta[exp] = {}
    best_dfs_lr[exp] = {}
    best_dfs_precond[exp] = {}
    exp_df = all_dfs[exp]
    # Get last metrics/performance (supposed to be epoch-smoothed for better results)
    max_ep = exp_df.groupby(ARG_COLS, sort=False)["ep"].transform(max)
    perf = exp_df[exp_df["ep"] == max_ep].drop("ep", axis=1)
    # Find the minimum aggregate metric (based on mean, median, etc.)
    def find_best_hyperparams(perf):
        if AGG == "mean":
            agg_perf = perf.groupby(ARG_COLS).mean()
        elif AGG == "median":
            agg_perf = perf.groupby(ARG_COLS).median()
        # Get the aggregated perf that minimizes the chosen metric
        min_agg_perf = agg_perf[agg_perf[METRIC] == agg_perf.min()[METRIC]]
        return min_agg_perf.index
    # Get the data associated with the args of the min aggregated metric
    exp_df = exp_df.set_index(ARG_COLS)
    best_dfs[exp] = exp_df.loc[find_best_hyperparams(perf)]
    for alpha in alphas:
        best_dfs_alpha[exp][alpha] = exp_df.loc[find_best_hyperparams(perf[perf["alpha"] == alpha])]
    for beta in betas:
        best_dfs_beta[exp][beta] = exp_df.loc[find_best_hyperparams(perf[perf["beta2"] == beta])]
    for lr in lrs:
        best_dfs_lr[exp][lr] = exp_df.loc[find_best_hyperparams(perf[perf["lr"] == lr])]
    for precond in preconds:
        best_dfs_precond[exp][precond] = exp_df.loc[find_best_hyperparams(perf[perf["precond"] == precond])]

Finding best hyperparams for ('a9a', 'SGD')
Finding best hyperparams for ('a9a', 'SARAH')
Finding best hyperparams for ('a9a', 'L-SVRG')
Finding best hyperparams for ('a9a', 'Adam')
Finding best hyperparams for ('w8a', 'SGD')
Finding best hyperparams for ('w8a', 'SARAH')
Finding best hyperparams for ('w8a', 'L-SVRG')
Finding best hyperparams for ('w8a', 'Adam')
Finding best hyperparams for ('rcv1', 'SGD')
Finding best hyperparams for ('rcv1', 'SARAH')
Finding best hyperparams for ('rcv1', 'L-SVRG')
Finding best hyperparams for ('rcv1', 'Adam')
Finding best hyperparams for ('real-sim', 'SGD')
Finding best hyperparams for ('real-sim', 'SARAH')
Finding best hyperparams for ('real-sim', 'L-SVRG')
Finding best hyperparams for ('real-sim', 'Adam')


In [23]:
print("Best hyperparams for each optimizer on each dataset given the following setting:")
print(FILTER_ARGS)
print()
for exp, df in best_dfs.items():
    print(exp)
    for arg, val in zip(ARG_COLS, df.index[0]):
        if arg == "lr":
            val = "2**" + str(int(log2(val)))
        print(f"- {arg} = {val}")
    print()

Best hyperparams for each optimizer on each dataset given the following setting:
{'corrupt': 'none', 'weight_decay': 0}

('a9a', 'SGD')
- lr = 2**-8
- alpha = 1e-07
- beta2 = 0.95

('a9a', 'SARAH')
- lr = 2**-10
- alpha = 1e-07
- beta2 = 0.995

('a9a', 'L-SVRG')
- lr = 2**-10
- alpha = 1e-07
- beta2 = 0.99

('a9a', 'Adam')
- lr = 2**-2
- alpha = 1e-08
- beta2 = 0.99

('w8a', 'SGD')
- lr = 2**-10
- alpha = 1e-07
- beta2 = 0.999

('w8a', 'SARAH')
- lr = 2**-12
- alpha = 1e-07
- beta2 = 0.995

('w8a', 'L-SVRG')
- lr = 2**-10
- alpha = 1e-07
- beta2 = 0.995

('w8a', 'Adam')
- lr = 2**0
- alpha = 1e-08
- beta2 = 0.95

('rcv1', 'SGD')
- lr = 2**-12
- alpha = 1e-07
- beta2 = 0.995

('rcv1', 'SARAH')
- lr = 2**-14
- alpha = 1e-07
- beta2 = 0.95

('rcv1', 'L-SVRG')
- lr = 2**-12
- alpha = 1e-07
- beta2 = 0.995

('rcv1', 'Adam')
- lr = 2**-4
- alpha = 1e-08
- beta2 = 0.99

('real-sim', 'SGD')
- lr = 2**-12
- alpha = 1e-07
- beta2 = 0.999

('real-sim', 'SARAH')
- lr = 2**-14
- alpha = 1e-07
- bet

# Plotting

## Plot gradnorm per lr

In [24]:
print("Types")
for col in df.columns:
    print(col, df[col].dtypes)

Types
ep float64
loss float64
gradnorm float64
error float64


In [25]:
print("Learning rates:")
for exp, df in all_dfs.items():
    display(set("2**"+str(int(log2(lr))) for lr in df["lr"]))
    break

Learning rates:


{'2**-10',
 '2**-12',
 '2**-14',
 '2**-16',
 '2**-2',
 '2**-4',
 '2**-6',
 '2**-8',
 '2**0',
 '2**2',
 '2**4'}

In [26]:
print("Range")
for col in df.columns:
    if df[col].dtypes != "object":
        print(f"{col}: ({df[col].min():}, {df[col].max()})")

Range
ep: (0.0, 100.0)
lr: (1.52587890625e-05, 16.0)
alpha: (1e-07, 0.1)
beta2: (0.95, 0.999)
loss: (0.9451468799491136, 3.036761770215903)
gradnorm: (0.0, 0.11819628026871622)
error: (0.16376826059805696, 0.7591904425539757)


In [27]:
# Set inf values to nan and recheck range
VERYBIGNUMBER = 10**10
df[df == float("inf")] = np.nan
df[df[["loss","gradnorm"]] > VERYBIGNUMBER] = np.nan
for col in df.columns:
    if df[col].dtypes != "object":
        print(f"{col}: ({df[col].min():}, {df[col].max()})")

ep: (0.0, 100.0)
lr: (1.52587890625e-05, 16.0)
alpha: (1e-07, 0.1)
beta2: (0.95, 0.999)
loss: (0.9451468799491136, 3.036761770215903)
gradnorm: (0.0, 0.11819628026871622)
error: (0.16376826059805696, 0.7591904425539757)


In [28]:
print("Check for NaNs in each column for each df.")
for col in df.columns:
    print(col)
    for exp, df in all_dfs.items():
        print(f"- {exp}: {df[col].isna().sum()}")

Check for NaNs in each column for each df.
ep
- ('a9a', 'SGD'): 0
- ('a9a', 'SARAH'): 0
- ('a9a', 'L-SVRG'): 0
- ('a9a', 'Adam'): 0
- ('w8a', 'SGD'): 0
- ('w8a', 'SARAH'): 0
- ('w8a', 'L-SVRG'): 0
- ('w8a', 'Adam'): 0
- ('rcv1', 'SGD'): 0
- ('rcv1', 'SARAH'): 0
- ('rcv1', 'L-SVRG'): 0
- ('rcv1', 'Adam'): 0
- ('real-sim', 'SGD'): 0
- ('real-sim', 'SARAH'): 0
- ('real-sim', 'L-SVRG'): 0
- ('real-sim', 'Adam'): 0
lr
- ('a9a', 'SGD'): 0
- ('a9a', 'SARAH'): 0
- ('a9a', 'L-SVRG'): 0
- ('a9a', 'Adam'): 0
- ('w8a', 'SGD'): 0
- ('w8a', 'SARAH'): 0
- ('w8a', 'L-SVRG'): 0
- ('w8a', 'Adam'): 0
- ('rcv1', 'SGD'): 0
- ('rcv1', 'SARAH'): 0
- ('rcv1', 'L-SVRG'): 0
- ('rcv1', 'Adam'): 0
- ('real-sim', 'SGD'): 0
- ('real-sim', 'SARAH'): 0
- ('real-sim', 'L-SVRG'): 0
- ('real-sim', 'Adam'): 0
alpha
- ('a9a', 'SGD'): 0
- ('a9a', 'SARAH'): 0
- ('a9a', 'L-SVRG'): 0
- ('a9a', 'Adam'): 0
- ('w8a', 'SGD'): 0
- ('w8a', 'SARAH'): 0
- ('w8a', 'L-SVRG'): 0
- ('w8a', 'Adam'): 0
- ('rcv1', 'SGD'): 0
- ('rcv1', 'SARA

## Plot best performance of each optimizer on each dataset

In [None]:
start_time = time.time()
y_greek = {"loss": r"$P(w_t)$", "gradnorm": r"$||\nabla P(w_t)||^2$", "error": "Error"}

# Plot 3 rows each one showing some performance metric,
# where the columns are the dataset on which the optim is run.
fig, axes = plt.subplots(3, len(DATASETS))
fig.set_size_inches(5 * len(DATASETS), 5 * 3)
plt.suptitle(rf"Best Performances w.r.t. {y_greek[METRIC]}")
for j, dataset in enumerate(DATASETS):
    for optimizer in OPTIMIZERS:
        exp = (dataset, optimizer)
        if exp not in best_dfs:
            continue
        # Get hyperparams of best performance of 'optimizer' on 'dataset'
        args = {k:v for k,v in zip(best_dfs[exp].index.names, best_dfs[exp].index[0])}
        exp_df = best_dfs[exp].reset_index()
        # Show power of lr as 2^lr_pow
        lr_pow = round(log2(args['lr']))
        if optimizer == "Adam":
            sublabel = rf"$\eta = 2^{{{lr_pow}}}$, $\beta_1=0.9$, $\beta_2={args['beta2']}$"
        else:
            sublabel = rf"$\eta = 2^{{{lr_pow}}}$, $\alpha={args['alpha']}$, $\beta={args['beta2']}$"
        label = rf"{optimizer}({sublabel})"
        print(f"Plotting lines for {exp}...")
        sns.lineplot(x="ep", y="loss", label=label, ax=axes[0,j], data=exp_df)
        sns.lineplot(x="ep", y="gradnorm", label=label, ax=axes[1,j], data=exp_df)
        sns.lineplot(x="ep", y="error", label=label, ax=axes[2,j], data=exp_df)
    # Loss
    axes[0,j].set_title(dataset)
    axes[0,j].set_ylabel(r"$P(w_t)$")
    axes[0,j].set_xlabel("Effective Passes")
    axes[0,j].legend()
    # Gradnorm
    axes[1,j].set(yscale="log")
    axes[1,j].set_title(dataset)
    axes[1,j].set_ylabel(r"$||\nabla P(w_t)||^2$")
    axes[1,j].set_xlabel("Effective Passes")
    axes[1,j].legend()
    # Error
    axes[2,j].set(yscale="log")
    axes[2,j].set_title(dataset)
    axes[2,j].set_ylabel("Error")
    axes[2,j].set_xlabel("Effective Passes")
    axes[2,j].legend()
fig.tight_layout()

# Create a string out of filter args and save figure
filter_args_str = ",".join(f"{k}={v}" for k,v in FILTER_ARGS.items())
plt.savefig(f"plots/best_{METRIC}_overall({filter_args_str}).pdf")
plt.close()
plot_best_time = time.time() - start_time
print(f"Took about {plot_best_time:.2f} seconds to create this plot.")

Plotting lines for ('a9a', 'SGD')...
Plotting lines for ('a9a', 'SARAH')...
Plotting lines for ('a9a', 'L-SVRG')...
Plotting lines for ('a9a', 'Adam')...
Plotting lines for ('w8a', 'SGD')...
Plotting lines for ('w8a', 'SARAH')...
Plotting lines for ('w8a', 'L-SVRG')...
Plotting lines for ('w8a', 'Adam')...
Plotting lines for ('rcv1', 'SGD')...
Plotting lines for ('rcv1', 'SARAH')...
Plotting lines for ('rcv1', 'L-SVRG')...
Plotting lines for ('rcv1', 'Adam')...
Plotting lines for ('real-sim', 'SGD')...
Plotting lines for ('real-sim', 'SARAH')...
Plotting lines for ('real-sim', 'L-SVRG')...
Plotting lines for ('real-sim', 'Adam')...
Took about 15.05 seconds to create this plot.


# Plot best performance given a fixed value of either $\alpha$, $\beta$, or $\eta$

In [None]:
best_dfs_mode = {"alphas": best_dfs_alpha, "betas": best_dfs_beta, "lrs": best_dfs_lr}
mode_greek = {"alphas": r"$\alpha$", "betas": r"$\beta$", "lrs": r"$\eta$"}

for y in ("error", "gradnorm"):
    for mode in ("alphas", "betas", "lrs"):
        valid_optimizers = [opt for opt in OPTIMIZERS if not (mode == "alphas" and opt == "Adam")]

        start_time = time.time()
        # Plot data for all optim, datasets, and args
        fig, axes = plt.subplots(len(valid_optimizers), len(DATASETS))
        fig.set_size_inches(5 * len(DATASETS), 5 * (len(valid_optimizers)))
        title = rf"Best {y_greek[y]} given {mode_greek[mode]}"
        plt.suptitle(title)
        for i, optimizer in enumerate(valid_optimizers):
            for j, dataset in enumerate(DATASETS):
                exp = (dataset, optimizer)
                if exp not in best_dfs_lr:
                    continue
                exp_df = pd.concat(list(best_dfs_mode[mode][exp].values())).reset_index()
                if len(exp_df) == 0:
                    continue
                axes[i,j].set_title(rf"$\tt {optimizer}({dataset})$")
                # avoid silly problem of inconsistent style across axes
                exp_df["beta2"] = exp_df["beta2"].astype(str)
                exp_df = exp_df.sort_values("alpha", ascending=False)
                exp_df = exp_df.sort_values("beta2", ascending=True)
                print(f"Plotting lines for {exp}...")
                sns.lineplot(ax=axes[i,j], x="ep", y=y,
                             hue="lr", hue_norm=LogNorm(), palette="vlag",
                             size="beta2", style="alpha", data=exp_df)
                axes[i,j].set_ylabel(rf"{y_greek[y]}")
                axes[i,j].set(yscale="log")  # @TODO: always set log scale?
        fig.tight_layout()

        # Create a string out of filter args and save figure
        filter_args_str = ",".join(f"{k}={v}" for k,v in FILTER_ARGS.items())
        plt.savefig(f"plots/best_{y}_given_{mode}({filter_args_str}).pdf")
        plt.close()
        plot_time = time.time() - start_time
        print(f"Took about {plot_time:.2f} seconds to create this plot.")

Plotting lines for ('a9a', 'SGD')...
Plotting lines for ('w8a', 'SGD')...
Plotting lines for ('rcv1', 'SGD')...
Plotting lines for ('real-sim', 'SGD')...
Plotting lines for ('a9a', 'SARAH')...
Plotting lines for ('w8a', 'SARAH')...
Plotting lines for ('rcv1', 'SARAH')...
Plotting lines for ('real-sim', 'SARAH')...
Plotting lines for ('a9a', 'L-SVRG')...
Plotting lines for ('w8a', 'L-SVRG')...
Plotting lines for ('rcv1', 'L-SVRG')...
Plotting lines for ('real-sim', 'L-SVRG')...
Took about 22.41 seconds to create this plot.
Plotting lines for ('a9a', 'SGD')...
Plotting lines for ('w8a', 'SGD')...
Plotting lines for ('rcv1', 'SGD')...
Plotting lines for ('real-sim', 'SGD')...
Plotting lines for ('a9a', 'SARAH')...
Plotting lines for ('w8a', 'SARAH')...
Plotting lines for ('rcv1', 'SARAH')...
Plotting lines for ('real-sim', 'SARAH')...
Plotting lines for ('a9a', 'L-SVRG')...
Plotting lines for ('w8a', 'L-SVRG')...
Plotting lines for ('rcv1', 'L-SVRG')...
Plotting lines for ('real-sim', 'L-