In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import math
import os

In [None]:
def load_results_opt(model_type, dataset, opt):
    results = {}
    results_dir = os.path.join('./general_results', model_type, dataset, opt)
    filenames = os.listdir(results_dir)
    if opt in ['sgd', 'svrg', 'slbfgs', 'lkatyusha']:
        for filename in filenames:
            if filename.endswith('.csv'):
                start = filename.find('_') + len('_')
                end = filename.find('_seed_')
                hyperparam = filename[start:end] # Get lr or L
                if hyperparam not in list(results.keys()):
                    results[hyperparam] = []
                
                df = pd.read_csv(os.path.join(results_dir, filename))
                results[hyperparam].append(df)
    elif opt == 'sketchysgd':
        results['auto'] = []
        for filename in filenames:
            if filename.endswith('.csv'):
                df = pd.read_csv(os.path.join(results_dir, filename))
                results['auto'].append(df)
    return results

In [None]:
def compute_averages(opt_results):
    averages = {}

    for key in opt_results.keys():
        averages[key] = {}
        n = len(opt_results[key]) # Number of runs for a given hyperparameter
        data_cols = list(opt_results[key][0].columns[1:]) # Get column names
        for metric in data_cols:
            for i in range(n):
                metric_data = 1/n * opt_results[key][i][metric].to_numpy()
                metric_data_adj = np.nan_to_num(metric_data, nan = np.inf) # When the loss is nan, replace with inf
                if metric not in averages[key].keys():
                    averages[key][metric] = metric_data_adj
                else:
                    averages[key][metric] += metric_data_adj

    return averages

In [None]:
def get_best_peak_run(opt_averages, opt_name, metric):
    best_run = None
    best_hyperparam = None

    for key in opt_averages.keys():
        if best_run is None:
            best_run = opt_averages[key]
            if opt_name in ['sgd', 'svrg', 'slbfgs', 'lkatyusha']:
                best_hyperparam = key
        else:
            if 'loss' in metric:
                if np.min(opt_averages[key][metric]) < np.min(best_run[metric]):
                    best_run = opt_averages[key]
                    if opt_name in ['sgd', 'svrg', 'slbfgs', 'lkatyusha']:
                        best_hyperparam = key
            elif 'acc' in metric:
                if np.max(opt_averages[key][metric]) > np.max(best_run[metric]):
                    best_run = opt_averages[key]
                    if opt_name in ['sgd', 'svrg', 'slbfgs', 'lkatyusha']:
                        best_hyperparam = key
    
    return best_run, best_hyperparam

In [None]:
def plot_results_metric(avg_results, opts, metric, dataset, colors, lims):
    plt.figure()
    opt_times = np.zeros(len(opts))
    
    for i, opt in enumerate(opts):
        best_run, best_hyperparam = get_best_peak_run(avg_results[opt], opt, metric['name'])
        label = opt
        if opt in ['sgd', 'svrg', 'slbfgs']:
            print(f"{opt}: best LR = {float(best_hyperparam):.8f}")
        elif opt == 'lkatyusha':
            print(f"{opt}: best L = {float(best_hyperparam):.8f}")
        
        if opt == 'sketchysgd':
            linestyle = 'dashed'
        else:
            linestyle = 'solid'

        if 'loss' in metric['name']:
            plt.semilogy(np.cumsum(best_run['times']), best_run[metric['name']],
                         colors[opt], label = label, linestyle = linestyle)
        else:
            plt.plot(np.cumsum(best_run['times']), best_run[metric['name']],
                     colors[opt], label = label, linestyle = linestyle)
    
        opt_times[i] = np.cumsum(best_run['times'])[-1]
        if opt == 'sketchysgd':
            print(best_run['train_loss'][0:100] - best_run['test_loss'][0:100])
    opt_times_sorted = np.sort(opt_times)
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    plt.xlabel('Time (s)')
    plt.ylabel(metric['label'])
    plt.title(dataset)
    plt.ylim(lims)
    plt.xlim(0, opt_times_sorted[1])

In [None]:
def plot_results_multi_metric(model_type, dataset, opts, metrics, colors, lims):
    avg_results = dict.fromkeys(opts)
    for opt in opts:
        opt_results = load_results_opt(model_type, dataset, opt)
        avg_results[opt] = compute_averages(opt_results)

    for metric, lim in zip(metrics, lims):
        plot_results_metric(avg_results, opts, metric, dataset, colors, lim)

In [None]:
def plot_avgs(model_type, dataset, opts, metric, colors, lim):
    from matplotlib.lines import Line2D
    avg_results = dict.fromkeys(opts)
    for opt in opts:
        opt_results = load_results_opt(model_type, dataset, opt)
        avg_results[opt] = compute_averages(opt_results)

    plt.figure()
    opt_times = np.zeros(len(opts))
    legend_elements = []
    for i, opt in enumerate(opts):
        best_run, best_hyperparam = get_best_peak_run(avg_results[opt], opt, metric['name'])
        if opt in ['sgd', 'svrg', 'slbfgs']:
            print(f"{opt}: best LR = {float(best_hyperparam):.8f}")
        elif opt == 'lkatyusha':
            print(f"{opt}: best L = {float(best_hyperparam):.8f}")

        # Specifications for legend
        if opt == 'sketchysgd':
            alpha = 1
            linestyle = 'dashed'
        else:
            alpha = 0.2
            linestyle = 'solid'
        legend_elements.append(Line2D([0], [0], color=colors[opt], label=opt, alpha=alpha, linestyle = linestyle))

        opt_time = np.inf
        for key in avg_results[opt].keys():
            
            # If we are plotting something other than sketchysgd that does not correspond to the best hyperparameter, make it translucent
            if opt in ['sgd', 'svrg', 'slbfgs', 'lkatyusha'] and key != best_hyperparam:
                alpha = 0.2
            else:
                alpha = 1

            if 'loss' in metric['name']:
                plt.semilogy(np.cumsum(avg_results[opt][key]['times']), avg_results[opt][key][metric['name']],
                             colors[opt], alpha = alpha, linestyle = linestyle)
            else:
                plt.plot(np.cumsum(avg_results[opt][key]['times']), avg_results[opt][key][metric['name']],
                         colors[opt], alpha = alpha, linestyle = linestyle)

            # Get the smallest time taken by the optimization algorithm across all hyperparameters
            opt_time = min(opt_time, np.cumsum(avg_results[opt][key]['times'])[-1])
        opt_times[i] = opt_time
    opt_times_sorted = np.sort(opt_times)

    plt.legend(handles = legend_elements, bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    plt.xlabel('Time (s)')
    plt.ylabel(metric['label'])
    plt.title(dataset)
    plt.ylim(lim)
    plt.xlim(0, opt_times_sorted[1])

In [None]:
def get_model_type(dataset):
    if dataset in ['rcv1', 'news20', 'real-sim']:
        model_type = 'logistic'
    elif dataset in ['yearmsd', 'e2006', 'w8a']:
        model_type = 'least_squares'

    return model_type

In [None]:
opts = ['sketchysgd', 'sgd', 'svrg', 'slbfgs', 'lkatyusha']
colors = {'sketchysgd': 'k', 'sgd': 'tab:blue', 'svrg': 'tab:orange', 'slbfgs': 'tab:brown', 'lkatyusha': 'tab:olive'}
metrics = [{'name': 'test_acc', 'label': 'Test Accuracy (%)'},
           {'name': 'train_acc', 'label': 'Train Accuracy (%)'},
           {'name': 'test_loss', 'label': 'Test Loss'},
           {'name': 'train_loss', 'label': 'Train Loss'}]

In [None]:
dataset = 'rcv1'
model_type = get_model_type(dataset)
plot_avgs(model_type, dataset, opts, metrics[0], colors, [90, 97])

plt.savefig('./rcv1_test_acc.pdf', bbox_inches = 'tight')
# plot_results_multi_metric(model_type, dataset, opts, metrics, [[92, 97], [92, 100], [1e-1, 2e-1], [2e-2, 2e-1]])

In [None]:
dataset = 'news20'
model_type = get_model_type(dataset)
plot_avgs(model_type, dataset, opts, metrics[0], colors, [80, 98])
plt.savefig('./news20_test_acc.pdf', bbox_inches = 'tight')
# plot_results_multi_metric(model_type, dataset, opts, metrics, [[83, 98], [83, 100], [1e-1, 6e-1], [1e-2, 7e-1]])

In [None]:
dataset = 'real-sim'
model_type = get_model_type(dataset)
plot_avgs(model_type, dataset, opts, metrics[0], colors, [90, 98])
plt.savefig('./real-sim_test_acc.pdf', bbox_inches = 'tight')
# plot_results_multi_metric(model_type, dataset, opts, metrics, [[96, 98], [97, 100], [6e-2, 2e-1], [2e-2, 2e-1]])

In [None]:
dataset = 'yearmsd'
model_type = get_model_type(dataset)

metrics = [{'name': 'test_loss', 'label': 'Test Loss'},
           {'name': 'train_loss', 'label': 'Train Loss'}]

plot_avgs(model_type, dataset, opts, metrics[0], colors, [5e1, 1e4])
plt.savefig('./yearmsd_test_loss.pdf', bbox_inches = 'tight')

metrics = [{'name': 'train_loss', 'label': 'Train Loss'}]
plot_results_multi_metric(model_type, dataset, opts, metrics, colors, [[5e1, 5e3], [5e1, 5e3]])
plt.savefig('./yearmsd_train_loss.pdf', bbox_inches = 'tight')

In [None]:
dataset = 'e2006'
model_type = get_model_type(dataset)

metrics = [{'name': 'test_loss', 'label': 'Test Loss'},
           {'name': 'train_loss', 'label': 'Train Loss'}]

plot_avgs(model_type, dataset, opts, metrics[0], colors, [1.2e-1, 7e-1])
plt.savefig('./e2006_test_loss.pdf', bbox_inches = 'tight')

metrics = [{'name': 'train_loss', 'label': 'Train Loss'}]
plot_results_multi_metric(model_type, dataset, opts, metrics, colors, [[1.2e-1, 7e-1], [1.4e-1, 5e-1]])
plt.savefig('./e2006_train_loss.pdf', bbox_inches = 'tight')

In [None]:
dataset = 'w8a'
model_type = get_model_type(dataset)

metrics = [{'name': 'test_acc', 'label': 'Test Accuracy (%)'},
           {'name': 'train_acc', 'label': 'Train Accuracy (%)'},
           {'name': 'test_loss', 'label': 'Test Loss'},
           {'name': 'train_loss', 'label': 'Train Loss'}]

plot_avgs(model_type, dataset, opts, metrics[0], colors, [90, 99])
plt.savefig('./w8a_test_acc.pdf', bbox_inches = 'tight')

metrics = [{'name': 'train_loss', 'label': 'Train Loss'}]
plot_results_multi_metric(model_type, dataset, opts, metrics, colors, [[2e-2, 1e-1]])
plt.savefig('./w8a_train_loss.pdf', bbox_inches = 'tight')