In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import math
import os

In [None]:
def load_ablation_results(model_type, dataset, hyperparam):
    results = {}
    results_dir = os.path.join('./sketchysgd_'+hyperparam+'_ablation', model_type, dataset)
    filenames = os.listdir(results_dir)

    for filename in filenames:
        if filename.endswith('.csv'):
            if hyperparam == 'update_freq':
                hyperparam_val = filename.split('_')[2] 
            elif hyperparam == 'rank':
                hyperparam_val = filename.split('_')[1]
                
            df = pd.read_csv(os.path.join(results_dir, filename))
            if hyperparam_val not in list(results.keys()):
                results[hyperparam_val] = []
            results[hyperparam_val].append(df)
    return results

In [None]:
def compute_averages(opt_results):
    averages = {}

    for key in opt_results.keys():
        averages[key] = {}
        n = len(opt_results[key]) # Number of runs for a given hyperparameter
        data_cols = list(opt_results[key][0].columns[1:]) # Get column names
        for metric in data_cols:
            for i in range(n):
                if metric not in averages[key].keys():
                    averages[key][metric] = 1/n * opt_results[key][i][metric].to_numpy()
                else:
                    averages[key][metric] += 1/n * opt_results[key][i][metric].to_numpy()

    return averages

In [None]:
def plot_results_metric(avg_results, hyperparam, metric, dataset, lims):
    plt.figure()

    keys = list(avg_results.keys())
    keys_numeric = []
    for key in keys:
        if key != 'infty':
            keys_numeric.append(float(key))
        else:
            keys_numeric.append(math.inf)

    keys_sorted = [x for _, x in sorted(zip(keys_numeric, keys))]

    for hyperparam_val in keys_sorted:
        if hyperparam_val == 'infty':
            label = hyperparam['label'] + ' = ' + r'$\infty$'
        else:
            label = f"{hyperparam['label']} = {f'{float(hyperparam_val):.2f}'.rstrip('0').rstrip('.')}"
        if 'loss' in metric['name']:
            plt.semilogy(np.cumsum(avg_results[hyperparam_val]['times']), avg_results[hyperparam_val][metric['name']], label = label)
        else:
            plt.plot(np.cumsum(avg_results[hyperparam_val]['times']), avg_results[hyperparam_val][metric['name']], label = label)
    plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
    plt.xlabel('Time (s)')
    plt.ylabel(metric['label'])
    plt.title(dataset)
    plt.ylim(lims)

In [None]:
def plot_results_multi_metric(model_type, dataset, hyperparam, metrics, lims):
    ablation_results = load_ablation_results(model_type, dataset, hyperparam['name'])
    avg_results = compute_averages(ablation_results)

    for metric, lim in zip(metrics, lims):
        plot_results_metric(avg_results, hyperparam, metric, dataset, lim)

In [None]:
def get_model_type(dataset):
    if dataset in ['rcv1', 'news20', 'real-sim']:
        model_type = 'logistic'
    elif dataset in ['yearmsd', 'e2006', 'w8a']:
        model_type = 'least_squares'

    return model_type

In [None]:
metrics = [{'name': 'test_acc', 'label': 'Test Accuracy (%)'},
           {'name': 'train_acc', 'label': 'Train Accuracy (%)'},
           {'name': 'test_loss', 'label': 'Test Loss'},
           {'name': 'train_loss', 'label': 'Train Loss'}]

In [None]:
dataset = 'rcv1'
model_type = get_model_type(dataset)
hyperparam_uf = {'name': 'update_freq', 'label': 'Update frequency'}
hyperparam_rk = {'name': 'rank', 'label': 'Rank'}

# plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[94, 96.5], [96, 100], [1e-1, 2e-1], [2e-2, 2e-1]])
# plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[94, 96.5], [96, 100], [1e-1, 2e-1], [2e-2, 2e-1]])

metrics = [{'name': 'test_acc', 'label': 'Test Accuracy (%)'}]
plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[94, 96.5]])
plt.savefig('rcv1_uf.pdf', bbox_inches = 'tight')
plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[94, 96.5]])
plt.savefig('rcv1_rk.pdf', bbox_inches = 'tight')

In [None]:
dataset = 'news20'
model_type = get_model_type(dataset)
hyperparam_uf = {'name': 'update_freq', 'label': 'Update frequency'}
hyperparam_rk = {'name': 'rank', 'label': 'Rank'}

# plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[89, 97], [89, 100], [8e-2, 4e-1], [2e-2, 4e-1]])
# plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[91, 97], [93, 100], [8e-2, 3e-1], [2e-2, 3e-1]])

metrics = [{'name': 'test_acc', 'label': 'Test Accuracy (%)'}]
plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[89, 97]])
plt.savefig('news20_uf.pdf', bbox_inches = 'tight')
plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[91, 97]])
plt.savefig('news20_rk.pdf', bbox_inches = 'tight')

In [None]:
dataset = 'real-sim'
model_type = get_model_type(dataset)
hyperparam_uf = {'name': 'update_freq', 'label': 'Update frequency'}
hyperparam_rk = {'name': 'rank', 'label': 'Rank'}

# plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[97, 98], [98, 100], [6e-2, 1e-1], [2e-2, 1e-1]])
# plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[96.5, 97.5], [98, 100], [6e-2, 1e-1], [2e-2, 1e-1]])

metrics = [{'name': 'test_acc', 'label': 'Test Accuracy (%)'}]
plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[97, 98]])
plt.savefig('real-sim_uf.pdf', bbox_inches = 'tight')
plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[96.5, 97.5]])
plt.savefig('real-sim_rk.pdf', bbox_inches = 'tight')

In [None]:
dataset = 'yearmsd'
model_type = get_model_type(dataset)
hyperparam_uf = {'name': 'update_freq', 'label': 'Update frequency'}
hyperparam_rk = {'name': 'rank', 'label': 'Rank'}

# metrics = [{'name': 'test_loss', 'label': 'Test Loss'},
#            {'name': 'train_loss', 'label': 'Train Loss'}]

# plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[5e1, 1e2], [5e1, 1e2]])
# plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[5e1, 1e2], [5e1, 1e2]])

metrics = [{'name': 'test_loss', 'label': 'Test Loss'}]
plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[5e1, 1e2]])
plt.savefig('yearmsd_uf.pdf', bbox_inches = 'tight')
plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[5e1, 1e2]])
plt.savefig('yearmsd_rk.pdf', bbox_inches = 'tight')

In [None]:
dataset = 'e2006'
model_type = get_model_type(dataset)
hyperparam_uf = {'name': 'update_freq', 'label': 'Update frequency'}
hyperparam_rk = {'name': 'rank', 'label': 'Rank'}

metrics = [{'name': 'test_loss', 'label': 'Test Loss'}]
#,{'name': 'train_loss', 'label': 'Train Loss'}]

plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[1.2e-1, 2e-1]])
plt.savefig('e2006_uf.pdf', bbox_inches = 'tight')
plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[1.2e-1, 2e-1]])
plt.savefig('e2006_rk.pdf', bbox_inches = 'tight')

In [None]:
dataset = 'w8a'
model_type = get_model_type(dataset)
hyperparam_uf = {'name': 'update_freq', 'label': 'Update frequency'}
hyperparam_rk = {'name': 'rank', 'label': 'Rank'}

# metrics = [{'name': 'test_acc', 'label': 'Test Accuracy (%)'},
#            {'name': 'train_acc', 'label': 'Train Accuracy (%)'},
#            {'name': 'test_loss', 'label': 'Test Loss'},
#            {'name': 'train_loss', 'label': 'Train Loss'}]

# plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[97, 99], [97, 99], [2e-2, 4e-2], [2e-2, 4e-2]])
# plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[98, 99], [98, 99], [2e-2, 4e-2], [2e-2, 4e-2]])

metrics = [{'name': 'test_acc', 'label': 'Test Accuracy (%)'}]
plot_results_multi_metric(model_type, dataset, hyperparam_uf, metrics, [[97, 99]])
plt.savefig('w8a_uf.pdf', bbox_inches = 'tight')
plot_results_multi_metric(model_type, dataset, hyperparam_rk, metrics, [[97, 99]])
plt.savefig('w8a_rk.pdf', bbox_inches = 'tight')