In [None]:
%load_ext autoreload
%autoreload 2
import sys
import os
import copy
import torch
import numpy as np
import abstract_gradient_training as agt
from models.fully_connected import FullyConnected
from datasets.uci import get_dataloaders

import matplotlib.pyplot as plt
import seaborn as sns
import matplotlib as mpl

In [5]:
# notebook config
sns.set_style('whitegrid')
sns.set_context('poster')

figsize = (5, 5)

In [None]:
# configure the training parameters
batchsize = 10000
nominal_config = agt.AGTConfig(
    fragsize=10000,
    learning_rate=0.05,
    lr_decay=1.0,
    n_epochs=1,
    device="cuda:1",
    forward_bound="interval",
    backward_bound="interval",
    loss="mse",
    log_level="INFO",
    early_stopping=False,
    metadata=f"uci, batchsize={batchsize}",
    bound_kwargs={"interval_matmul": "exact"}
)

# get the data and nn model
torch.manual_seed(0)
dl_train, dl_test = get_dataloaders(batchsize, batchsize, "houseelectric")
model = FullyConnected(11, 1, 50, 1)  # network with 1 hidden layer of 64 neurons

In [28]:
def run_sweep(config, sweep_parameter, sweep_values):
    config = copy.deepcopy(config)

    for val in sweep_values:
        config.__setattr__(sweep_parameter, val)
        conf_hash = config.hash()
        fname = f".results/uci_{conf_hash}.pt"
        if os.path.isfile(fname):
            continue
        training = []
        def log(*args):
            training.append(copy.deepcopy(args))
        config.callback = log
                
        torch.manual_seed(0)
        dl_train, dl_test = get_dataloaders(batchsize, batchsize, "houseelectric")
        param_l, param_n, param_u = agt.poison_certified_training(model, config, dl_train, dl_test)
        worst = [t[0][0] for t in training]
        nominal = [t[0][1] for t in training]
        best = [t[0][2] for t in training]
        torch.save((worst, nominal, best), fname)

In [29]:
def plot_training(config, sweep_parameter, sweep_values, label, ax=None):
    if ax is None:
        f, ax = plt.subplots(figsize=figsize)
        save = True
        ax.set_ylabel("Mean Squared Error")
        ax.set_ylim(0, 0.5)
        ax.set_xlabel("Training Iteration")
    else: 
        save = False
    config = copy.deepcopy(config)

    vals = []
    mse = []
    colors = list(iter(palette))[:len(sweep_values)]
    colors = colors + colors [::-1]
    
    lines = []

    for val in sweep_values:
        config.__setattr__(sweep_parameter, val)
        conf_hash = config.hash()
        fname = f".results/uci_{conf_hash}.pt"
        worst, nominal, best = torch.load(fname)
        lines.extend([worst, best])
    lines.append(nominal)
    x = np.arange(len(lines[0]))
    lines.sort(key=lambda x: x[-1])
    for i in range(len(lines) - 1):
        ax.fill_between(x, lines[i], lines[i+1], color="w", alpha=1.0, lw=0)
        if i < len(sweep_values):
            ax.fill_between(x, lines[i], lines[i+1], color=colors[i], alpha=0.6, label=f'${label}={sweep_values[i]}$', lw=0)
        else:
            ax.fill_between(x, lines[i], lines[i+1], color=colors[i], alpha=0.6, lw=0)
    # ax.fill_between(range(len(nominal)), worst, best, color=color, alpha=1.0, label=f'${label}={config[sweep_parameter]}$')
    ax.plot(nominal, color=nominal_color, lw=2)
    ax.legend(loc="upper right", fontsize="small")
    return vals, mse

In [None]:
sns.set_theme(context="poster", style="whitegrid", font_scale=1.1)
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'

fig, axs = plt.subplots(1, 4, figsize=(20, 3.8), sharey=True, layout='constrained')
axs[0].set_ylim(-0.02, 0.5)
# [ax.set_box_aspect(1) for ax in axs]
[ax.set_xlim(-5, 145) for ax in axs]
palette = sns.color_palette(palette="Dark2", n_colors=8)
nominal_color = next(iter(sns.color_palette(palette="Set1", n_colors=4)))

# palette = sns.color_palette(palette="Greens", n_colors=4)
# palette = sns.color_palette(palette="Set1", n_colors=4)

conf = copy.deepcopy(nominal_config)
conf.epsilon = 0.01
ks = [10000, 5000, 1000]
run_sweep(conf, "k_poison", ks)
plot_training(conf, "k_poison", ks, "n", axs[0])
axs[0].set_title("Feature poisoning $(\epsilon=0.01)$", pad=15)

conf = copy.deepcopy(nominal_config)
conf.k_poison = 1000
epsilons = [0.15, 0.1, 0.05]
run_sweep(conf, "epsilon", epsilons)
plot_training(conf, "epsilon", epsilons, "\epsilon", axs[1])
axs[1].set_title("Feature poisoning $(n=1000)$", pad=15)

conf = copy.deepcopy(nominal_config)
conf.label_epsilon = 0.05
ks = [10000, 5000, 1000]
run_sweep(conf, "label_k_poison", ks)
plot_training(conf, "label_k_poison", ks, "m", axs[2])
axs[2].set_title(r"Label poisoning $(\nu=0.05)$", pad=15)

conf = copy.deepcopy(nominal_config)
conf.label_k_poison = 1000
epsilons = [0.75, 0.5, 0.25]
run_sweep(conf, "label_epsilon", epsilons)
plot_training(conf, "label_epsilon", epsilons, r"\nu", axs[3])
axs[3].set_title("Label poisoning $(m=1000)$", pad=15)

fig.supylabel("MSE + Bounds", fontsize="x-large")
# fig.supxlabel("Training Iteration", y=-0.1, fontsize="x-large")
plt.savefig(f".figures/uci_training_1.pdf", bbox_inches="tight", dpi=300)