In [88]:
import os
import glob
import pickle
import time
from math import log2
from itertools import cycle, product
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm
%matplotlib inline

# Settings

In [99]:
# Data logs root directory
LOG_DIR = "logs_fine"

# Loss function: either "logistic" regression, or "nonlinear" least squares
LOSSES = ("logistic", "nonlinear")
LOSS = "logistic"
# Adjust logdir
LOG_DIR = os.path.join(LOG_DIR, LOSS)

# The following should be the same as the one used in run_experiment.py
DATASETS = ("a9a", "a9a(-3,3)", "rcv1", "rcv1(-3,3)")
OPTIMIZERS = ("SARAH",)
T = 100  # Use 2xT used in run_experiment.py

# These are the metrics collected in the data logs
METRICS = ("loss", "gradnorm", "error")
METRIC = "error"  # choose metric

# These are aggregators for comparing multi-seed runs
AGGS = ("mean", "median")
AGG = "mean"  # choose aggregator

# Downsample this number of effective passes by averaging them
AVG_DOWNSAMPLE = 5

# These are the logs columns: effective passes + metrics
LOG_COLS = ["ep"] + list(METRICS)

# These are the hyperparameters of interest
ARG_COLS = ["lr", "alpha", "beta2"]

# Plots will be generated for this hyperparams/args setting.
# 'corrupt' should be the scale/suffix of the dataset as a string or 'none'.
FILTER_ARGS = {}
SETTINGS_STR = f"loss={LOSS},metric={METRIC}{',' if len(FILTER_ARGS) > 0 else ''}" + \
               ",".join(f"{k}={v}" for k,v in FILTER_ARGS.items())

# Ignore all runs containing 'any' of these hyperparams.
IGNORE_ARGS = {}#{"beta2": [None]}

# Force remove log files that are empty
FORCE_REMOVE_EMPTY_DATA = False

### Utility functions for loading data

In [90]:
def ignore(args_dict):
    return any(args_dict[arg] in IGNORE_ARGS[arg]
               for arg in IGNORE_ARGS.keys() if arg in args_dict)


def loaddata(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
    return data


def contain_dict(dict1, dict2):
    return all(dict1[k] == v for k, v in dict2.items() if k in dict1)

# Gathering data and finding best hyperparameters for each (optimizer, dataset) combination

In [93]:
REMOVE_EMPTY_DATA = False or FORCE_REMOVE_EMPTY_DATA


def unpack_args(fname):
    """
    Recover all args given file path.
    """
    args = {}
    # unpack path
    dirname, logname = os.path.split(fname)
    logdir, args["dataset"] = os.path.split(dirname)
    # parse args
    args["optimizer"], argstr = logname.split("(")
    argstr, _ = argstr.split(")")  # remove ').pkl'
    args_dict = {k:v for k,v in [s.split("=") for s in argstr.split(",")]}

    # Extract args
    if args["dataset"][-1] == ")":
        args["corrupt"] = args["dataset"][args["dataset"].index("("):]
    else:
        # It is very unlikely that the original dataset name will end with ')'
        args["corrupt"] = "none"

    if "seed" in args_dict:
        args["seed"] = args_dict["seed"]
    else:
        args["seed"] = 0

    args["BS"] = args_dict["BS"]
    args["lr"] = args_dict["lr"]
    if "precond" in args_dict:
        args["precond"] = args_dict["precond"].lower()
        args["beta2"] = args_dict["beta2"].lower()
        args["alpha"] = args_dict["alpha"].lower()
    else:
        args["precond"] = "none"
        args["alpha"] = "none"
        args["beta2"] = "none"

    return args


def get_logs(logdir, dataset, optimizer, **filter_args):
    """
    Return all logs in 'logdir' containing the filter hyperparams.
    Dataset name should contain feature scaling, if any
    e.g. 'dataset' or 'dataset(k_min,k_max)'.
    
    Returns the data in the log file and its arguments/hyperparams.
    """
    global REMOVE_EMPTY_DATA
    # Add
    if "corrupt" in filter_args and filter_args['corrupt'] != "none":
        # Add scale suffix to specify dataset    
        dataset += filter_args['corrupt']
    else:
        # No setting specified, use wildcard to match all suffixes
        dataset += "*"
    # Find all files matching this pattern
    for fname in glob.glob(f"{logdir}/{dataset}/{optimizer}(*).pkl"):
        exp_args = unpack_args(fname)
        # Skip if filter_args do not match args of this file
        if not contain_dict(exp_args, filter_args):
            continue
        # Load data
        data = loaddata(fname)
        # Handle empty data files
        if len(data) == 0:
            print(fname, "has no data!")
            if not REMOVE_EMPTY_DATA:
                if "y" == input("Remove empty log files in the future without asking? (y/n)"):
                    print("Will remove without asking.")
                    REMOVE_EMPTY_DATA = True
                else:
                    print("Will ask again before removing.")
            else:
                try:
                    print("Removing", fname)
                    os.remove(fname)
                except OSError as e:
                    print ("Error: %s - %s." % (e.filename, e.strerror))
            continue
        # @XXX: hack to correct wrong initial ep>0 for L-SVRG
        ep0 = data[0,0]
        if ep0 > 0.:
            data[:,0] -= ep0
        yield data, exp_args


# Gather data
all_dfs = {}
start_time = time.time()
for exp in product(DATASETS, OPTIMIZERS):
    exp_df = pd.DataFrame()
    # Get all log data given the experiment and filter args
    for data, args in get_logs(LOG_DIR, *exp, **FILTER_ARGS):
        if ignore(args):
            continue
        # Get experiment log data
        df = pd.DataFrame(data[:, :4], columns=LOG_COLS)
        # Get args of interest
        for col in ARG_COLS:
            df[col] = args[col]
        # Downsample by averaging metrics every AVG_DOWNSAMPLE epoch.
        df["ep"] = np.ceil(df["ep"] / AVG_DOWNSAMPLE) * AVG_DOWNSAMPLE
        df = df.groupby(["ep"] + ARG_COLS).mean().reset_index()
        # Get data up to the prespecified epoch T
        df = df[df["ep"] <= T]
        # @TODO: is this efficient?
        exp_df = exp_df.append(df, ignore_index=True)
    # Record all runs of exp in a single dataframe
    all_dfs[exp] = exp_df

    if len(exp_df) == 0:
        print("No log data found for this experiment!")
        print("- Experiment:", exp)
        print("- filter_args:", FILTER_ARGS)
        continue
data_gather_time = time.time() - start_time
print(f"Data frame lengths:")
for exp, df in all_dfs.items():
    print(f"{exp} -> {len(df)} data rows -> {len(df) // T} runs")
print(f"Took about {data_gather_time:.2f} seconds to gather all these data.")

logs_fine/logistic/a9a/SARAH(seed=0,BS=128,lr=1.52587890625e-05,precond=hutchinson,beta2=avg,alpha=1e-09).pkl has no data!


Remove empty log files in the future without asking? (y/n) y


Will remove without asking.
Data frame lengths:
('a9a', 'SARAH') -> 125015 data rows -> 1250 runs
('a9a(-3,3)', 'SARAH') -> 62052 data rows -> 620 runs
('rcv1', 'SARAH') -> 126126 data rows -> 1261 runs
('rcv1(-3,3)', 'SARAH') -> 63063 data rows -> 630 runs
Took about 80.81 seconds to gather all these data.


In [94]:
for i, (exp, df) in enumerate(all_dfs.items()):
    if i == 3: break
    print(exp)
    display(df)

('a9a', 'SARAH')


Unnamed: 0,ep,lr,alpha,beta2,loss,gradnorm,error
0,0.0,0.0625,0.1,0.6,6.931472e-01,0.032853,0.500000
1,5.0,0.0625,0.1,0.6,4.121946e-01,0.002434,0.186834
2,10.0,0.0625,0.1,0.6,3.488606e-01,0.000042,0.161359
3,15.0,0.0625,0.1,0.6,3.383863e-01,0.000016,0.158175
4,20.0,0.0625,0.1,0.6,3.338007e-01,0.000008,0.156812
...,...,...,...,...,...,...,...
125010,70.0,1,1e-11,0.3,2.430087e+15,3402.207322,0.331658
125011,75.0,1,1e-11,0.3,7.512257e+14,3504.809660,0.291622
125012,80.0,1,1e-11,0.3,7.117048e+14,1238.717116,0.243299
125013,85.0,1,1e-11,0.3,7.211711e+14,853.065245,0.211348


('a9a(-3,3)', 'SARAH')


Unnamed: 0,ep,lr,alpha,beta2,loss,gradnorm,error
0,0.0,0.0625,0.1,0.6,6.931472e-01,5089.893187,0.500000
1,5.0,0.0625,0.1,0.6,1.589830e+03,769.509470,0.212319
2,10.0,0.0625,0.1,0.6,8.723444e+03,4424.594719,0.300699
3,15.0,0.0625,0.1,0.6,3.783912e+03,503.284495,0.235460
4,20.0,0.0625,0.1,0.6,2.988904e+03,1231.502111,0.236289
...,...,...,...,...,...,...,...
62047,70.0,1,1e-11,0.3,2.430087e+15,3402.207322,0.331658
62048,75.0,1,1e-11,0.3,7.512257e+14,3504.809660,0.291622
62049,80.0,1,1e-11,0.3,7.117048e+14,1238.717116,0.243299
62050,85.0,1,1e-11,0.3,7.211711e+14,853.065245,0.211348


('rcv1', 'SARAH')


Unnamed: 0,ep,lr,alpha,beta2,loss,gradnorm,error
0,0.0,0.0625,0.1,0.6,6.931472e-01,36.643960,0.500000
1,5.0,0.0625,0.1,0.6,1.109013e+03,21.246745,0.299025
2,10.0,0.0625,0.1,0.6,6.093965e+02,5.993189,0.207456
3,15.0,0.0625,0.1,0.6,6.062551e+02,7.156348,0.188306
4,20.0,0.0625,0.1,0.6,4.839257e+02,5.331601,0.163196
...,...,...,...,...,...,...,...
126121,80.0,1,1e-11,0.3,3.240055e+08,0.000002,0.022780
126122,85.0,1,1e-11,0.3,3.086169e+08,0.000003,0.023131
126123,90.0,1,1e-11,0.3,2.696809e+08,0.000001,0.018305
126124,95.0,1,1e-11,0.3,2.608952e+08,0.000001,0.019196


## Get best hyperparams

In [100]:
for exp, df in all_dfs.items():
    if exp[1] == "Adam": continue
    alphas = set([] if "alpha" not in ARG_COLS else df["alpha"])
    betas = set([] if "beta2" not in ARG_COLS else df["beta2"])
    break

best_dfs = {}
best_dfs_alpha = {}
best_dfs_beta = {}
for exp in product(DATASETS, OPTIMIZERS):
    print("Finding best hyperparams for", exp)
    best_dfs_alpha[exp] = {}
    best_dfs_beta[exp] = {}
    # Get last metrics/performance (supposed to be epoch-smoothed for better results)
    exp_df = all_dfs[exp]
    if len(exp_df) == 0:
        continue
    max_ep = exp_df.groupby(ARG_COLS, sort=False)["ep"].transform(max)
    perf = exp_df[exp_df["ep"] == max_ep].drop("ep", axis=1)
    # Find the minimum aggregate metric (based on mean, median, etc.)
    def find_best_perf(perf):
        if AGG == "mean":
            agg_perf = perf.groupby(ARG_COLS).mean()
        elif AGG == "median":
            agg_perf = perf.groupby(ARG_COLS).median()
        # Get the aggregated perf that minimizes the chosen metric
        min_agg_perf = agg_perf[agg_perf[METRIC] == agg_perf.min()[METRIC]]
        return min_agg_perf.sort_index().iloc[[0]]
    # Find the best performances
    best_dfs[exp] = find_best_perf(perf)
    for alpha in alphas:
        best_dfs_alpha[exp][alpha] = find_best_perf(perf[perf["alpha"] == alpha])
    for beta in betas:
        best_dfs_beta[exp][beta] = find_best_perf(perf[perf["beta2"] == beta])

Finding best hyperparams for ('a9a', 'SARAH')
Finding best hyperparams for ('a9a(-3,3)', 'SARAH')
Finding best hyperparams for ('rcv1', 'SARAH')
Finding best hyperparams for ('rcv1(-3,3)', 'SARAH')


In [101]:
print("Best hyperparams for each optimizer on each dataset given the following setting:")
print(FILTER_ARGS)
print()
for exp, df in best_dfs.items():
    print(exp)
    for arg, val in zip(ARG_COLS, df.index[0]):
        if arg == "lr":
            val = "2**" + str(int(log2(float(val))))
        print(f"- {arg} = {val}")
    print()
    display(df)
    print()

Best hyperparams for each optimizer on each dataset given the following setting:
{}

('a9a', 'SARAH')
- lr = 2**-4
- alpha = 0.001
- beta2 = 0.9



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,gradnorm,error
lr,alpha,beta2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0.0625,0.001,0.9,0.322671,6.040132e-11,0.151058



('a9a(-3,3)', 'SARAH')
- lr = 2**-2
- alpha = 1e-09
- beta2 = 0.99999



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,gradnorm,error
lr,alpha,beta2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0.25,1e-09,0.99999,7.187007,0.007288,0.1517



('rcv1', 'SARAH')
- lr = 2**-10
- alpha = 1e-05
- beta2 = avg



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,gradnorm,error
lr,alpha,beta2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0.0009765625,1e-05,avg,422.173962,0.008904,0.005741



('rcv1(-3,3)', 'SARAH')
- lr = 2**-10
- alpha = 0.001
- beta2 = 0.999



Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,loss,gradnorm,error
lr,alpha,beta2,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
0.0009765625,0.001,0.999,0.046534,0.001897,0.005426





# Plotting

In [102]:
y_greek = {
    "loss": r"$P(w_t)$",
    "gradnorm": r"$||\nabla P(w_t)||^2$",
    "error": "error"
}
best_dfs_mode = {
    "alphas": best_dfs_alpha,
    "betas": best_dfs_beta,
}
mode_greek = {
    "alphas": r"$\alpha$",
    "betas": r"$\beta$",
    "lrs": r"$\eta$"
}

In [103]:
y = METRIC

start_time = time.time()
# Plot 3 rows each one showing some performance metric,
# where the columns are the dataset on which the optim is run.
fig, axes = plt.subplots(1, len(DATASETS))
fig.set_size_inches(5 * len(DATASETS), 5)
# plt.suptitle(rf"title")
for j, dataset in enumerate(DATASETS):
    for optimizer in OPTIMIZERS:
        exp = (dataset, optimizer)
        if exp not in best_dfs_alpha or exp not in best_dfs_beta:
            continue
        # Get hyperparams of best performance of 'optimizer' on 'dataset'
        # beta_0 = best_dfs_beta[exp]["none"][["gradnorm", "error"]].astype(float)
        beta_avg = best_dfs_beta[exp]["avg"][["gradnorm", "error"]].astype(float)
        beta_df = pd.DataFrame()
        alpha_df = pd.DataFrame()
        for k, v in best_dfs_beta[exp].items():
            if k == "none": continue
            if k == "avg": continue
            row = v.reset_index()[["beta2", "gradnorm", "error"]].astype(float)
            row["1-beta2"] = 1 - row["beta2"]
            beta_df = beta_df.append(row, ignore_index=True)
        for k, v in best_dfs_alpha[exp].items():
            row = v.reset_index()[["alpha", "gradnorm", "error"]].astype(float)
            alpha_df = alpha_df.append(row, ignore_index=True)

        print(f"Plotting lines for {exp}...")
        sns.lineplot(x="alpha", y=y, label=r"$\alpha$", color="tab:blue", ax=axes[j], data=alpha_df, err_style="bars")
        sns.scatterplot(x="alpha", y=y, color="tab:blue", ax=axes[j], data=alpha_df)
        sns.lineplot(x="1-beta2", y=y, label=r"$1-\beta$", color="tab:orange", ax=axes[j], data=beta_df, err_style="bars")
        sns.scatterplot(x="1-beta2", y=y, color="tab:orange", ax=axes[j], data=beta_df)
        # axes[j].axhline(y=beta_0[y].item(), label=r"$1-1/(t+1)$", color="tab:orange", linestyle='--')
        axes[j].axhline(y=beta_avg[y].item(), label=r"$1-1/(t+1)$", color="tab:red", linestyle='--')
        axes[j].set_xscale("log")
        axes[j].set_yscale("log")
        axes[j].set_title(fr"$\tt {optimizer}({dataset})$")
        axes[j].set_ylabel(y_greek[y])
        axes[j].set_xlabel("Parameter")
        axes[j].legend() 
fig.tight_layout()

# Create a string out of filter args and save figure
plt.savefig(f"plots/compare_{y}({SETTINGS_STR}).pdf")
plt.close()
plot_best_time = time.time() - start_time
print(f"Took about {plot_best_time:.2f} seconds to create this plot.")

Plotting lines for ('a9a', 'SARAH')...
Plotting lines for ('a9a(-3,3)', 'SARAH')...
Plotting lines for ('rcv1', 'SARAH')...
Plotting lines for ('rcv1(-3,3)', 'SARAH')...
Took about 0.98 seconds to create this plot.
