In [56]:
import os
import glob
import pickle
import time
from math import log2
from itertools import cycle, product
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm
%matplotlib inline

# Settings

In [265]:
# Data logs root directory
LOG_DIR = "logs2"

# Loss function: either "logistic" regression, or nonlinear least squares ('nllsq')
LOSSES = ("logistic", "nllsq")
LOSS = "logistic"
# Adjust logdir
LOG_DIR = os.path.join(LOG_DIR, LOSS)

# The following should be the same as the one used in run_experiment.py
DATASETS = ("a9a", "a9a(-3,3)", "rcv1", "rcv1(-3,3)")
OPTIMIZERS = ("SARAH",)
T = 100  # Use 2xT used in run_experiment.py

# These are the metrics collected in the data logs
METRICS = ("loss", "gradnorm", "error")
METRIC = "error"  # choose metric

# These are aggregators for comparing multi-seed runs
AGGS = ("mean", "median")
AGG = "mean"  # choose aggregator

# Downsample this number of effective passes by averaging them
AVG_DOWNSAMPLE = 5

# These are the logs columns: effective passes + metrics
LOG_COLS = ["ep"] + list(METRICS)

# These are the hyperparameters of interest
ARG_COLS = ["seed", "lr", "alpha", "beta2"]

# Plots will be generated for this hyperparams/args setting.
# 'corrupt' should be the scale/suffix of the dataset as a string or 'none'.
FILTER_ARGS = {}
SETTINGS_STR = f"loss={LOSS},metric={METRIC}{',' if len(FILTER_ARGS) > 0 else ''}" + \
               ",".join(f"{k}={v}" for k,v in FILTER_ARGS.items())

# Ignore all runs containing 'any' of these hyperparams.
IGNORE_ARGS = {}#{"beta2": [None]}

# Force remove log files that are empty
FORCE_REMOVE_EMPTY_DATA = False

# Aspect ratio and height of subplots
ASPECT = 4. / 3.
HEIGHT = 3.

### Utility functions for loading data

In [258]:
def ignore(args_dict):
    return any(args_dict[arg] in IGNORE_ARGS[arg]
               for arg in IGNORE_ARGS.keys() if arg in args_dict)


def loaddata(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
    return data


def contain_dict(dict1, dict2):
    return all(dict1[k] == v for k, v in dict2.items() if k in dict1)

# Gathering data and finding best hyperparameters for each (optimizer, dataset) combination

In [259]:
REMOVE_EMPTY_DATA = False or FORCE_REMOVE_EMPTY_DATA


def unpack_args(fname):
    """
    Recover all args given file path.
    """
    args = {}
    # unpack path
    dirname, logname = os.path.split(fname)
    logdir, args["dataset"] = os.path.split(dirname)
    # parse args
    args["optimizer"], argstr = logname.split("(")
    argstr, _ = argstr.split(")")  # remove ').pkl'
    args_dict = {k:v for k,v in [s.split("=") for s in argstr.split(",")]}

    # Extract args
    if args["dataset"][-1] == ")":
        args["corrupt"] = args["dataset"][args["dataset"].index("("):]
    else:
        # It is very unlikely that the original dataset name will end with ')'
        args["corrupt"] = "none"

    if "seed" in args_dict:
        args["seed"] = args_dict["seed"]
    else:
        args["seed"] = 0

    args["BS"] = args_dict["BS"]
    args["lr"] = args_dict["lr"]
    if "precond" in args_dict:
        args["precond"] = args_dict["precond"].lower()
        args["beta2"] = args_dict["beta2"].lower()
        args["alpha"] = args_dict["alpha"].lower()
    else:
        args["precond"] = "none"
        args["alpha"] = "none"
        args["beta2"] = "none"

    return args


def get_logs(logdir, dataset, optimizer, **filter_args):
    """
    Return all logs in 'logdir' containing the filter hyperparams.
    Dataset name should contain feature scaling, if any
    e.g. 'dataset' or 'dataset(k_min,k_max)'.
    
    Returns the data in the log file and its arguments/hyperparams.
    """
    global REMOVE_EMPTY_DATA
    # Add
    if "corrupt" in filter_args and filter_args['corrupt'] != "none":
        # Add scale suffix to specify dataset    
        dataset += filter_args['corrupt']
    else:
        # No setting specified, use wildcard to match all suffixes
        dataset += "*"
    # Find all files matching this pattern
    for fname in glob.glob(f"{logdir}/{dataset}/{optimizer}(*).pkl"):
        exp_args = unpack_args(fname)
        # Skip if filter_args do not match args of this file
        if not contain_dict(exp_args, filter_args):
            continue
        # Load data
        data = loaddata(fname)
        # Handle empty data files
        if len(data) == 0:
            print(fname, "has no data!")
            if not REMOVE_EMPTY_DATA:
                if "y" == input("Remove empty log files in the future without asking? (y/n)"):
                    print("Will remove without asking.")
                    REMOVE_EMPTY_DATA = True
                else:
                    print("Will ask again before removing.")
            else:
                try:
                    print("Removing", fname)
                    os.remove(fname)
                except OSError as e:
                    print ("Error: %s - %s." % (e.filename, e.strerror))
            continue
        # @XXX: hack to correct wrong initial ep>0 for L-SVRG
        ep0 = data[0,0]
        if ep0 > 0.:
            data[:,0] -= ep0
        yield data, exp_args


# Gather data
all_dfs = {}
start_time = time.time()
for exp in product(DATASETS, OPTIMIZERS):
    exp_df = pd.DataFrame()
    # Get all log data given the experiment and filter args
    for data, args in get_logs(LOG_DIR, *exp, **FILTER_ARGS):
        if ignore(args):
            continue
        # Get experiment log data
        df = pd.DataFrame(data[:, :4], columns=LOG_COLS)
        # Get args of interest
        for col in ARG_COLS:
            df[col] = args[col]
        # Downsample by averaging metrics every AVG_DOWNSAMPLE epoch.
        df["ep"] = np.ceil(df["ep"] / AVG_DOWNSAMPLE) * AVG_DOWNSAMPLE
        df = df.groupby(["ep"] + ARG_COLS).mean().reset_index()
        # Get data up to the prespecified epoch T
        df = df[df["ep"] <= T]
        # @TODO: is this efficient?
        exp_df = exp_df.append(df, ignore_index=True)
    # Record all runs of exp in a single dataframe
    all_dfs[exp] = exp_df

    if len(exp_df) == 0:
        print("No log data found for this experiment!")
        print("- Experiment:", exp)
        print("- filter_args:", FILTER_ARGS)
        continue
data_gather_time = time.time() - start_time
print(f"Data frame lengths:")
for exp, df in all_dfs.items():
    print(f"{exp} -> {len(df)} data rows -> {len(df) // T} runs")
print(f"Took about {data_gather_time:.2f} seconds to gather all these data.")

logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.999999999,alpha=1e-09).pkl has no data!


Remove empty log files in the future without asking? (y/n) y


Will remove without asking.
logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.999999999,alpha=1e-05).pkl has no data!
Removing logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.999999999,alpha=1e-05).pkl
logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.99999999999,alpha=1.0).pkl has no data!
Removing logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.99999999999,alpha=1.0).pkl
logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.99999999999,alpha=1e-05).pkl has no data!
Removing logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.99999999999,alpha=1e-05).pkl
logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.99999999999,alpha=1e-07).pkl has no data!
Removing logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,precond=hutchinson,beta2=0.99999999999,alpha=1e-07).pkl
logs2/nllsq/rcv1/SARAH(seed=9,BS=128,lr=0.015625,pre

In [260]:
for i, (exp, df) in enumerate(all_dfs.items()):
    if i == 3: break
    print(exp)
    display(df)

('a9a', 'SARAH')


Unnamed: 0,ep,seed,lr,alpha,beta2,loss,gradnorm,error
0,0.0,8,16,0.1,0.9999999,0.250000,8.213275e-03,0.500000
1,5.0,8,16,0.1,0.9999999,0.240757,4.364626e-09,0.240783
2,10.0,8,16,0.1,0.9999999,0.240779,8.803775e-23,0.240779
3,15.0,8,16,0.1,0.9999999,0.240779,8.803740e-23,0.240779
4,20.0,8,16,0.1,0.9999999,0.240779,8.803707e-23,0.240779
...,...,...,...,...,...,...,...,...
355517,80.0,9,1.52587890625e-05,1e-05,0.999999999,0.153125,1.271243e+01,0.225577
355518,85.0,9,1.52587890625e-05,1e-05,0.999999999,0.151963,1.147413e+01,0.224798
355519,90.0,9,1.52587890625e-05,1e-05,0.999999999,0.150900,1.045582e+01,0.222679
355520,95.0,9,1.52587890625e-05,1e-05,0.999999999,0.149884,9.572498e+00,0.220825


('a9a(-3,3)', 'SARAH')


Unnamed: 0,ep,seed,lr,alpha,beta2,loss,gradnorm,error
0,0.0,8,16,0.1,0.9999999,0.250000,2.102320e+02,0.500000
1,5.0,8,16,0.1,0.9999999,0.703629,3.217698e-06,0.703635
2,10.0,8,16,0.1,0.9999999,0.739566,3.460306e-17,0.739566
3,15.0,8,16,0.1,0.9999999,0.739566,3.456410e-17,0.739566
4,20.0,8,16,0.1,0.9999999,0.739566,3.452570e-17,0.739566
...,...,...,...,...,...,...,...,...
177665,80.0,9,1.52587890625e-05,1e-05,0.999999999,0.153125,1.271243e+01,0.225577
177666,85.0,9,1.52587890625e-05,1e-05,0.999999999,0.151963,1.147413e+01,0.224798
177667,90.0,9,1.52587890625e-05,1e-05,0.999999999,0.150900,1.045582e+01,0.222679
177668,95.0,9,1.52587890625e-05,1e-05,0.999999999,0.149884,9.572498e+00,0.220825


('rcv1', 'SARAH')


Unnamed: 0,ep,seed,lr,alpha,beta2,loss,gradnorm,error
0,0.0,8,16,0.1,0.9999999,0.250000,7.943554e+00,0.500000
1,5.0,8,16,0.1,0.9999999,0.377128,5.606938e-05,0.377253
2,10.0,8,16,0.1,0.9999999,0.397638,4.230986e-08,0.397641
3,15.0,8,16,0.1,0.9999999,0.397688,3.001461e-15,0.397688
4,20.0,8,16,0.1,0.9999999,0.397688,3.274446e-15,0.397688
...,...,...,...,...,...,...,...,...
329611,80.0,9,1.52587890625e-05,1e-05,0.999999999,0.066414,7.973784e-06,0.038621
329612,85.0,9,1.52587890625e-05,1e-05,0.999999999,0.064311,7.197242e-06,0.037933
329613,90.0,9,1.52587890625e-05,1e-05,0.999999999,0.062434,6.546938e-06,0.037153
329614,95.0,9,1.52587890625e-05,1e-05,0.999999999,0.060681,5.975403e-06,0.036602


## Get best hyperparams

In [266]:
for exp, df in all_dfs.items():
    if exp[1] == "Adam": continue
    alphas = set([] if "alpha" not in ARG_COLS else df["alpha"])
    betas = set([] if "beta2" not in ARG_COLS else df["beta2"])
    break

best_dfs = {}
best_dfs_alpha = {}
best_dfs_beta = {}
for exp in product(DATASETS, OPTIMIZERS):
    print("Finding best hyperparams for", exp)
    best_dfs_alpha[exp] = {}
    best_dfs_beta[exp] = {}
    # Get last metrics/performance (supposed to be epoch-smoothed for better results)
    exp_df = all_dfs[exp]
    if len(exp_df) == 0:
        continue

    args_fix = [arg for arg in ARG_COLS if arg != "seed"]
    max_ep = exp_df.groupby(args_fix, sort=False)["ep"].transform(max)
    perf = exp_df[exp_df["ep"] == max_ep].drop("ep", axis=1)
    # Find the minimum aggregate metric (based on mean, median, etc.)
    def find_best_perf(perf):
        if AGG == "mean":
            agg_perf = perf.groupby(args_fix).mean()
        elif AGG == "median":
            agg_perf = perf.groupby(args_fix).median()
        # Get the aggregated perf that minimizes the chosen metric
        min_agg_perf = agg_perf[agg_perf[METRIC] == agg_perf.min()[METRIC]]
        best_perfs_args = agg_perf[agg_perf[METRIC] == agg_perf.min()[METRIC]].index
        best_perfs = perf.set_index(args_fix).loc[best_perfs_args[:1]]
        return best_perfs
    # Find the best performances
    best_dfs[exp] = find_best_perf(perf)
    for alpha in alphas:
        best_dfs_alpha[exp][alpha] = find_best_perf(perf[perf["alpha"] == alpha])
    for beta in betas:
        best_dfs_beta[exp][beta] = find_best_perf(perf[perf["beta2"] == beta])

Finding best hyperparams for ('a9a', 'SARAH')
Finding best hyperparams for ('a9a(-3,3)', 'SARAH')
Finding best hyperparams for ('rcv1', 'SARAH')
Finding best hyperparams for ('rcv1(-3,3)', 'SARAH')


In [267]:
print("Best hyperparams for each optimizer on each dataset given the following setting:")
print(FILTER_ARGS)
print()
for exp, df in best_dfs.items():
    print(exp)
    for arg, val in zip(ARG_COLS, df.index[0]):
        if arg == "lr":
            val = "2**" + str(int(log2(float(val))))
        print(f"- {arg} = {val}")
    print()
    display(df)
    print()

Best hyperparams for each optimizer on each dataset given the following setting:
{}

('a9a', 'SARAH')
- seed = 0.00390625
- lr = 2**-36
- alpha = 0.1



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,seed,loss,gradnorm,error
lr,alpha,beta2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0.00390625,1e-11,0.1,8,0.687172,0.0,0.687172
0.00390625,1e-11,0.1,3,0.240779,0.0,0.240779
0.00390625,1e-11,0.1,7,0.247044,0.0,0.247044
0.00390625,1e-11,0.1,5,0.57747,0.0,0.57747
0.00390625,1e-11,0.1,1,0.75919,0.0,0.75919
0.00390625,1e-11,0.1,0,0.75919,0.0,0.75919
0.00390625,1e-11,0.1,4,0.266884,0.0,0.266884
0.00390625,1e-11,0.1,6,0.25678,0.0,0.25678
0.00390625,1e-11,0.1,2,0.640736,0.0,0.640736
0.00390625,1e-11,0.1,9,0.24081,0.0,0.24081



('a9a(-3,3)', 'SARAH')
- seed = 0.00390625
- lr = 2**-29
- alpha = 0.1



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,seed,loss,gradnorm,error
lr,alpha,beta2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0.00390625,1e-09,0.1,2,0.240011,0.0,0.240011
0.00390625,1e-09,0.1,6,0.347287,0.0,0.347287
0.00390625,1e-09,0.1,4,0.24081,0.0,0.24081
0.00390625,1e-09,0.1,0,0.664722,0.0,0.664722
0.00390625,1e-09,0.1,1,0.261724,0.0,0.261724
0.00390625,1e-09,0.1,5,0.331132,0.0,0.331132
0.00390625,1e-09,0.1,8,0.265532,0.0,0.265532
0.00390625,1e-09,0.1,7,0.242038,0.0,0.242038
0.00390625,1e-09,0.1,3,0.628021,0.0,0.628021



('rcv1', 'SARAH')
- seed = 0.00390625
- lr = 2**-36
- alpha = 0.1



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,seed,loss,gradnorm,error
lr,alpha,beta2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0.00390625,1e-11,0.1,8,0.452524,0.0,0.452524
0.00390625,1e-11,0.1,3,0.44956,0.0,0.44956
0.00390625,1e-11,0.1,7,0.465517,0.0,0.465517
0.00390625,1e-11,0.1,5,0.441063,0.0,0.441063
0.00390625,1e-11,0.1,1,0.563284,0.0,0.563284
0.00390625,1e-11,0.1,0,0.511511,0.0,0.511511
0.00390625,1e-11,0.1,4,0.497283,0.0,0.497283
0.00390625,1e-11,0.1,6,0.516155,0.0,0.516155
0.00390625,1e-11,0.1,2,0.435925,0.0,0.435925
0.00390625,1e-11,0.1,8,0.454896,0.0,0.454896



('rcv1(-3,3)', 'SARAH')
- seed = 0.000244140625
- lr = 2**-29
- alpha = 0.1



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,seed,loss,gradnorm,error
lr,alpha,beta2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
0.000244140625,1e-09,0.1,0,0.440372,0.0,0.440372
0.000244140625,1e-09,0.1,6,0.519069,0.0,0.519069
0.000244140625,1e-09,0.1,1,0.492787,0.0,0.492787
0.000244140625,1e-09,0.1,7,0.443484,0.0,0.443484
0.000244140625,1e-09,0.1,2,0.497629,0.0,0.497629
0.000244140625,1e-09,0.1,4,0.475546,0.0,0.475546
0.000244140625,1e-09,0.1,3,0.424662,0.0,0.424662
0.000244140625,1e-09,0.1,5,0.424464,0.0,0.424464
0.000244140625,1e-09,0.1,8,0.483648,0.0,0.483648





# Plotting

In [268]:
y_greek = {
    "loss": r"$P(w_t)$",
    "gradnorm": r"$||\nabla P(w_t)||^2$",
    "error": "error"
}
best_dfs_mode = {
    "alphas": best_dfs_alpha,
    "betas": best_dfs_beta,
}
mode_greek = {
    "alphas": r"$\alpha$",
    "betas": r"$\beta$",
    "lrs": r"$\eta$"
}

In [269]:
y = METRIC

start_time = time.time()
# Plot 3 rows each one showing some performance metric,
# where the columns are the dataset on which the optim is run.
fig, axes = plt.subplots(1, len(DATASETS))
fig.set_size_inches(5 * len(DATASETS), 5)
# plt.suptitle(rf"title")
for j, dataset in enumerate(DATASETS):
    for optimizer in OPTIMIZERS:
        exp = (dataset, optimizer)
        if exp not in best_dfs_alpha or exp not in best_dfs_beta:
            continue
        # Get hyperparams of best performance of 'optimizer' on 'dataset'
        # beta_0 = best_dfs_beta[exp]["none"][["gradnorm", "error"]].astype(float)
        beta_avg = best_dfs_beta[exp]["avg"][["gradnorm", "error"]].astype(float)
        beta_df = pd.DataFrame()
        alpha_df = pd.DataFrame()
        for k, v in best_dfs_beta[exp].items():
            if k == "none": continue
            if k == "avg": continue
            row = v.reset_index()[["beta2", "gradnorm", "error"]].astype(float)
            row["1-beta2"] = 1 - row["beta2"]
            beta_df = beta_df.append(row, ignore_index=True)
        for k, v in best_dfs_alpha[exp].items():
            row = v.reset_index()[["alpha", "gradnorm", "error"]].astype(float)
            alpha_df = alpha_df.append(row, ignore_index=True)

        print(f"Plotting lines for {exp}...")
        sns.lineplot(x="alpha", y=y, label=r"$\alpha$", color="tab:blue",
                     ax=axes[j], data=alpha_df, err_style="bars", marker='o')
        sns.lineplot(x="1-beta2", y=y, label=r"$1-\beta$", color="tab:orange",
                     ax=axes[j], data=beta_df, err_style="bars", marker='o')
        # All of this just to plot the beta avg line with ci interval
        beta_avg_line0 = beta_avg[[y]]
        beta_avg_line0["x"] = 0
        beta_avg_line1 = beta_avg[[y]]
        beta_avg_line1["x"] = 1
        beta_avg_line = beta_avg_line0.append(beta_avg_line1).reset_index()
        sns.lineplot(y=y, x='x', label=r"$1-1/(t+1)$", color="tab:red", ax=axes[j], data=beta_avg_line, linestyle='--')
        axes[j].set_xscale("log")
        axes[j].set_yscale("log")
        axes[j].set_title(fr"$\tt {optimizer}({dataset})$")
        axes[j].set_ylabel(y_greek[y])
        axes[j].set_xlabel("Parameter")
        axes[j].legend() 
fig.tight_layout()

# Create a string out of filter args and save figure
plt.savefig(f"plots/compare_{y}({SETTINGS_STR}).pdf")
plt.close()
plot_best_time = time.time() - start_time
print(f"Took about {plot_best_time:.2f} seconds to create this plot.")

Plotting lines for ('a9a', 'SARAH')...
Plotting lines for ('a9a(-3,3)', 'SARAH')...
Plotting lines for ('rcv1', 'SARAH')...
Plotting lines for ('rcv1(-3,3)', 'SARAH')...
Took about 2.20 seconds to create this plot.
