In [25]:
%matplotlib inline

import os
from warnings import warn
from typing import Dict, List

import matplotlib.pyplot as plt
import scnn
from experiment_utils import configs, utils, experiments, files
from experiment_utils.plotting import defaults
from experiment_utils.plotting.plot_grid import plot_grid
from experiment_utils.plotting.plot_cell import make_convergence_plot

from scaffold.runners import run_experiment

In [2]:
max_iters = 10000
gated_width = 2500
relu_width = 500
arrangement_seed = 779
lambda_to_try = [1e-7]

In [3]:
# Configurations for experiments

TorchGated = {
    "name": "torch_mlp_l1",
    "hidden_layers": [
        [
            {
                "name": "gated_relu",
                "sign_patterns": {
                    "name": "sampler",
                    "n_samples": gated_width,
                    "seed": arrangement_seed,
                },
            },
        ]
    ],
    "regularizer": {
        "name": "l2",
        "lambda": lambda_to_try,
    },
}

TorchReLU = {
    "name": "torch_mlp_l1",
    "hidden_layers": [
        [
            {
                "name": "relu",
                "p": relu_width,
            },
        ]
    ],
    "regularizer": {
        "name": "l2",
        "lambda": lambda_to_try,
    },
}

TorchOptim = {
    "name": ["torch_adam", "torch_sgd"],
    "step_size": [0.001, 0.01, 0.1, 1, 5, 10],
    "batch_size": 0.1,
    "max_epochs": max_iters,
    "term_criterion": {"name": "grad_norm", "tol": 1e-10},
    "scheduler": {"name": "step", "step_length": 100, "decay": 0.5},
    "metric_freq": 10,
}


FISTA = {
    "name": ["fista"],
    "ls_cond": {"name": "quadratic_bound"},
    "backtrack_fn": {"name": "backtrack", "beta": 0.8},
    "step_size_update": [
        {
            "name": "lassplore",
            "alpha": 1.25,
            "threshold": 5,
        },
    ],
    "init_step_size": 0.1,
    "term_criterion": {"name": "grad_norm", "tol": 1e-8},
    "prox": {"name": "group_l1"},
    "max_iters": max_iters,
    "metric_freq": 10,
    "restart_rule": "gradient_mapping",
}


GReLU_GL1 = {
    "name": "convex_mlp",
    "kernel": "einsum",
    "sign_patterns": {
        "name": "sampler",
        "n_samples": gated_width,
        "seed": arrangement_seed,
    },
    "regularizer": {
        "name": "group_l1",
        "lambda": lambda_to_try,
    },
    "c": 10,
    "initializer": {"name": "zero"},
}

data = {
    "name": "mnist",
    "transforms": [["to_tensor", "normalize", "flatten"]],
    "use_valid": False,
}

metrics = (
    ["objective", "grad_norm", "nc_accuracy"],
    ["nc_accuracy"],
    ["group_sparsity"],
)

final_metrics = (
    ["objective", "accuracy", "squared_error", "grad_norm"],
    ["nc_accuracy", "squared_error"],
    [
        "group_sparsity",
    ],
)

ConvexGated = {
    "method": FISTA,
    "model": GReLU_GL1,
    "data": data,
    "metrics": metrics,
    "final_metrics": final_metrics,
    "seed": 778,
    "repeat": list(range(1)),
    "backend": "torch",
    "device": "cuda",
    "dtype": "float32",
}

NonConvex = {
    "method": TorchOptim,
    "model": [TorchReLU, TorchGated],
    "data": data,
    "metrics": metrics,
    "final_metrics": final_metrics,
    "seed": 778,
    "repeat": list(range(1)),
    "backend": "torch",
    "device": "cuda",
    "dtype": "float32",
}

EXPERIMENTS: Dict[str, List] = {
    "mnist_timing_convex": [ConvexGated],
    "mnist_timing_non_convex": [NonConvex],
}


In [4]:
# Parameters for running.
verbose = True
debug = False
log_file = None
force_rerun = False

base_results_dir = "results"
data_dir = "data"

In [5]:
for exp_id, config in EXPERIMENTS.items():
    # Run Experiments
    logger = utils.get_logger(exp_id, verbose, debug, log_file)
    
    logger.warning(f"\n\n====== Running {exp_id} ======\n")

    experiment_list = configs.expand_config_list(config)


    logger.warning("Starting experiments.")

    results_dir = os.path.join(base_results_dir, exp_id)
    print(results_dir)
    
    for i, exp_dict in enumerate(experiment_list):
        num_repeats = 10
        logger.warning(f"Running Experiment: {i+1}/{len(experiment_list)}.")
        logger.info(f"Method: {exp_dict['method']['name']}")
        # wrap in try-except block to prevent a single failure from crashing all experiments.
        try:
            experiments.run_or_load(
                logger,
                exp_dict,
                run_experiment,
                data_dir,
                results_dir,
                force_rerun,
            )
        except Exception as e:
            if debug:
                raise
            else:
                # log the error
                logger.error(
                    f"Exception {e} encountered while running experiment with configuration {exp_dict}."
                )
                # output the error to the user.
                warn(
                    f"Exception {e} encountered while running experiment with configuration {exp_dict}."
                )

    logger.warning("Experiments complete.")



INFO:mnist_timing_convex:Method: fista
INFO:mnist_timing_convex:Loading results.


INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non_convex:Loading results.
INFO:mnist_timing_non_convex:Method: torch_adam
INFO:mnist_timing_non

results/mnist_timing_convex
results/mnist_timing_non_convex


In [11]:
# Load Results
row_key = ("data", "name") # rows of the plot
metrics = [                # columns of the plot
    "train_objective",
    "test_nc_accuracy",
]
repeat_key = ("repeat",)   # average over this

# lines in the plot
def line_key_gen(exp_dict): 
    """Load line key."""
    method = exp_dict["method"]
    key = method["name"]

    if "torch" in key:
        step_size = method["step_size"]
        key = key + "_" + str(step_size)

    return key

# keep only gated ReLU and good step-sizes
def filter_fn(exp_config):
    model = exp_config["model"]

    method = exp_config["method"]

    keep = True

    if method["name"] == "torch_adam":
        keep = method["step_size"] == 0.1
    elif method["name"] == "torch_sgd":
        keep = method["step_size"] == 10

    if "hidden_layers" in model:
        return keep and model["hidden_layers"][0]["name"] == "gated_relu"
    else:
        return keep

In [14]:
config_list = EXPERIMENTS["mnist_timing_convex"] + EXPERIMENTS["mnist_timing_non_convex"]
results_dir = [base_results_dir + "/mnist_timing_convex", base_results_dir + "/mnist_timing_non_convex"]

# load data
metric_grid = files.load_and_clean_experiments(
    config_list,
    results_dir,
    metrics=metrics,
    row_key=row_key,
    line_key=line_key_gen,
    repeat_key=repeat_key,
    metric_fn=utils.quantile_metrics,
    keep=[],
    remove=[],
    filter_fn=filter_fn,
    transform_fn=None,
    processing_fns=[],
    x_key="time",
)

In [21]:
# Plotting
figure_labels = {
    "x_labels": {
        "train_objective": "Time (S)",
        "test_nc_accuracy": "Time (S)",
    },
    "y_labels": {},
    "col_titles": {
        "train_objective": "Training Objective",
        "test_nc_accuracy": "Test Accuracy",
    },
    "row_titles": {},
}

limits = {
    "train_objective": ([0, 1000], None),
    "test_nc_accuracy": ([0, 1000], [0.8, 1.01]),
}


line_colors = [
    "#000000",
    "#1f77b4",
    "#ff7f0e",
    "#2ca02c",
    "#d62728",
    "#9467bd",
    "#e377c2",
    "#7f7f7f",
    "#bcbd22",
    "#8c564b",
    "#17becf",
    "#556B2F",
    "#FFFF00",
    "#191970",
]

line_kwargs = {
    "fista": {
        "c": line_colors[0],
        "label": "Convex",
        "linewidth": 3,
        "marker": "v",
        "markevery": 0.1,
        "markersize": 8,
    },
    "torch_adam_0.1": {
        "c": line_colors[1],
        "label": "Adam",
        "linewidth": 3,
        "marker": "D",
        "markevery": 0.1,
        "markersize": 8,
    },
    "torch_sgd_10": {
        "c": line_colors[2],
        "label": "SGD",
        "linewidth": 3,
        "marker": "X",
        "markevery": 0.1,
        "markersize": 8,
    },
}

log_scale = {
    "train_objective": "log-linear",
    "train_squared_error": "log-linear",
    "test_squared_error": "log-linear",
    "train_grad_norm": "log-linear",
    "time": "linear-linear",
    "train_constraint_gaps": "log-linear",
}


settings = defaults.DEFAULT_SETTINGS
settings["y_labels"] = "every_col"
settings["x_labels"] = "bottom_row"
settings["legend_cols"] = 4
settings["show_legend"] = True

In [22]:
plot_dest = "figures"

In [29]:
fig, = plot_grid(
    plot_fn=make_convergence_plot,
    results=metric_grid,
    figure_labels=figure_labels,
    line_kwargs=line_kwargs,
    limits=limits,
    log_scale=log_scale,
    base_dir=os.path.join(
        plot_dest, "mnist_timing.pdf"
    ),
    settings=settings,
)

(<Figure size 864x432 with 2 Axes>, {('mnist', 'train_objective'): <AxesSubplot:title={'center':'Training Objective'}, xlabel='Time (S)'>, ('mnist', 'test_nc_accuracy'): <AxesSubplot:title={'center':'Test Accuracy'}, xlabel='Time (S)'>})
