In [27]:
import os
import glob
import pickle
import time
from math import log2
from itertools import cycle, product
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm
%matplotlib inline
MARKERS = (',', '+', '.', 'o', '*', "D")

# Settings

In [28]:
# Data logs root directory
LOG_DIR = "logs10"

# The following should be the same as the one used in run_experiment.py
DATASETS = ("a9a", "w8a", "rcv1", "real-sim",)
OPTIMIZERS = ("SGD", "Adam", "SARAH", "L-SVRG")
T = 100  # Use 2xT used in run_experiment.py

# These are the metrics collected in the data logs
METRICS = ("loss", "gradnorm", "error")
METRIC = "gradnorm"  # default metric

# These are aggregators for comparing multi-seed runs
AGGS = ("mean", "median")
AGG = "mean"  # default aggregator

# Downsample this number of effective passes by averaging them
AVG_DOWNSAMPLE = 5

# These are the logs columns: effective passes + metrics
LOG_COLS = ["ep"] + list(METRICS)

# These are the hyperparameters of interest
ARG_COLS = ["lr", "BS", "precond", "alpha"]

# Plots will be generated for this hyperparams/args setting.
# 'corrupt' should be the suffix of the dataset as a string.
FILTER_ARGS = {
    #"corrupt": "",
    "corrupt": "(-3,3)",
    "weight_decay": 0,
}

# Ignore all runs containing 'any' of these hyperparams.
IGNORE_ARGS = {
    "alpha": [1e-9],
    #"alpha": [1e-1, 1e-3, 1e-7, 1e-9],
    "BS": [2048],
    "gamma": [2**-16, 2**-18, 2**-20],
}

### Utility functions for loading data

In [29]:
def ignore(args_dict):
    return any(args_dict[arg] in IGNORE_ARGS[arg]
               for arg in IGNORE_ARGS.keys() if arg in args_dict)


def loaddata(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
    return data


def contain_dict(dict1, dict2):
    return all(dict1[k] == v for k, v in dict2.items() if k in dict1)

# Gathering data and finding best hyperparameters for each (optimizer, dataset) combination

In [30]:
def unpack_args(fname):
    """
    Recover all args given file path.
    """
    args = {}
    # unpack path
    dirname, logname = os.path.split(fname)
    logdir, args["dataset"] = os.path.split(dirname)
    # parse args
    args["optimizer"], argstr = logname.split("(")
    argstr, _ = argstr.split(")")  # remove ').pkl'
    args_dict = {k:v for k,v in [s.split("=") for s in argstr.split(",")]}

    # Extract args
    if args["dataset"][-1] == ")":
        args["corrupt"] = args["dataset"][args["dataset"].index("("):]
    else:
        # It is very unlikely that the original dataset name will end with ')'
        args["corrupt"] = "none"

    if "seed" in args_dict:
        args["seed"] = int(args_dict["seed"])
    else:
        args["seed"] = 0

    args["BS"] = int(args_dict["BS"])
    args["lr"] = float(args_dict["lr"])
    if "weight_decay" in args_dict:
        args["weight_decay"] = float(args_dict["weight_decay"])
    else:
        args["weight_decay"] = 0
    if "lr_decay" in args_dict:
        args["lr_decay"] = float(args_dict["lr_decay"])
    else:
        args["lr_decay"] = 0
    if "p" in args_dict:
        args["p"] = args_dict["p"]
    if "precond" in args_dict:
        args["precond"] = args_dict["precond"]
        args["beta2"] = float(args_dict["beta2"])
        args["alpha"] = float(args_dict["alpha"])
    else:
        args["precond"] = "none"
        args["alpha"] = "none"

    return args


def get_logs(logdir, dataset, optimizer, **filter_args):
    """
    Return all logs in 'logdir' containing the filter hyperparams.
    Dataset name should contain feature scaling, if any
    e.g. 'dataset' or 'dataset(k_min,k_max)'.
    
    Returns the data in the log file and its arguments/hyperparams.
    """
    remove_empty_data = False
    # Add
    if "corrupt" in filter_args:
        dataset = f"{dataset}{filter_args['corrupt']}"
    else:
        # No setting specified, use wildcard to match all suffixes
        dataset += "*"
    # Find all files matching this pattern
    for fname in glob.glob(f"{logdir}/{dataset}/{optimizer}(*).pkl"):
        exp_args = unpack_args(fname)
        # Skip if filter_args do not match args of this file
        if not contain_dict(exp_args, filter_args):
            continue
        # Load data
        data = loaddata(fname)
        # Handle empty data files
        if len(data) == 0:
            print(fname, "has no data!")
            """
            if "y" == input("Remove empty log files in the future? (y/n)"):
                remove_empty_data = True
            if remove_empty_data:
                try:
                    print("Removing", fname)
                    os.remove(fname)
                except OSError as e:
                    print ("Error: %s - %s." % (e.filename, e.strerror))
            """
            continue
        # @XXX: hack to correct wrong initial ep>0 for L-SVRG
        ep0 = data[0,0]
        if ep0 > 0.:
            data[:,0] -= ep0
        yield data, exp_args

        
# Gather data
all_dfs = {}
start_time = time.time()
for exp in product(DATASETS, OPTIMIZERS):
    exp_df = pd.DataFrame()
    # Get all log data given the experiment and filter args
    for data, args in get_logs(LOG_DIR, *exp, **FILTER_ARGS):
        if ignore(args):
            continue
        # Get experiment log data
        df = pd.DataFrame(data[:, :4], columns=LOG_COLS)
        # Get args of interest
        for col in ARG_COLS:
            df[col] = args[col]
        # Downsample by averaging metrics every AVG_DOWNSAMPLE epoch.
        df["ep"] = np.ceil(df["ep"] / AVG_DOWNSAMPLE) * AVG_DOWNSAMPLE
        df = df.groupby(["ep"] + ARG_COLS).mean().reset_index()
        # Get data up to the prespecified epoch T
        df = df[df["ep"] <= T]
        # @TODO: is this efficient?
        exp_df = exp_df.append(df, ignore_index=True)
    # Record all runs of exp in a single dataframe
    all_dfs[exp] = exp_df

    if len(exp_df) == 0:
        print("No log data found for this experiment!")
        print("- Experiment:", exp)
        print("- filter_args:", FILTER_ARGS)
        continue
    # @TODO: Would it be better to plot these on the go instead of storing them?
data_gather_time = time.time() - start_time

In [31]:
print(f"Data frame lengths:")
for exp, df in all_dfs.items():
    print(f"{exp} -> {len(df)} data rows -> {len(df) // T} runs")
print(f"Took about {data_gather_time:.2f} seconds to gather all these data.")

Data frame lengths:
('a9a', 'SGD') -> 8316 data rows -> 83 runs
('a9a', 'Adam') -> 2079 data rows -> 20 runs
('a9a', 'SARAH') -> 8185 data rows -> 81 runs
('a9a', 'L-SVRG') -> 8316 data rows -> 83 runs
('w8a', 'SGD') -> 8295 data rows -> 82 runs
('w8a', 'Adam') -> 2079 data rows -> 20 runs
('w8a', 'SARAH') -> 8314 data rows -> 83 runs
('w8a', 'L-SVRG') -> 8316 data rows -> 83 runs
('rcv1', 'SGD') -> 8316 data rows -> 83 runs
('rcv1', 'Adam') -> 2079 data rows -> 20 runs
('rcv1', 'SARAH') -> 7392 data rows -> 73 runs
('rcv1', 'L-SVRG') -> 7392 data rows -> 73 runs
('real-sim', 'SGD') -> 7392 data rows -> 73 runs
('real-sim', 'Adam') -> 1848 data rows -> 18 runs
('real-sim', 'SARAH') -> 7392 data rows -> 73 runs
('real-sim', 'L-SVRG') -> 7392 data rows -> 73 runs
Took about 39.52 seconds to gather all these data.


In [32]:
for i, (exp, df) in enumerate(all_dfs.items()):
    if i == 3: break
    print(exp)
    display(df)

('a9a', 'SGD')


Unnamed: 0,ep,lr,BS,precond,alpha,loss,gradnorm,error
0,0.0,0.250000,128,none,none,0.693147,1529.465824,0.500000
1,5.0,0.250000,128,none,none,113.010573,664.831308,0.256550
2,10.0,0.250000,128,none,none,122.158834,699.929352,0.261304
3,15.0,0.250000,128,none,none,95.213317,830.147837,0.258952
4,20.0,0.250000,128,none,none,98.567754,701.136119,0.255896
...,...,...,...,...,...,...,...,...
8311,80.0,0.015625,128,hutchinson,0.1,0.344543,0.510958,0.161522
8312,85.0,0.015625,128,hutchinson,0.1,0.344926,0.597405,0.161130
8313,90.0,0.015625,128,hutchinson,0.1,0.346016,0.913790,0.161501
8314,95.0,0.015625,128,hutchinson,0.1,0.343733,0.631168,0.160822


('a9a', 'Adam')


Unnamed: 0,ep,lr,BS,precond,alpha,loss,gradnorm,error
0,0.0,0.015625,128,none,none,0.693147,1529.465824,0.500000
1,5.0,0.015625,128,none,none,0.435452,70.799062,0.187790
2,10.0,0.015625,128,none,none,0.398886,67.818147,0.176218
3,15.0,0.015625,128,none,none,0.404162,88.323849,0.176607
4,20.0,0.015625,128,none,none,0.393879,54.472134,0.172675
...,...,...,...,...,...,...,...,...
2074,80.0,0.062500,128,none,none,0.583387,140.991985,0.188742
2075,85.0,0.062500,128,none,none,0.570388,125.247746,0.188475
2076,90.0,0.062500,128,none,none,0.544347,117.218403,0.186919
2077,95.0,0.062500,128,none,none,0.589378,173.168942,0.190157


('a9a', 'SARAH')


Unnamed: 0,ep,lr,BS,precond,alpha,loss,gradnorm,error
0,0.0,4.0,128,none,none,0.693147,840.928096,0.500000
1,5.0,4.0,128,none,none,170038.319924,757.545785,0.305833
2,10.0,4.0,128,none,none,23089.423052,612.062111,0.323217
3,15.0,4.0,128,none,none,28242.598214,553.153684,0.285293
4,20.0,4.0,128,none,none,17606.615604,286.914670,0.291054
...,...,...,...,...,...,...,...,...
8180,75.0,4.0,128,none,none,7896.344815,994.151625,0.346751
8181,80.0,4.0,128,none,none,7383.777889,208.110324,0.271292
8182,85.0,4.0,128,none,none,3352.979973,144.781995,0.248543
8183,90.0,4.0,128,none,none,3634.551796,109.708327,0.228397


## Get best hyperparams

In [33]:
best_dfs = {}
best_dfs_with_precond = {}
best_dfs_without_precond = {}
for exp in product(DATASETS, OPTIMIZERS):
    exp_df = all_dfs[exp]
    # Get last metrics/performance  @TODO: is this good?
    max_ep = exp_df.groupby(ARG_COLS, sort=False)["ep"].transform(max)
    perf = exp_df[exp_df["ep"] == max_ep].drop("ep", axis=1)
    # Find the minimum aggregate metric (based on mean, median, etc.)
    def find_best_hyperparams(perf):
        if AGG == "mean":
            agg_perf = perf.groupby(ARG_COLS).mean()
        elif AGG == "median":
            agg_perf = perf.groupby(ARG_COLS).median()
        # Get the aggregated perf that minimizes the chosen metric
        min_agg_perf = agg_perf[agg_perf[METRIC] == agg_perf.min()[METRIC]]
        return min_agg_perf.index
    # Get the data associated with the args of the min aggregated metric
    exp_df = exp_df.set_index(ARG_COLS)
    best_dfs[exp] = exp_df.loc[find_best_hyperparams(perf)]
    best_dfs_with_precond[exp] = exp_df.loc[find_best_hyperparams(perf[perf["precond"] == "hutchinson"])]
    best_dfs_without_precond[exp] = exp_df.loc[find_best_hyperparams(perf[perf["precond"] == "none"])]

In [34]:
print("Best hyperparams for each optimizer on each dataset given the following setting:")
print(FILTER_ARGS)
print()
for exp, df in best_dfs.items():
    print(exp)
    for arg, val in zip(ARG_COLS, df.index[0]):
        if arg == "lr":
            val = "2**" + str(int(log2(val)))
        print(f"- {arg} = {val}")
    print()

Best hyperparams for each optimizer on each dataset given the following setting:
{'corrupt': '(-3,3)', 'weight_decay': 0}

('a9a', 'SGD')
- lr = 2**-10
- BS = 128
- precond = hutchinson
- alpha = 0.1

('a9a', 'Adam')
- lr = 2**-16
- BS = 128
- precond = none
- alpha = none

('a9a', 'SARAH')
- lr = 2**-6
- BS = 128
- precond = hutchinson
- alpha = 0.1

('a9a', 'L-SVRG')
- lr = 2**-4
- BS = 128
- precond = hutchinson
- alpha = 0.1

('w8a', 'SGD')
- lr = 2**-12
- BS = 128
- precond = hutchinson
- alpha = 0.001

('w8a', 'Adam')
- lr = 2**-16
- BS = 128
- precond = none
- alpha = none

('w8a', 'SARAH')
- lr = 2**-8
- BS = 128
- precond = hutchinson
- alpha = 0.1

('w8a', 'L-SVRG')
- lr = 2**-6
- BS = 128
- precond = hutchinson
- alpha = 0.1

('rcv1', 'SGD')
- lr = 2**-10
- BS = 128
- precond = hutchinson
- alpha = 0.001

('rcv1', 'Adam')
- lr = 2**-14
- BS = 128
- precond = none
- alpha = none

('rcv1', 'SARAH')
- lr = 2**-8
- BS = 128
- precond = hutchinson
- alpha = 0.1

('rcv1', 'L-SVRG'

# Plotting

## Plot gradnorm per lr

In [35]:
print("Types")
for col in df.columns:
    print(col, df[col].dtypes)

Types
ep float64
loss float64
gradnorm float64
error float64


In [36]:
print("Learning rates:")
for exp, df in all_dfs.items():
    display(set("2**"+str(int(log2(lr))) for lr in df["lr"]))
    break

Learning rates:


{'2**-10',
 '2**-12',
 '2**-14',
 '2**-16',
 '2**-2',
 '2**-4',
 '2**-6',
 '2**-8',
 '2**0',
 '2**2',
 '2**4'}

In [37]:
print("Range")
for col in df.columns:
    if df[col].dtypes != "object":
        print(f"{col}: ({df[col].min():}, {df[col].max()})")

Range
ep: (0.0, 100.0)
lr: (1.52587890625e-05, 16.0)
BS: (128, 128)
loss: (0.33428567783858903, 190231207821.8668)
gradnorm: (0.007554754362071744, 7014.391853635776)
error: (0.15514598241249758, 0.5)


In [38]:
# Set inf values to nan and recheck range
VERYBIGNUMBER = 10**20
df[df == float("inf")] = np.nan
df[df[["loss","gradnorm"]] > VERYBIGNUMBER] = np.nan
for col in df.columns:
    if df[col].dtypes != "object":
        print(f"{col}: ({df[col].min():}, {df[col].max()})")

ep: (0.0, 100.0)
lr: (1.52587890625e-05, 16.0)
BS: (128, 128)
loss: (0.33428567783858903, 190231207821.8668)
gradnorm: (0.007554754362071744, 7014.391853635776)
error: (0.15514598241249758, 0.5)


In [39]:
print("Check for NaNs in each column for each df.")
for col in df.columns:
    print(col)
    for exp, df in all_dfs.items():
        print(f"- {exp}: {df[col].isna().sum()}")

Check for NaNs in each column for each df.
ep
- ('a9a', 'SGD'): 0
- ('a9a', 'Adam'): 0
- ('a9a', 'SARAH'): 0
- ('a9a', 'L-SVRG'): 0
- ('w8a', 'SGD'): 0
- ('w8a', 'Adam'): 0
- ('w8a', 'SARAH'): 0
- ('w8a', 'L-SVRG'): 0
- ('rcv1', 'SGD'): 0
- ('rcv1', 'Adam'): 0
- ('rcv1', 'SARAH'): 0
- ('rcv1', 'L-SVRG'): 0
- ('real-sim', 'SGD'): 0
- ('real-sim', 'Adam'): 0
- ('real-sim', 'SARAH'): 0
- ('real-sim', 'L-SVRG'): 0
lr
- ('a9a', 'SGD'): 0
- ('a9a', 'Adam'): 0
- ('a9a', 'SARAH'): 0
- ('a9a', 'L-SVRG'): 0
- ('w8a', 'SGD'): 0
- ('w8a', 'Adam'): 0
- ('w8a', 'SARAH'): 0
- ('w8a', 'L-SVRG'): 0
- ('rcv1', 'SGD'): 0
- ('rcv1', 'Adam'): 0
- ('rcv1', 'SARAH'): 0
- ('rcv1', 'L-SVRG'): 0
- ('real-sim', 'SGD'): 0
- ('real-sim', 'Adam'): 0
- ('real-sim', 'SARAH'): 0
- ('real-sim', 'L-SVRG'): 0
BS
- ('a9a', 'SGD'): 0
- ('a9a', 'Adam'): 0
- ('a9a', 'SARAH'): 0
- ('a9a', 'L-SVRG'): 0
- ('w8a', 'SGD'): 0
- ('w8a', 'Adam'): 0
- ('w8a', 'SARAH'): 0
- ('w8a', 'L-SVRG'): 0
- ('rcv1', 'SGD'): 0
- ('rcv1', 'Adam'):

In [40]:
start_time = time.time()
# Plot data for all optim, datasets, and args
fig, axes = plt.subplots(len(OPTIMIZERS), len(DATASETS))
fig.set_size_inches(5 * len(DATASETS), 5 * len(OPTIMIZERS))
title = r"$||\nabla F(w_t)||^2$ per $\eta$"
plt.suptitle(title)
for i, optimizer in enumerate(OPTIMIZERS):
    for j, dataset in enumerate(DATASETS):
        exp = (dataset, optimizer)
        if exp not in all_dfs:
            continue
        exp_df = all_dfs[exp]
        axes[i,j].set_title(rf"$\tt{optimizer}({dataset})$")
        # axes[i,j].set_title(rf"{optimizer}({dataset})")
        # avoid silly problem of inconsistent style across axes
        exp_df["alpha"] = exp_df["alpha"].astype(str)
        exp_df = exp_df.sort_values("alpha", ascending=False)
        print(f"Plotting lines for {exp}...")
        sns.lineplot(ax=axes[i,j], x="ep", y="gradnorm",
                     hue="lr", hue_norm=LogNorm(), palette="vlag",
                     #size="BS", sizes=(1, 2),  # @XXX
                     style="alpha",
                     # plot only for alpha == none and alpha == 0.001 to reduce clutter
                     data=exp_df[(exp_df["alpha"] == "none") | (exp_df["alpha"] == "0.001")])
        axes[i,j].set(yscale="log")
        axes[i,j].set_ylabel(r"$||\nabla F(w_t)||^2$")
        # Set an upper limit since we seem to have crazy values for some runs @TODO: remove those runs?
        axes[i,j].set_ylim(top=exp_df[exp_df["ep"] == 0]["gradnorm"].max())
fig.tight_layout()

# Create a string out of filter args and save figure
filter_args_str = ",".join(f"{k}={v}" for k,v in FILTER_ARGS.items())
plt.savefig(f"plots/learning_rates({filter_args_str}).pdf")
plt.close()
plot_per_lr_time = time.time() - start_time
print(f"Took about {plot_per_lr_time:.2f} seconds to create this plot.")

Plotting lines for ('a9a', 'SGD')...
Plotting lines for ('w8a', 'SGD')...
Plotting lines for ('rcv1', 'SGD')...
Plotting lines for ('real-sim', 'SGD')...
Plotting lines for ('a9a', 'Adam')...
Plotting lines for ('w8a', 'Adam')...
Plotting lines for ('rcv1', 'Adam')...
Plotting lines for ('real-sim', 'Adam')...
Plotting lines for ('a9a', 'SARAH')...
Plotting lines for ('w8a', 'SARAH')...
Plotting lines for ('rcv1', 'SARAH')...
Plotting lines for ('real-sim', 'SARAH')...
Plotting lines for ('a9a', 'L-SVRG')...
Plotting lines for ('w8a', 'L-SVRG')...
Plotting lines for ('rcv1', 'L-SVRG')...
Plotting lines for ('real-sim', 'L-SVRG')...
Took about 159.08 seconds to create this plot.


## Plot best performance of each optimizer on each dataset

In [41]:
start_time = time.time()
# Plot 3 rows each one showing some performance metric,
# where the columns are the dataset on which the optim is run.
fig, axes = plt.subplots(3, len(DATASETS))
fig.set_size_inches(5 * len(DATASETS), 5 * 3)
plt.suptitle(rf"Best Performances")
for j, dataset in enumerate(DATASETS):
    for optimizer in OPTIMIZERS:
        exp = (dataset, optimizer)
        if exp not in best_dfs:
            continue
        # Get hyperparams of best performance of 'optimizer' on 'dataset'
        args = {k:v for k,v in zip(best_dfs[exp].index.names, best_dfs[exp].index[0])}
        exp_df = best_dfs[exp].reset_index()
        # Show power of lr as 2^lr_pow
        lr_pow = round(log2(args['lr']))
        sublabel = rf"$\eta = 2^{{{lr_pow}}}$, $\alpha={args['alpha']}$"
        label = rf"{optimizer}({sublabel})"
        print(f"Plotting lines for {exp}...")
        sns.lineplot(x="ep", y="loss", label=label, ax=axes[0,j], data=exp_df)
        sns.lineplot(x="ep", y="gradnorm", label=label, ax=axes[1,j], data=exp_df)
        sns.lineplot(x="ep", y="error", label=label, ax=axes[2,j], data=exp_df)
    # Loss
    axes[0,j].set_title(dataset)
    axes[0,j].set_ylabel(r"$F(w_t)$")
    axes[0,j].set_xlabel("Effective Passes")
    axes[0,j].legend()
    # Gradnorm
    axes[1,j].set(yscale="log")
    axes[1,j].set_title(dataset)
    axes[1,j].set_ylabel(r"$||\nabla F(w_t)||^2$")
    axes[1,j].set_xlabel("Effective Passes")
    axes[1,j].legend()
    # Error
    axes[2,j].set(yscale="log")
    axes[2,j].set_title(dataset)
    axes[2,j].set_ylabel("Error")
    axes[2,j].set_xlabel("Effective Passes")
    axes[2,j].legend()
    # Set an upper limit since we seem to have crazy values for some runs @TODO: remove those runs?
    # axes[0,j].set_ylim(top=exp_df[exp_df["ep"] == 0]["loss"].max()*1.1,
                       # bottom=exp_df["loss"].min()*0.9)
    axes[1,j].set_ylim(top=exp_df[exp_df["ep"] == 0]["gradnorm"].max())
fig.tight_layout()

# Create a string out of filter args and save figure
filter_args_str = ",".join(f"{k}={v}" for k,v in FILTER_ARGS.items())
plt.savefig(f"plots/optimizers({filter_args_str}).pdf")
plt.close()
plot_best_time = time.time() - start_time
print(f"Took about {plot_best_time:.2f} seconds to create this plot.")

Plotting lines for ('a9a', 'SGD')...
Plotting lines for ('a9a', 'Adam')...
Plotting lines for ('a9a', 'SARAH')...
Plotting lines for ('a9a', 'L-SVRG')...
Plotting lines for ('w8a', 'SGD')...
Plotting lines for ('w8a', 'Adam')...
Plotting lines for ('w8a', 'SARAH')...
Plotting lines for ('w8a', 'L-SVRG')...
Plotting lines for ('rcv1', 'SGD')...
Plotting lines for ('rcv1', 'Adam')...
Plotting lines for ('rcv1', 'SARAH')...
Plotting lines for ('rcv1', 'L-SVRG')...
Plotting lines for ('real-sim', 'SGD')...
Plotting lines for ('real-sim', 'Adam')...
Plotting lines for ('real-sim', 'SARAH')...
Plotting lines for ('real-sim', 'L-SVRG')...
Took about 27.13 seconds to create this plot.


# Generate plots comparing preconditioning vs none

In [44]:
start_time = time.time()
fig, axes = plt.subplots(3, len(DATASETS))
fig.set_size_inches(5 * len(DATASETS), 5 * 3)
plt.suptitle(rf"Top performance with preconditioning vs. without")
for j, dataset in enumerate(DATASETS):
    optim_df = pd.DataFrame()
    for optimizer in OPTIMIZERS:
        exp = (dataset, optimizer)
        if exp not in best_dfs_with_precond or exp not in best_dfs_without_precond:
            continue
        # Put both dfs together and mark them with the optimizer's name.
        # (They already have 'precond' set accordingly.)
        exp_df = best_dfs_without_precond[exp].append(best_dfs_with_precond[exp])
        exp_df["optimizer"] = optimizer
        optim_df = optim_df.append(exp_df)
    # reset index and combine precond with gamma
    print(f"Plotting lines for {dataset}...")
    optim_df = optim_df.reset_index()
    sns.lineplot(x="ep", y="loss", hue="optimizer", style="precond", ax=axes[0,j], data=optim_df)
    sns.lineplot(x="ep", y="gradnorm", hue="optimizer", style="precond", ax=axes[1,j], data=optim_df)
    sns.lineplot(x="ep", y="error", hue="optimizer", style="precond", ax=axes[2,j], data=optim_df)
    # Loss
    axes[0,j].set_title(dataset)
    axes[0,j].set_ylabel(r"$F(w_t)$")
    axes[0,j].set_xlabel("Effective Passes")
    # Gradnorm
    axes[1,j].set(yscale="log")
    axes[1,j].set_title(dataset)
    axes[1,j].set_ylabel(r"$||\nabla F(w_t)||^2$")
    axes[1,j].set_xlabel("Effective Passes")
    # Error
    axes[2,j].set(yscale="log")
    axes[2,j].set_title(dataset)
    axes[2,j].set_ylabel("Error")
    axes[2,j].set_xlabel("Effective Passes")
fig.tight_layout()

filter_args_str = ",".join(f"{k}={v}" for k,v in FILTER_ARGS.items())
plt.savefig(f"plots/compare_optimizers({filter_args_str}).pdf")
plt.close()
plot_compare_time = time.time() - start_time
print(f"Took about {plot_compare_time:.2f} seconds to create this plot.")

Plotting lines for a9a...
Plotting lines for w8a...
Plotting lines for rcv1...
Plotting lines for real-sim...
Took about 50.53 seconds to create this plot.


In [58]:
def display_best_performances(best_data, show_alpha=True):
    for dataset in DATASETS:
        for optimizer in OPTIMIZERS:
            # Extract best performance metrics for each experiment
            exp = (dataset, optimizer)
            if exp not in best_data:
                continue
            if len(best_data[exp].index) == 0:
                # Not applicable to exp, likely because it does not have 'precond'
                continue
            args = {k:v for k,v in zip(best_data[exp].index.names, best_data[exp].index[0])}
            exp_df = best_data[exp].reset_index()
            loss = exp_df["loss"].iloc[-1]
            gradnorm = exp_df["gradnorm"].iloc[-1]
            error = exp_df["error"].iloc[-1]
            # Print report
            print(f"{exp}:"
                  f"\tlr = 2^{str(round(log2(args['lr'])))}," + \
                  (f"\talpha = {args['alpha']}," if show_alpha else "") + \
                  f"\tloss = {loss:5f},"
                  f"\tgradnorm = {gradnorm:5f},"
                  f"\terror = {error:5f}")
    print()


print(f"Best hyperparameters using {METRIC} metric WITHOUT preconditoning:")
display_best_performances(best_dfs_without_precond, show_alpha=False)
print(f"Best hyperparameters using {METRIC} metric WITH preconditoning:")
display_best_performances(best_dfs_with_precond, show_alpha=True)

Best hyperparameters using gradnorm metric WITHOUT preconditoning:
('a9a', 'SGD'):	lr = 2^-16,	loss = 0.425640,	gradnorm = 0.345316,	error = 0.201449
('a9a', 'Adam'):	lr = 2^-16,	loss = 0.384308,	gradnorm = 0.154609,	error = 0.179367
('a9a', 'SARAH'):	lr = 2^-10,	loss = 0.400594,	gradnorm = 0.000804,	error = 0.187749
('a9a', 'L-SVRG'):	lr = 2^-14,	loss = 0.484843,	gradnorm = 0.015331,	error = 0.220827
('w8a', 'SGD'):	lr = 2^-14,	loss = 0.197371,	gradnorm = 0.016864,	error = 0.062160
('w8a', 'Adam'):	lr = 2^-16,	loss = 0.162432,	gradnorm = 0.004597,	error = 0.061150
('w8a', 'SARAH'):	lr = 2^-10,	loss = 0.168572,	gradnorm = 0.000344,	error = 0.063900
('w8a', 'L-SVRG'):	lr = 2^-10,	loss = 0.179330,	gradnorm = 0.001024,	error = 0.062263
('rcv1', 'SGD'):	lr = 2^-12,	loss = 0.170175,	gradnorm = 0.035876,	error = 0.057769
('rcv1', 'Adam'):	lr = 2^-14,	loss = 0.056230,	gradnorm = 0.002812,	error = 0.015777
('rcv1', 'SARAH'):	lr = 2^-8,	loss = 0.105435,	gradnorm = 0.000867,	error = 0.034058
('r