Constrain the samplers to have equal running times, and estimate their performances.

In [1]:
import os
import json
import h5py
import numpy as np
import torch as t
import matplotlib.pyplot as plt
import pickle
import gc
import pandas as pd

from pathlib import Path
import warnings
warnings.filterwarnings("ignore")
from scipy.interpolate import interp1d

In [2]:
def get_center_and_std(metrics):
    center = []
    std = []
    for entry in metrics:
        center.append(np.mean(entry))
        std.append(np.std(entry))
    return np.array(center), np.array(std)

In [3]:
def get_metrics(parent_dir, lrs, other_args, csv_path, fastest_metric="Identity"):
    
    trials = 10
    num_epochs = 400

    sub_dirs = [str(i+1) for i in range(trials)]
        
    final_metrics = dict()
    
    df = pd.read_csv(csv_path, index_col=0)
    final_metrics["timestamp"] = num_epochs * df["T"][fastest_metric]
    
    val_or_tests = ["val", "test"]
    used_metrics = ["lp_ensemble", "acc_ensemble"]
    indexes = [i for i in range(num_epochs)]
    
    metric_names = {'VanillaSGLD': 'Identity', 'WenzelSGLD': 'Wenzel', 'pSGLD': 'RMSprop', 'MongeSGLD': 'Monge', 'ShampooSGLD': 'Shampoo'}
    
    for (key, value) in lrs.items():
        final_metrics[key] = dict()
        
        current_metrics = dict()
        
        if key in other_args:
            temps = []
            for other_arg in other_args[key]:
                for lr in value:
                    temps.append(f"{lr}_{other_arg}")
        else:
            temps = value
            
        bad_lrs = []

        for lr in temps:
            try:
                current_metrics[lr] = dict()
                for val_or_test in val_or_tests:
                    current_metrics[lr][val_or_test] = dict()
                    for used_metric in used_metrics:
                        current_metrics[lr][val_or_test][used_metric] = [[] for _ in range(num_epochs)]

                for sub_dir in sub_dirs:
                    with open(f'{parent_dir}{key}_{lr}/{sub_dir}/evaluations.pkl', 'rb') as f:
                        evaluations = pickle.load(f)
                        for val_or_test in val_or_tests:
                            for used_metric in used_metrics:
                                for (index, entry) in enumerate(evaluations[val_or_test]):
                                    current_metrics[lr][val_or_test][used_metric][index].append(entry[used_metric])

                for val_or_test in val_or_tests:
                    for used_metric in used_metrics:
                        center, std = get_center_and_std(current_metrics[lr][val_or_test][used_metric])
                        current_metrics[lr][val_or_test][used_metric] = dict()
                        current_metrics[lr][val_or_test][used_metric]["center"] = center
                        current_metrics[lr][val_or_test][used_metric]["std"] = std
            except:
                # print(key, lr)
                bad_lrs.append(lr)

        # df = pd.read_csv(csv_path, index_col=0)
        times = [(index+1)*df["T"][metric_names[key]] for index in indexes]
        if not np.isnan(df["T"][metric_names[key]]):
            assert times[-1] >= final_metrics["timestamp"]
            
        candidate_lrs = dict()
        for lr in temps:
            if lr not in bad_lrs:
                center_f = interp1d(times, current_metrics[lr]["val"]["lp_ensemble"]["center"])
                candidate_lrs[lr] = center_f(final_metrics["timestamp"]).item()

        best_lr = max(candidate_lrs, key=candidate_lrs.get)
        final_metrics[key]["best_lr"] = best_lr

        final_metrics[key]["lr"] = best_lr
        for used_metric in used_metrics:
            center_f = interp1d(times, current_metrics[best_lr]["test"][used_metric]["center"])
            std_f = interp1d(times, current_metrics[best_lr]["test"][used_metric]["std"])
            final_metrics[key][used_metric] = [np.round(center_f(final_metrics["timestamp"]).item(), 4), 
                                               np.round(std_f(final_metrics["timestamp"]).item(), 4)]
                
    return final_metrics
    

In [4]:
parent_dir = '../results/mnist_400_gaussian_1.0_100_flat_400_1000_evaluations/'

lrs = dict()
lrs["VanillaSGLD"] = [0.01, 0.025, 0.05, 0.075, 0.1]
lrs["WenzelSGLD"] = [0.025, 0.05, 0.075, 0.1, 0.25, 0.5]
lrs["pSGLD"] = [0.000075, 0.0001, 0.00025, 0.0005, 0.00075]
lrs["MongeSGLD"] = [0.01, 0.025, 0.05, 0.075, 0.1, 0.25]
lrs["ShampooSGLD"] = [0.00075, 0.001, 0.0025, 0.005, 0.0075]

other_args = dict()
other_args["MongeSGLD"] = ["monge_alpha_2=2.25", "monge_alpha_2=2.0", "monge_alpha_2=1.75", "monge_alpha_2=1.5", "monge_alpha_2=1.25", "monge_alpha_2=1.0", "monge_alpha_2=0.75", "monge_alpha_2=0.5", "monge_alpha_2=0.25", "monge_alpha_2=0.1", "monge_alpha_2=0.05"]

csv_path = "csvs/mnist_400_gaussian_1.0_100_flat_400_1000_evaluations.csv"

get_metrics(parent_dir, lrs, other_args, csv_path, "Wenzel")


{'timestamp': 3080.0,
 'VanillaSGLD': {'best_lr': 0.05,
  'lr': 0.05,
  'lp_ensemble': [-0.1311, 0.0004],
  'acc_ensemble': [0.9684, 0.0005]},
 'WenzelSGLD': {'best_lr': 0.075,
  'lr': 0.075,
  'lp_ensemble': [-0.1315, 0.0002],
  'acc_ensemble': [0.9686, 0.0003]},
 'pSGLD': {'best_lr': 0.00025,
  'lr': 0.00025,
  'lp_ensemble': [-0.131, 0.0005],
  'acc_ensemble': [0.9689, 0.0006]},
 'MongeSGLD': {'best_lr': '0.01_monge_alpha_2=2.25',
  'lr': '0.01_monge_alpha_2=2.25',
  'lp_ensemble': [nan, nan],
  'acc_ensemble': [nan, nan]},
 'ShampooSGLD': {'best_lr': 0.005,
  'lr': 0.005,
  'lp_ensemble': [-0.1304, 0.0002],
  'acc_ensemble': [0.9684, 0.0004]}}

In [5]:
parent_dir = '../results/mnist_400_horseshoe_1.0_100_flat_400_1000_evaluations/'

lrs = dict()
lrs["VanillaSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["WenzelSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75, 1.0]
lrs["pSGLD"] = [0.0001, 0.00025, 0.0005, 0.00075, 0.001]
lrs["MongeSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["ShampooSGLD"] = [0.001, 0.0025, 0.005, 0.0075, 0.01]

other_args = dict()
other_args["MongeSGLD"] = ["monge_alpha_2=3.0", "monge_alpha_2=2.75", "monge_alpha_2=2.5", "monge_alpha_2=2.25", "monge_alpha_2=2.0", "monge_alpha_2=1.75", "monge_alpha_2=1.5", "monge_alpha_2=1.25", "monge_alpha_2=1.0", "monge_alpha_2=0.75", "monge_alpha_2=0.5", "monge_alpha_2=0.25", "monge_alpha_2=0.1", "monge_alpha_2=0.05"]

csv_path = "csvs/mnist_400_horseshoe_1.0_100_flat_400_1000_evaluations.csv"

get_metrics(parent_dir, lrs, other_args, csv_path)


{'timestamp': 3320.0000000000005,
 'VanillaSGLD': {'best_lr': 0.25,
  'lr': 0.25,
  'lp_ensemble': [-0.075, 0.0004],
  'acc_ensemble': [0.9823, 0.0003]},
 'WenzelSGLD': {'best_lr': 0.75,
  'lr': 0.75,
  'lp_ensemble': [-0.0777, 0.0009],
  'acc_ensemble': [0.9806, 0.0003]},
 'pSGLD': {'best_lr': 0.00075,
  'lr': 0.00075,
  'lp_ensemble': [-0.068, 0.0013],
  'acc_ensemble': [0.9792, 0.0007]},
 'MongeSGLD': {'best_lr': '0.25_monge_alpha_2=1.25',
  'lr': '0.25_monge_alpha_2=1.25',
  'lp_ensemble': [-0.065, 0.0006],
  'acc_ensemble': [0.9826, 0.0006]},
 'ShampooSGLD': {'best_lr': 0.005,
  'lr': 0.005,
  'lp_ensemble': [-0.0691, 0.0009],
  'acc_ensemble': [0.9812, 0.0008]}}

In [6]:
parent_dir = '../results/mnist_800_gaussian_1.0_100_flat_400_1000_evaluations/'

lrs = dict()
lrs["VanillaSGLD"] = [0.01, 0.025, 0.05, 0.075, 0.1]
lrs["WenzelSGLD"] = [0.025, 0.05, 0.075, 0.1, 0.25, 0.5]
lrs["pSGLD"] = [0.000075, 0.0001, 0.00025, 0.0005, 0.00075]
lrs["MongeSGLD"] = [0.01, 0.025, 0.05, 0.075, 0.1, 0.25]
lrs["ShampooSGLD"] = [0.00075, 0.001, 0.0025, 0.005, 0.0075, 0.01]

other_args = dict()
other_args["MongeSGLD"] = ["monge_alpha_2=1.0", "monge_alpha_2=0.5", "monge_alpha_2=0.1"]

csv_path = "csvs/mnist_800_gaussian_1.0_100_flat_400_1000_evaluations.csv"

get_metrics(parent_dir, lrs, other_args, csv_path)


{'timestamp': 7959.999999999999,
 'VanillaSGLD': {'best_lr': 0.05,
  'lr': 0.05,
  'lp_ensemble': [-0.1667, 0.0004],
  'acc_ensemble': [0.9587, 0.0003]},
 'WenzelSGLD': {'best_lr': 0.05,
  'lr': 0.05,
  'lp_ensemble': [-0.1668, 0.0004],
  'acc_ensemble': [0.9583, 0.0004]},
 'pSGLD': {'best_lr': 0.0005,
  'lr': 0.0005,
  'lp_ensemble': [-0.1648, 0.0002],
  'acc_ensemble': [0.9612, 0.0003]},
 'MongeSGLD': {'best_lr': '0.01_monge_alpha_2=1.0',
  'lr': '0.01_monge_alpha_2=1.0',
  'lp_ensemble': [nan, nan],
  'acc_ensemble': [nan, nan]},
 'ShampooSGLD': {'best_lr': 0.0025,
  'lr': 0.0025,
  'lp_ensemble': [-0.1635, 0.0003],
  'acc_ensemble': [0.9596, 0.0004]}}

In [7]:
parent_dir = '../results/mnist_800_horseshoe_1.0_100_flat_400_1000_evaluations/'

lrs = dict()
lrs["VanillaSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["WenzelSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["pSGLD"] = [0.0001, 0.00025, 0.0005, 0.00075, 0.001]
lrs["MongeSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["ShampooSGLD"] = [0.001, 0.0025, 0.005, 0.0075, 0.01]

other_args = dict()
other_args["MongeSGLD"] = ["monge_alpha_2=1.0", "monge_alpha_2=0.75", "monge_alpha_2=0.5", "monge_alpha_2=0.25", "monge_alpha_2=0.1"]

csv_path = "csvs/mnist_800_horseshoe_1.0_100_flat_400_1000_evaluations.csv"

get_metrics(parent_dir, lrs, other_args, csv_path)


{'timestamp': 9920.0,
 'VanillaSGLD': {'best_lr': 0.5,
  'lr': 0.5,
  'lp_ensemble': [-0.0798, 0.0003],
  'acc_ensemble': [0.9818, 0.0004]},
 'WenzelSGLD': {'best_lr': 0.5,
  'lr': 0.5,
  'lp_ensemble': [-0.0842, 0.0008],
  'acc_ensemble': [0.9787, 0.0002]},
 'pSGLD': {'best_lr': 0.0005,
  'lr': 0.0005,
  'lp_ensemble': [-0.0656, 0.0013],
  'acc_ensemble': [0.9801, 0.0007]},
 'MongeSGLD': {'best_lr': '0.25_monge_alpha_2=0.5',
  'lr': '0.25_monge_alpha_2=0.5',
  'lp_ensemble': [-0.0628, 0.001],
  'acc_ensemble': [0.9835, 0.0007]},
 'ShampooSGLD': {'best_lr': 0.005,
  'lr': 0.005,
  'lp_ensemble': [-0.0663, 0.0009],
  'acc_ensemble': [0.9819, 0.0006]}}

In [8]:
parent_dir = '../results/mnist_1200_gaussian_1.0_100_flat_400_1000_evaluations/'

lrs = dict()
lrs["VanillaSGLD"] = [0.01, 0.025, 0.05, 0.075, 0.1]
lrs["WenzelSGLD"] = [0.025, 0.05, 0.075, 0.1, 0.25, 0.5]
lrs["pSGLD"] = [0.000075, 0.0001, 0.00025, 0.0005, 0.00075]
lrs["MongeSGLD"] = [0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0]
lrs["ShampooSGLD"] = [0.00075, 0.001, 0.0025, 0.005, 0.0075, 0.01, 0.025]

other_args = dict()
other_args["MongeSGLD"] = ["monge_alpha_2=1.0", "monge_alpha_2=0.75", "monge_alpha_2=0.5", "monge_alpha_2=0.25", "monge_alpha_2=0.1"]

csv_path = "csvs/mnist_1200_gaussian_1.0_100_flat_400_1000_evaluations.csv"

get_metrics(parent_dir, lrs, other_args, csv_path)


{'timestamp': 17400.0,
 'VanillaSGLD': {'best_lr': 0.025,
  'lr': 0.025,
  'lp_ensemble': [-0.1932, 0.0004],
  'acc_ensemble': [0.9514, 0.0004]},
 'WenzelSGLD': {'best_lr': 0.05,
  'lr': 0.05,
  'lp_ensemble': [-0.1935, 0.0004],
  'acc_ensemble': [0.9519, 0.0003]},
 'pSGLD': {'best_lr': 0.0005,
  'lr': 0.0005,
  'lp_ensemble': [-0.1894, 0.0006],
  'acc_ensemble': [0.9566, 0.0003]},
 'MongeSGLD': {'best_lr': '0.025_monge_alpha_2=0.75',
  'lr': '0.025_monge_alpha_2=0.75',
  'lp_ensemble': [-0.1785, 0.0005],
  'acc_ensemble': [0.956, 0.0005]},
 'ShampooSGLD': {'best_lr': 0.0025,
  'lr': 0.0025,
  'lp_ensemble': [-0.1878, 0.0003],
  'acc_ensemble': [0.9541, 0.0003]}}

In [9]:
parent_dir = '../results/mnist_1200_horseshoe_1.0_100_flat_400_1000_evaluations/'

lrs = dict()
lrs["VanillaSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["WenzelSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["pSGLD"] = [0.0001, 0.00025, 0.0005, 0.00075, 0.001]
lrs["MongeSGLD"] = [0.075, 0.1, 0.25, 0.5]
lrs["ShampooSGLD"] = [0.00075, 0.001, 0.0025, 0.005, 0.0075, 0.01]

other_args = dict()
other_args["MongeSGLD"] = ["monge_alpha_2=0.1", "monge_alpha_2=0.075", "monge_alpha_2=0.05"]

csv_path = "csvs/mnist_1200_horseshoe_1.0_100_flat_400_1000_evaluations.csv"

get_metrics(parent_dir, lrs, other_args, csv_path)


{'timestamp': 22520.0,
 'VanillaSGLD': {'best_lr': 0.25,
  'lr': 0.25,
  'lp_ensemble': [-0.0809, 0.0004],
  'acc_ensemble': [0.9814, 0.0002]},
 'WenzelSGLD': {'best_lr': 0.25,
  'lr': 0.25,
  'lp_ensemble': [-0.1009, 0.0009],
  'acc_ensemble': [0.9745, 0.0005]},
 'pSGLD': {'best_lr': 0.0005,
  'lr': 0.0005,
  'lp_ensemble': [-0.063, 0.0011],
  'acc_ensemble': [0.9808, 0.0007]},
 'MongeSGLD': {'best_lr': '0.25_monge_alpha_2=0.075',
  'lr': '0.25_monge_alpha_2=0.075',
  'lp_ensemble': [-0.0699, 0.0005],
  'acc_ensemble': [0.9832, 0.0005]},
 'ShampooSGLD': {'best_lr': 0.005,
  'lr': 0.005,
  'lp_ensemble': [-0.0599, 0.0005],
  'acc_ensemble': [0.9831, 0.0003]}}

In [10]:
parent_dir = '../results/cifar10_googleresnet_gaussian_1.0_100_flat_400_1000_evaluations/'

lrs = dict()
lrs["VanillaSGLD"] = [0.05, 0.075, 0.1, 0.25, 0.5]
lrs["WenzelSGLD"] = [0.1, 0.25, 0.5, 0.75, 1.0]
lrs["pSGLD"] = [0.00025, 0.0005, 0.00075, 0.001, 0.0025, 0.005]
lrs["MongeSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["ShampooSGLD"] = [0.005, 0.0075, 0.01, 0.025, 0.05]

other_args = dict()
other_args["MongeSGLD"] = ["monge_alpha_2=1.0", "monge_alpha_2=0.5", "monge_alpha_2=0.1"]

csv_path = "csvs/cifar10_googleresnet_gaussian_1.0_100_flat_400_1000_evaluations.csv"

get_metrics(parent_dir, lrs, other_args, csv_path, "Wenzel")


{'timestamp': 3000.0,
 'VanillaSGLD': {'best_lr': 0.1,
  'lr': 0.1,
  'lp_ensemble': [-0.4634, 0.0039],
  'acc_ensemble': [0.8589, 0.0013]},
 'WenzelSGLD': {'best_lr': 0.5,
  'lr': 0.5,
  'lp_ensemble': [-0.4869, 0.0047],
  'acc_ensemble': [0.8533, 0.0023]},
 'pSGLD': {'best_lr': 0.00075,
  'lr': 0.00075,
  'lp_ensemble': [-0.4828, 0.0041],
  'acc_ensemble': [0.8562, 0.0018]},
 'MongeSGLD': {'best_lr': '0.075_monge_alpha_2=1.0',
  'lr': '0.075_monge_alpha_2=1.0',
  'lp_ensemble': [nan, nan],
  'acc_ensemble': [nan, nan]},
 'ShampooSGLD': {'best_lr': 0.025,
  'lr': 0.025,
  'lp_ensemble': [-0.4858, 0.0047],
  'acc_ensemble': [0.8595, 0.0023]}}

In [11]:
parent_dir = '../results/cifar10_correlatedgoogleresnet_convcorrnormal_1.0_100_flat_400_1000_evaluations/'

lrs = dict()
lrs["VanillaSGLD"] = [0.05, 0.075, 0.1, 0.25, 0.5]
lrs["WenzelSGLD"] = [0.25, 0.5, 0.75, 1.0, 1.25]
lrs["pSGLD"] = [0.00025, 0.0005, 0.00075, 0.001, 0.0025]
lrs["MongeSGLD"] = [0.075, 0.1, 0.25, 0.5, 0.75]
lrs["ShampooSGLD"] = [0.0075, 0.01, 0.025, 0.05, 0.075]

other_args = dict()
other_args["MongeSGLD"] = ["monge_alpha_2=1.0", "monge_alpha_2=0.5", "monge_alpha_2=0.1"]

csv_path = "csvs/cifar10_correlatedgoogleresnet_convcorrnormal_1.0_100_flat_400_1000_evaluations.csv"

get_metrics(parent_dir, lrs, other_args, csv_path)


{'timestamp': 4640.0,
 'VanillaSGLD': {'best_lr': 0.1,
  'lr': 0.1,
  'lp_ensemble': [-0.4437, 0.004],
  'acc_ensemble': [0.8641, 0.0025]},
 'WenzelSGLD': {'best_lr': 0.75,
  'lr': 0.75,
  'lp_ensemble': [-0.4624, 0.005],
  'acc_ensemble': [0.8612, 0.002]},
 'pSGLD': {'best_lr': 0.00075,
  'lr': 0.00075,
  'lp_ensemble': [-0.4594, 0.003],
  'acc_ensemble': [0.8631, 0.0026]},
 'MongeSGLD': {'best_lr': '0.075_monge_alpha_2=1.0',
  'lr': '0.075_monge_alpha_2=1.0',
  'lp_ensemble': [nan, nan],
  'acc_ensemble': [nan, nan]},
 'ShampooSGLD': {'best_lr': 0.025,
  'lr': 0.025,
  'lp_ensemble': [-0.4506, 0.0023],
  'acc_ensemble': [0.8697, 0.0015]}}