In [1]:
%load_ext autoreload
%autoreload 2
import os
import copy
import torch
import numpy as np
import abstract_gradient_training as agt
from models.fully_connected import FullyConnected
from datasets.uci import get_dataloaders

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

In [2]:
# notebook config
sns.set_style('whitegrid')
sns.set_context('poster')

figsize = (5, 5)

In [3]:
# configure the training parameters
batchsize_n = 10000
hidden_dim_n = 50
hidden_lay_n = 1
n_iters = 100
seed = 0 
nominal_config = agt.AGTConfig(
    fragsize=10000,
    learning_rate=0.05,
    lr_decay=1.0,
    k_poison=100,
    epsilon=0.01,
    n_epochs=1,
    device="cuda:1",
    forward_bound="interval",
    backward_bound="interval",
    loss="mse",
    log_level="INFO",
    early_stopping=False,
    metadata=f"uci, batchsize={batchsize_n}",
    bound_kwargs={"interval_matmul": "exact"}
)

In [4]:
def run_sweep(config, sweep_parameter, sweep_values):
    config = copy.deepcopy(config)

    for val in sweep_values:
        hidden_lay, hidden_dim, batchsize = hidden_lay_n, hidden_dim_n, batchsize_n
        if sweep_parameter == "learning_rate":
            config.__setattr__(sweep_parameter, val)
        elif sweep_parameter == "batchsize":
            batchsize = val
        elif sweep_parameter == "hidden_dim":
            hidden_dim = val
        elif sweep_parameter == "hidden_lay":
            hidden_lay = val
        config.metadata = f"{seed=}, {batchsize=}, {hidden_dim=}, {hidden_lay=}"   
        conf_hash = config.hash()
        fname = f".results/uci_{conf_hash}.pt"
        if os.path.isfile(fname):
            continue
        training = []
        def log(*args):
            training.append(copy.deepcopy(args))
        config.callback = log
                
        torch.manual_seed(seed)
        model = FullyConnected(11, 1, hidden_dim, hidden_lay)  # network with 1 hidden layer of 64 neurons
        dl_train, dl_test = get_dataloaders(batchsize, batchsize, "houseelectric", n_batches=n_iters)
        param_l, param_n, param_u = agt.poison_certified_training(model, config, dl_train, dl_test)
        worst = [t[0][0] for t in training]
        nominal = [t[0][1] for t in training]
        best = [t[0][2] for t in training]
        torch.save((worst, nominal, best), fname)
        


In [9]:

def plot_training(config, sweep_parameter, sweep_values, label, ax=None):
    if ax is None:
        f, ax = plt.subplots(figsize=figsize)
        save = True
        ax.set_ylabel("Mean Squared Error")
        ax.set_ylim(0, 0.5)
        ax.set_xlabel("Training Iteration")
    else: 
        save = False
    config = copy.deepcopy(config)

    vals = []
    mse = []
    colors = iter(palette)
    
    for val in sweep_values:
        hidden_lay, hidden_dim, batchsize = hidden_lay_n, hidden_dim_n, batchsize_n
        if sweep_parameter == "learning_rate":
            config.__setattr__(sweep_parameter, val)
        elif sweep_parameter == "batchsize":
            batchsize = val
        elif sweep_parameter == "hidden_dim":
            hidden_dim = val
        elif sweep_parameter == "hidden_lay":
            hidden_lay = val
        config.metadata = f"{seed=}, {batchsize=}, {hidden_dim=}, {hidden_lay=}"       
        conf_hash = config.hash() 
        fname = f".results/uci_{conf_hash}.pt"
        worst, nominal, best = torch.load(fname)
        color = next(colors)
        l = f'${label}={val}$'
        ax.fill_between(range(len(nominal)), worst, best, color="#ffffff", alpha=1.0, lw=0)
        ax.fill_between(range(len(nominal)), worst, best, color=color, alpha=0.6, label=l, lw=0)
        # ax.plot(np.array(worst)-np.array(best), color=color, label=f'${label}={config[sweep_parameter]}$')
        ax.plot(nominal, color=color)
    ax.legend(loc="upper right")
    return vals, mse

In [None]:
sns.set_theme(context="poster", style="whitegrid", font_scale=1.1)
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'

fig, axs = plt.subplots(1, 4, figsize=(20, 4.0), sharey=True, layout='constrained')
plt.ylim(0, 0.5)

palette = sns.color_palette(palette="Dark2", n_colors=4)
[ax.set_xlim(-5, 100) for ax in axs]

conf = copy.deepcopy(nominal_config)
hls = [3, 2, 1]
run_sweep(conf, "hidden_lay", hls)
plot_training(conf, "hidden_lay", hls, "d", axs[0])
axs[0].set_title("Depth $(b)$", pad=15)

conf = copy.deepcopy(nominal_config)
hds = [400, 300, 100]
run_sweep(conf, "hidden_dim", hds)
plot_training(conf, "hidden_dim", hds, "w", axs[1])
axs[1].set_title("Width $(w)$", pad=15)

conf = copy.deepcopy(nominal_config)
batchsizes = [100, 1000, 10000]
run_sweep(conf, "batchsize", batchsizes)
plot_training(conf, "batchsize", batchsizes, "b", axs[2])
axs[2].set_title("Batch Size $(b)$", pad=15)
# axs[2].set_xlim(0, 183)

conf = copy.deepcopy(nominal_config)
learning_rates = [1e-1, 5e-2, 2e-2]
run_sweep(conf, "learning_rate", learning_rates)
plot_training(conf, "learning_rate", learning_rates, "\\alpha", axs[3])
axs[3].set_title("Learning Rate $(\\alpha)$", pad=15)

fig.supylabel("MSE + Bounds", ha="center", fontsize="x-large")
fig.supxlabel("Training Iteration", y=-0.15, fontsize="x-large")
plt.savefig(f".figures/uci_training_2.pdf", bbox_inches="tight", dpi=300)