In [1]:
import os
import glob
import pickle
import time
from math import log2
from itertools import cycle, product
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm
%matplotlib inline

# Settings

In [2]:
# Data logs root directory
LOG_DIR = "logs_torch"

# Loss function: either "logistic" regression, or nonlinear least squares ('nllsq')
LOSSES = ("logistic", "nllsq")
LOSS = "cross_entropy"
# Adjust logdir
LOG_DIR = os.path.join(LOG_DIR, LOSS)

# The following should be the same as the one used in run_experiment.py
DATASETS = ("mnist",)
dataset = "mnist"
OPTIMIZERS = ("L-SVRG", "Adam")
T = 25  # Use 2xT used in run_experiment.py

# These are the metrics collected in the data logs
METRICS = ("loss", "gradnorm", "error")
METRIC = "error"  # choose metric

# These are aggregators for comparing multi-seed runs
AGGS = ("mean", "median")
AGG = "mean"  # choose aggregator

# Downsample this number of effective passes by averaging them
AVG_DOWNSAMPLE = 2

# These are the logs columns: effective passes + metrics
LOG_COLS = ["ep"] + list(METRICS)

# These are the hyperparameters of interest
ARG_COLS = ["lr", "alpha", "beta2", "precond"]

# Plots will be generated for this hyperparams/args setting.
# 'corrupt' should be the scale/suffix of the dataset as a string or 'none'.
FILTER_ARGS = {
}
SETTINGS_STR = f"loss={LOSS},metric={METRIC}," + \
               ",".join(f"{k}={v}" for k,v in FILTER_ARGS.items())

# Ignore all runs containing 'any' of these hyperparams.
IGNORE_ARGS = {
    "alpha": (1e-11,),
    "weight_decay": (0.1,),
}

# Force remove log files that are empty
FORCE_REMOVE_EMPTY_DATA = False

# Aspect ratio and height of subplots
ASPECT = 4. / 3.
HEIGHT = 3.
HEIGHT_LARGE = 4.
LEGEND_FONTSIZE = "x-small"
LEGEND_LOC = "upper right"

### Utility functions for loading data

In [3]:
def ignore(args_dict):
    return any(args_dict[arg] in map(str, IGNORE_ARGS[arg])
               for arg in IGNORE_ARGS.keys() if arg in args_dict)


def loaddata(fname):
    with open(fname, 'rb') as f:
        data = pickle.load(f)
    return data


def contain_dict(dict1, dict2):
    return all(dict1[k] == v for k, v in dict2.items() if k in dict1)

# Gathering data and finding best hyperparameters for each (optimizer, dataset) combination

In [4]:
REMOVE_EMPTY_DATA = False or FORCE_REMOVE_EMPTY_DATA


def unpack_args(fname):
    """
    Recover all args given file path.
    """
    args = {}
    # unpack path
    dirname, logname = os.path.split(fname)
    logdir, args["dataset"] = os.path.split(dirname)
    # parse args
    args["optimizer"], argstr = logname.split("(")
    argstr, _ = argstr.split(")")  # remove ').pkl'
    args_dict = {k:v for k,v in [s.split("=") for s in argstr.split(",")]}

    # Extract args
    if args["dataset"][-1] == ")":
        args["corrupt"] = args["dataset"][args["dataset"].index("("):]
    else:
        # It is very unlikely that the original dataset name will end with ')'
        args["corrupt"] = "none"

    if "seed" in args_dict:
        args["seed"] = args_dict["seed"]
    else:
        args["seed"] = '0'

    args["BS"] = args_dict["batch_size"]
    args["lr"] = args_dict["lr"]
    if "weight_decay" in args_dict:
        args["weight_decay"] = args_dict["weight_decay"]
    else:
        args["weight_decay"] = '0'
    if "lr_decay" in args_dict:
        args["lr_decay"] = args_dict["lr_decay"]
    else:
        args["lr_decay"] = '0'
    if "p" in args_dict:
        args["p"] = args_dict["p"]
    if "precond" in args_dict:
        args["precond"] = args_dict["precond"]
        args["beta2"] = args_dict["beta2"]
        args["alpha"] = args_dict["alpha"]
    else:
        args["precond"] = "none"
        args["alpha"] = "none"
        args["beta2"] = "none"

    if args["optimizer"] == "Adam":
        args["beta1"] = args_dict["beta1"] if "beta1" in args_dict else '0.9'
        args["beta2"] = args_dict["beta2"]
        args["eps"] = args_dict["eps"] if "eps" in args_dict else '1e-8'

    return args


def get_logs(logdir, dataset, optimizer, **filter_args):
    """
    Return all logs in 'logdir' containing the filter hyperparams.
    Dataset name should contain feature scaling, if any
    e.g. 'dataset' or 'dataset(k_min,k_max)'.
    
    Returns the data in the log file and its arguments/hyperparams.
    """
    global REMOVE_EMPTY_DATA
    # Add
    if "corrupt" in filter_args and filter_args['corrupt'] != "none":
        # Add scale suffix to specify dataset    
        dataset += filter_args['corrupt']
    else:
        # No setting specified, use wildcard to match all suffixes
        dataset += "*"
    # Find all files matching this pattern
    for fname in glob.glob(f"{logdir}/{dataset}/{optimizer}(*).pkl"):
        exp_args = unpack_args(fname)
        # Skip if filter_args do not match args of this file
        if not contain_dict(exp_args, filter_args):
            continue
        # Load data
        data = loaddata(fname)
        # Handle empty data files
        if len(data) == 0:
            print(fname, "has no data!")
            if not REMOVE_EMPTY_DATA:
                if "y" == input("Remove empty log files in the future without asking? (y/n)"):
                    print("Will remove without asking.")
                    REMOVE_EMPTY_DATA = True
                else:
                    print("Will ask again before removing.")
            else:
                try:
                    print("Removing", fname)
                    os.remove(fname)
                except OSError as e:
                    print ("Error: %s - %s." % (e.filename, e.strerror))
            continue
        # @XXX: hack to correct wrong initial ep>0 for L-SVRG
        data = np.array(data)
        ep0 = data[0,0]
        if ep0 > 0.:
            data[:,0] -= ep0
        yield data, exp_args

        
# Gather data
all_dfs = {}
start_time = time.time()
for exp in product(DATASETS, OPTIMIZERS):
    exp_df = pd.DataFrame()
    # Get all log data given the experiment and filter args
    for data, args in get_logs(LOG_DIR, *exp, **FILTER_ARGS):
        if ignore(args):
            continue
        # Get experiment log data
        df = pd.DataFrame(data[:, :4], columns=LOG_COLS)
        # Get args of interest
        for col in ARG_COLS:
            df[col] = args[col]
        # Downsample by averaging metrics every AVG_DOWNSAMPLE epoch.
        df["ep"] = np.ceil(df["ep"] / AVG_DOWNSAMPLE) * AVG_DOWNSAMPLE
        df = df.groupby(["ep"] + ARG_COLS).mean().reset_index()
        # Get data up to the prespecified epoch T
        df = df[df["ep"] <= T]
        # @TODO: is this efficient?
        exp_df = exp_df.append(df, ignore_index=True)
    # Record all runs of exp in a single dataframe
    all_dfs[exp] = exp_df

    if len(exp_df) == 0:
        print("No log data found for this experiment!")
        print("- Experiment:", exp)
        print("- filter_args:", FILTER_ARGS)
        continue
data_gather_time = time.time() - start_time
print(f"Data frame lengths:")
for exp, df in all_dfs.items():
    print(f"{exp} -> {len(df)} data rows -> {len(df) // T} runs")
print(f"Took about {data_gather_time:.2f} seconds to gather all these data.")

logs_torch/cross_entropy/mnist/L-SVRG(seed=2,batch_size=128,lr=0.0625,p=0.999,precond=hutchinson,beta2=avg,alpha=0.001,warmup=100).pkl has no data!


Remove empty log files in the future without asking? (y/n) y


Will remove without asking.
logs_torch/cross_entropy/mnist/L-SVRG(seed=0,batch_size=128,lr=0.0625,p=0.999,precond=hutchinson,beta2=avg,alpha=super,warmup=100).pkl has no data!
Removing logs_torch/cross_entropy/mnist/L-SVRG(seed=0,batch_size=128,lr=0.0625,p=0.999,precond=hutchinson,beta2=avg,alpha=super,warmup=100).pkl
Data frame lengths:
('mnist', 'L-SVRG') -> 6382 data rows -> 255 runs
('mnist', 'Adam') -> 1378 data rows -> 55 runs
Took about 28.28 seconds to gather all these data.


In [5]:
for i, (exp, df) in enumerate(all_dfs.items()):
    if i == 3: break
    print(exp)
    display(df)

('mnist', 'L-SVRG')


Unnamed: 0,ep,lr,alpha,beta2,precond,loss,gradnorm,error
0,0.0,6.103515625e-05,0.001,0.99,hutchinson,2.303674,0.015701,0.899200
1,2.0,6.103515625e-05,0.001,0.99,hutchinson,2.297054,0.017216,0.890300
2,4.0,6.103515625e-05,0.001,0.99,hutchinson,1.187880,1.103850,0.329450
3,6.0,6.103515625e-05,0.001,0.99,hutchinson,0.453627,0.456129,0.125820
4,8.0,6.103515625e-05,0.001,0.99,hutchinson,0.341007,0.254002,0.099100
...,...,...,...,...,...,...,...,...
6377,16.0,0.015625,1e-07,0.99,hutchinson,5216.137327,0.921152,0.910800
6378,18.0,0.015625,1e-07,0.99,hutchinson,3748.250835,0.418313,0.905743
6379,20.0,0.015625,1e-07,0.99,hutchinson,2757.735963,0.365667,0.899000
6380,22.0,0.015625,1e-07,0.99,hutchinson,1977.116441,0.241248,0.899911


('mnist', 'Adam')


Unnamed: 0,ep,lr,alpha,beta2,precond,loss,gradnorm,error
0,0.0,0.015625,none,0.99,none,2.303674,0.015701,0.89920
1,2.0,0.015625,none,0.99,none,0.137447,1.365845,0.04301
2,4.0,0.015625,none,0.99,none,0.117942,1.311841,0.03209
3,6.0,0.015625,none,0.99,none,0.128118,1.331316,0.03217
4,8.0,0.015625,none,0.99,none,0.124661,1.147403,0.02900
...,...,...,...,...,...,...,...,...
1373,16.0,0.0009765625,none,0.99,none,0.037298,0.528722,0.00934
1374,18.0,0.0009765625,none,0.99,none,0.036511,0.123733,0.00888
1375,20.0,0.0009765625,none,0.99,none,0.047172,1.565055,0.00989
1376,22.0,0.0009765625,none,0.99,none,0.042677,0.567044,0.00926


## Get best hyperparams

In [6]:
preconds = [] if "precond" not in ARG_COLS else ["none", "hutchinson"]
for exp, df in all_dfs.items():
    if exp[1] == "Adam": continue
    alphas = set([] if "alpha" not in ARG_COLS else df["alpha"])
    betas = set([] if "beta2" not in ARG_COLS else df["beta2"])
    lrs = set([] if "lr" not in ARG_COLS else df["lr"])
    break

best_dfs = {}
best_dfs_alpha = {}
best_dfs_beta = {}
best_dfs_lr = {}
best_dfs_precond = {}
for exp in product(DATASETS, OPTIMIZERS):
    print("Finding best hyperparams for", exp)
    best_dfs_alpha[exp] = {}
    best_dfs_beta[exp] = {}
    best_dfs_lr[exp] = {}
    best_dfs_precond[exp] = {}
    # Get last metrics/performance (supposed to be epoch-smoothed for better results)
    exp_df = all_dfs[exp]
    max_ep = exp_df.groupby(ARG_COLS, sort=False)["ep"].transform(max)
    perf = exp_df[exp_df["ep"] == max_ep].drop("ep", axis=1)
    # Find the minimum aggregate metric (based on mean, median, etc.)
    def find_best_hyperparams(perf):
        if AGG == "mean":
            agg_perf = perf.groupby(ARG_COLS).mean()
        elif AGG == "median":
            agg_perf = perf.groupby(ARG_COLS).median()
        # Get the aggregated perf that minimizes the chosen metric
        min_agg_perf = agg_perf[agg_perf[METRIC] == agg_perf.min()[METRIC]]
        return min_agg_perf.index
    # Get the data associated with the args of the min aggregated metric
    exp_df = exp_df.set_index(ARG_COLS)
    best_dfs[exp] = exp_df.loc[find_best_hyperparams(perf)]
    for alpha in alphas:
        best_dfs_alpha[exp][alpha] = exp_df.loc[find_best_hyperparams(perf[perf["alpha"] == alpha])]
    for beta in betas:
        best_dfs_beta[exp][beta] = exp_df.loc[find_best_hyperparams(perf[perf["beta2"] == beta])]
    for lr in lrs:
        best_dfs_lr[exp][lr] = exp_df.loc[find_best_hyperparams(perf[perf["lr"] == lr])]
    for precond in preconds:
        best_dfs_precond[exp][precond] = exp_df.loc[find_best_hyperparams(perf[perf["precond"] == precond])]

Finding best hyperparams for ('mnist', 'L-SVRG')
Finding best hyperparams for ('mnist', 'Adam')


In [7]:
print("Best hyperparams for each optimizer on each dataset given the following setting:")
print(FILTER_ARGS)
print()
for exp, df in best_dfs.items():
    print(exp)
    for arg, val in zip(ARG_COLS, df.index[0]):
        if arg == "lr":
            val = "2**" + str(int(log2(float(val))))
        print(f"- {arg} = {val}")
    print()

Best hyperparams for each optimizer on each dataset given the following setting:
{}

('mnist', 'L-SVRG')
- lr = 2**-6
- alpha = 0.1
- beta2 = 0.99
- precond = hutchinson

('mnist', 'Adam')
- lr = 2**-10
- alpha = none
- beta2 = 0.99
- precond = none



# Plotting

In [8]:
print("Types")
for col in df.columns:
    print(col, df[col].dtypes)

Types
ep float64
loss float64
gradnorm float64
error float64


In [9]:
print("Learning rates:")
for exp, df in all_dfs.items():
    display(set("2**"+str(int(log2(float(lr)))) for lr in df["lr"]))
    break

Learning rates:


{'2**-10', '2**-12', '2**-14', '2**-4', '2**-6', '2**-8'}

In [10]:
print("Range")
for col in df.columns:
    if df[col].dtypes != "object":
        print(f"{col}: ({df[col].min():}, {df[col].max()})")

Range
ep: (0.0, 24.0)
loss: (0.02400844679877423, inf)
gradnorm: (2.343360453475422e-05, inf)
error: (0.0071, 0.91315)


In [11]:
# Set inf values to nan and recheck range
VERYBIGNUMBER = 10**10
df[df == float("inf")] = np.nan
df[df[["loss","gradnorm"]] > VERYBIGNUMBER] = np.nan
for col in df.columns:
    if df[col].dtypes != "object":
        print(f"{col}: ({df[col].min():}, {df[col].max()})")

ep: (0.0, 24.0)
loss: (0.02400844679877423, 9562093005.43324)
gradnorm: (2.343360453475422e-05, 8879064985.6)
error: (0.0071, 0.91315)


In [12]:
print("Check for NaNs in each column for each df.")
for col in df.columns:
    print(col)
    for exp, df in all_dfs.items():
        print(f"- {exp}: {df[col].isna().sum()}")

Check for NaNs in each column for each df.
ep
- ('mnist', 'L-SVRG'): 0
- ('mnist', 'Adam'): 0
lr
- ('mnist', 'L-SVRG'): 0
- ('mnist', 'Adam'): 0
alpha
- ('mnist', 'L-SVRG'): 0
- ('mnist', 'Adam'): 0
beta2
- ('mnist', 'L-SVRG'): 0
- ('mnist', 'Adam'): 0
precond
- ('mnist', 'L-SVRG'): 0
- ('mnist', 'Adam'): 0
loss
- ('mnist', 'L-SVRG'): 367
- ('mnist', 'Adam'): 0
gradnorm
- ('mnist', 'L-SVRG'): 544
- ('mnist', 'Adam'): 0
error
- ('mnist', 'L-SVRG'): 0
- ('mnist', 'Adam'): 0


## Plot best performance of each optimizer on each dataset

In [13]:
plt.rc('legend', fontsize=LEGEND_FONTSIZE, loc=LEGEND_LOC)

In [14]:
start_time = time.time()
# Plot 3 rows each one showing some performance metric,
# where the columns are the dataset on which the optim is run.
fig, axes = plt.subplots(1, 3)
fig.set_size_inches(ASPECT * HEIGHT * 3, HEIGHT * 1)
plt.suptitle(rf"Best performances on MNIST")
for optimizer in OPTIMIZERS:
    exp = (dataset, optimizer)
    if exp not in best_dfs:
        continue
    # Get hyperparams of best performance of 'optimizer' on 'dataset'
    args = {k:v for k,v in zip(best_dfs[exp].index.names, best_dfs[exp].index[0])}
    exp_df = best_dfs[exp].reset_index()
    # Show power of lr as 2^lr_pow
    lr_pow = round(log2(float(args['lr'])))
    if optimizer == "Adam":
        sublabel = rf"$\eta = 2^{{{lr_pow}}}$, $\beta_1=0.9$, $\beta_2={args['beta2']}$"
    else:
        sublabel = rf"$\eta = 2^{{{lr_pow}}}$, $\alpha={args['alpha']}$, $\beta={args['beta2']}$"
    label = rf"{optimizer}({sublabel})"
    print(f"Plotting lines for {exp}...")
    sns.lineplot(x="ep", y="loss", label=label, ax=axes[0], data=exp_df)
    sns.lineplot(x="ep", y="gradnorm", label=label, ax=axes[1], data=exp_df)
    sns.lineplot(x="ep", y="error", label=label, ax=axes[2], data=exp_df)
# Loss
axes[0].set_ylabel(r"$P(w_t)$")
axes[0].set_xlabel("Effective Passes")
axes[0].legend()
# Gradnorm
axes[1].set(yscale="log")
axes[1].set_ylabel(r"$||\nabla P(w_t)||^2$")
axes[1].set_xlabel("Effective Passes")
axes[1].legend()
# Error
axes[2].set(yscale="log")
axes[2].set_ylabel("Error")
axes[2].set_xlabel("Effective Passes")
axes[2].legend()
fig.tight_layout()

# Create a string out of filter args and save figure
plt.savefig(f"plots/perf({SETTINGS_STR}).pdf")
plt.close()
plot_best_time = time.time() - start_time
print(f"Took about {plot_best_time:.2f} seconds to create this plot.")

Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('mnist', 'Adam')...
Took about 1.58 seconds to create this plot.


## Plot best performance given a fixed value of either $\alpha$, $\beta$, or $\eta$

In [15]:
y_greek = {
    "loss": r"$P(w_t)$",
    "gradnorm": r"$||\nabla P(w_t)||^2$",
    "error": "error"
}

best_dfs_mode= {
    "alphas": best_dfs_alpha,
    "betas": best_dfs_beta,
    "lrs": best_dfs_lr,
    "precond": best_dfs_precond,
}
mode_greek = {
    "alphas": r"$\alpha$",
    "betas": r"$\beta$",
    "lrs": r"$\eta$"
}

In [16]:
%%time
for y in ("error", "gradnorm"):
    for mode in ("betas", "lrs"):
        valid_optimizers = [opt for opt in OPTIMIZERS if not (mode == "alphas" and opt == "Adam")]

        start_time = time.time()
        # Plot data for all optim, datasets, and args
        fig, axes = plt.subplots(1, len(valid_optimizers))
        fig.set_size_inches(ASPECT * HEIGHT_LARGE * len(valid_optimizers), HEIGHT_LARGE * 1)
        title = rf"Best {y_greek[y]} given {mode_greek[mode]}"
        plt.suptitle(title)
        for i, optimizer in enumerate(valid_optimizers):
            exp = (dataset, optimizer)
            if exp not in best_dfs_mode[mode]:
                continue
            exp_df = pd.concat(best_dfs_mode[mode][exp].values()).reset_index()
            if len(exp_df) == 0:
                continue
            exp_df["lr"] = exp_df["lr"].astype(float)
            exp_df["alpha"] = exp_df["alpha"].astype(str)
            exp_df["beta2"] = exp_df["beta2"].astype(str)
            print(f"Plotting lines for {exp}...")
            if mode == "lrs":
                exp_df = exp_df.sort_values("alpha", ascending=False)  # none is thinest
                exp_df = exp_df.sort_values("beta2", ascending=False)  # none is solid, avg is dashed, etc.
                sns.lineplot(ax=axes[i], x="ep", y=y,
                             hue="lr", hue_norm=LogNorm(), palette="vlag",
                             size="alpha", style="beta2", data=exp_df)
            elif mode == "betas":
                exp_df = exp_df.sort_values("alpha", ascending=True)  # none is blue, etc.
                exp_df = exp_df.sort_values("beta2", ascending=True)  # nums first, to be consistent with Adam
                sns.lineplot(ax=axes[i], x="ep", y=y,
                             hue="beta2", size="lr", size_norm=LogNorm(), style="alpha", data=exp_df)
            elif mode == "alphas":
                exp_df = exp_df.sort_values("alpha", ascending=True)  # none is blue, etc.
                exp_df = exp_df.sort_values("beta2", ascending=False)  # none is solid, avg is dashed, etc.
                sns.lineplot(ax=axes[i], x="ep", y=y,
                             hue="alpha", size="lr", size_norm=LogNorm(), style="beta2", data=exp_df)
            axes[i].set(yscale="log")
            axes[i].set_title(rf"$\tt {optimizer}({dataset})$")
            axes[i].set_ylabel(rf"{y_greek[y]}")
            axes[i].set_xlabel(rf"Effective Passes")
        fig.tight_layout()

        # Create a string out of filter args and save figure
        plt.savefig(f"plots/{y}_given_{mode}({SETTINGS_STR}).pdf")
        plt.close()
        plot_time = time.time() - start_time
        print(f"Took about {plot_time:.2f} seconds to create this plot.")

Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('mnist', 'Adam')...
Took about 1.22 seconds to create this plot.
Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('mnist', 'Adam')...
Took about 2.38 seconds to create this plot.
Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('mnist', 'Adam')...
Took about 1.26 seconds to create this plot.
Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('mnist', 'Adam')...
Took about 2.38 seconds to create this plot.
CPU times: user 7.16 s, sys: 83.1 ms, total: 7.24 s
Wall time: 7.24 s
