In [39]:
import os
import glob
import pickle
import time
from math import log2
from itertools import cycle, product
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LogNorm
%matplotlib inline

from plot1 import *

# Settings

In [50]:
# Data logs root directory
LOG_DIR = "../logs/logs_torch"
PLOT_DIR = "plots_torch"


class Args:
    # Loss function: either "logistic" regression, or nonlinear least squares ('nllsq')
    LOSSES = ("cross_entropy",)
    # The following should be the same as the one used in run_experiment.py
    DATASETS = ("mnist", "cifar-10")
    OPTIMIZERS = ("SARAH", "L-SVRG", "Adam")
    MAX_EPOCHS = 50  # Use 2xT used in run_experiment.py
    # These are the metrics collected in the data logs
    METRICS = ("loss", "gradnorm", "error")
    # These are aggregators for comparing multi-seed runs
    AGGS = ("mean", "median")
    # These are the logs columns: effective passes + metrics + walltime
    LOG_COLS = ["ep", "loss", "gradnorm", "error", "time"]
    DATA_INDICES = [0, 1, 2, 3, 5]  # indices corresponding to chosen cols in logs
    # These are the hyperparameters of interest
    ARG_COLS = ["lr", "alpha", "beta2", "precond"]

    # Choose loss, metric, and aggregation method
    idx = "ep"
    loss = "cross_entropy"
    metric = "error"
    agg = "mean"
    # Downsample this number of effective passes by averaging them
    avg_downsample = 2
    # Logs will be filtered for this setting when applicable (USE EXACT STRING VALUE AS IN FILENAME).
    filter_args = {
        "beta1": '0.0',
    }
    # Ignore all runs containing 'any' of these hyperparams.
    ignore_args = {
        "alpha": (1e-11,),
        "weight_decay": (0.1,),
    }
    # Force remove log files that are empty
    remove_empty_file = False

    def __init__(self, log_dir, plot_dir) -> None:
        self.log_dir = log_dir
        # self.log_dir = os.path.join(self.log_dir, self.loss)
        self.plots_dir = plot_dir
        os.makedirs(self.plots_dir, exist_ok=True)
        self.as_dict = dict(idx=self.idx, loss=self.loss, metric=self.metric, **self.filter_args)
        self.experiment_str = f"_".join(f"{k}_{v}" for k,v in self.as_dict.items())

    def __repr__(self) -> str:
        experiment_repr = f", ".join(f"{k}={v}" for k,v in self.as_dict.items())
        return f"Args({experiment_repr})"


args = Args(log_dir=LOG_DIR, plot_dir=PLOT_DIR)
args

Args(idx=ep, loss=cross_entropy, metric=error, beta1=0.0)

# Gathering data and finding best hyperparameters for each (optimizer, dataset) combination

In [51]:
all_dfs = create_experiments_dataframe(args)

Data frame lengths:
('mnist', 'SARAH') -> 4167 data rows -> 83 runs
('mnist', 'L-SVRG') -> 4068 data rows -> 81 runs
('mnist', 'Adam') -> 936 data rows -> 18 runs
('cifar-10', 'SARAH') -> 4193 data rows -> 83 runs
('cifar-10', 'L-SVRG') -> 4210 data rows -> 84 runs
('cifar-10', 'Adam') -> 936 data rows -> 18 runs
Took about 2.41 seconds to gather all these data.


## Get best hyperparams

In [52]:
best_dfs, best_dfs_fixed_args = find_all_best_hyperparams(args, all_dfs)

Finding best hyperparams for ('mnist', 'SARAH')
Finding best hyperparams for ('mnist', 'L-SVRG')
Finding best hyperparams for ('mnist', 'Adam')
Finding best hyperparams for ('cifar-10', 'SARAH')
Finding best hyperparams for ('cifar-10', 'L-SVRG')
Finding best hyperparams for ('cifar-10', 'Adam')


# Plotting

## Plot best performance of each optimizer on each dataset

In [53]:
plot_best_perfs(args, best_dfs)

Plotting lines for ('mnist', 'SARAH')...
Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('mnist', 'Adam')...
Plotting lines for ('cifar-10', 'SARAH')...
Plotting lines for ('cifar-10', 'L-SVRG')...
Plotting lines for ('cifar-10', 'Adam')...
Took about 7.75 seconds to create this plot.


## Plot best performance given a fixed value of either $\alpha$, $\beta$, or $\eta$

In [54]:
plot_best_perfs_given_fixed_arg(args, best_dfs_fixed_args)

Plotting lines for ('mnist', 'SARAH')...
Plotting lines for ('cifar-10', 'SARAH')...
Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('cifar-10', 'L-SVRG')...
Took about 0.40 seconds to create this plot.
Plotting lines for ('mnist', 'SARAH')...
Plotting lines for ('cifar-10', 'SARAH')...
Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('cifar-10', 'L-SVRG')...
Plotting lines for ('mnist', 'Adam')...
Plotting lines for ('cifar-10', 'Adam')...
Took about 0.62 seconds to create this plot.
Plotting lines for ('mnist', 'SARAH')...
Plotting lines for ('cifar-10', 'SARAH')...
Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('cifar-10', 'L-SVRG')...
Plotting lines for ('mnist', 'Adam')...
Plotting lines for ('cifar-10', 'Adam')...
Took about 0.58 seconds to create this plot.
Plotting lines for ('mnist', 'SARAH')...
Plotting lines for ('cifar-10', 'SARAH')...
Plotting lines for ('mnist', 'L-SVRG')...
Plotting lines for ('cifar-10', 'L-SVRG')...
Took about 