## Imports and dataset

In [None]:
from src.models.optimized_bnn import BayesianNN, BayesianNNVI
from src.utils2 import plot_ppds

from src.attacks.distr_attacks_bnn_jax import fgsm_attack, mlmc_attack, kl_to_appd

import numpy as np
import pandas as pd
import jax
import jax.numpy as jnp
import numpyro

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from tqdm import tqdm
from joblib import Parallel, delayed

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import seaborn as sns
import matplotlib.pyplot as plt

sns.set_theme(style="whitegrid", palette="muted", font="serif")

sns.set_context("notebook", font_scale=1.5, rc={"lines.linewidth": 2.5})
plt.rcParams.update({
    'axes.titlesize': 18,
    'axes.labelsize': 14,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'axes.titleweight': 'bold',
    'axes.edgecolor': 'black',
    'axes.linewidth': 1,
    'grid.alpha': 0.5,
    'grid.linestyle': '--',
    'legend.fontsize': 12,
    'legend.frameon': False,
    'figure.dpi': 300,  
})

In [3]:
import numpyro
numpyro.set_host_device_count(8)

In [4]:
# set all seeds for reproducibility
seed = 42
np.random.seed(seed)

# Wine dataset

In [5]:
# Wine Quality dataset with response in {3, 4, 5, 6, 7, 8}, 11 features and 4898 samples

# URL to the Wine Quality dataset (for example, from UCI Machine Learning Repository)
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/wine-quality/winequality-white.csv"
# Load the dataset directly into a Pandas DataFrame
data = pd.read_csv(url, delimiter=";")

X = data.iloc[:, :-1].values
y = data.iloc[:, -1].values

In [6]:
# Normalize the data
scaler = MinMaxScaler()
X_normalized = scaler.fit_transform(X)

X_train, X_test, y_train, y_test = train_test_split(X_normalized, y, test_size=0.3, random_state=42)

In [7]:
# Fit the model
model = BayesianNN(input_dim=X_train.shape[1], hidden_units=3)
try:
    model.load('../src/models/3bnn')
except FileNotFoundError:   
    print('Not model found, fitting the model') 
    model.fit(
        X_train, 
        y_train,
        num_warmup=500,
        num_samples=1000,
        num_chains=8,
    )
    model.save('../src/models/3bnn')

In [8]:
modelVI = BayesianNNVI(input_dim=X_train.shape[1], hidden_units=3)
try:
    modelVI.load('../src/models/3bnnvi')
except FileNotFoundError:
    print('Not model found, fitting the model') 
    modelVI.fit(
        X_train, 
        y_train, 
        num_steps=5000
    )
    modelVI.save('../src/models/3bnnvi')

In [37]:
# Attack example
x = X_test[17,:].copy().reshape(1, -1)
x = jnp.array(x)
std = modelVI.sample_predictive_distribution(x, 1000).std()
mu = modelVI.sample_predictive_distribution(x, 1000).mean()
appd = numpyro.distributions.Normal(2 + mu, 2 * std)
x_adv_distr, x_adv_values = mlmc_attack(model, x, appd=appd, epsilon=2, R=20, lr=0.005, n_iter=1000)

In [82]:
epsilons = [0, 0.2, 0.5]
results = []
for i in tqdm(range(10)):
    res_it_mm = []
    res_it_mvi = []
    res_it_vim = []
    res_it_vivi = []
    for epsilon in epsilons:
        def compute_kl(model, model_eval, i):
            x = X_test[i,:].copy().reshape(1, -1)
            x = jnp.array(x)
            std = model.sample_predictive_distribution(x, 1000).std()
            mu = model.sample_predictive_distribution(x, 1000).mean()
            appd = numpyro.distributions.Normal(2 + mu, 2 * std)
            if epsilon == 0:
                x_adv = x.clone()
            else:
                x_adv, _ = mlmc_attack(model, x, appd, epsilon=epsilon, verbose=False, R=20, lr=0.001, n_iter=1000)
            adv_std = model_eval.sample_predictive_distribution(x_adv, 1000).std()
            adv_mu = model_eval.sample_predictive_distribution(x_adv, 1000).mean()
            att_kl = kl_to_appd(adv_mu, adv_std ** 2, 2 + mu, 4 * std ** 2)
            return att_kl, abs(adv_mu - 2 - mu), adv_std / (2 * std)

        n = 40
        kl_values_mm = Parallel(n_jobs=-1)(delayed(compute_kl)(model, model, j) for j in range(n))
        kl_mm = np.array(kl_values_mm).sum(axis=0)
        res_it_mm.append(kl_mm / n)
        kl_values_mvi = Parallel(n_jobs=-1)(delayed(compute_kl)(model, modelVI, j) for j in range(n))
        kl_mvi = np.array(kl_values_mvi).sum(axis=0)
        res_it_mvi.append(kl_mvi / n)
        kl_values_vim = Parallel(n_jobs=-1)(delayed(compute_kl)(modelVI, model, j) for j in range(n))
        kl_vim = np.array(kl_values_vim).sum(axis=0)
        res_it_vim.append(kl_vim / n)
        kl_values_vivi = Parallel(n_jobs=-1)(delayed(compute_kl)(modelVI, modelVI, j) for j in range(n))
        kl_vivi = np.array(kl_values_vivi).sum(axis=0)
        res_it_vivi.append(kl_vivi / n)
    res_it = {
        'mm': res_it_mm,
        'mvi': res_it_mvi,
        'vim': res_it_vim,
        'vivi': res_it_vivi
    }
    results.append(res_it)

  0%|          | 0/10 [00:00<?, ?it/s]python(47305) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47306) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47307) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47308) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47309) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47310) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47311) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47312) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47313) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
python(47314) MallocStackLogging: can't turn off malloc stack logging because it

In [84]:
results_mm = np.array([res['mm'] for res in results]).mean(axis=0)
results_mvi = np.array([res['mvi'] for res in results]).mean(axis=0)
results_vim = np.array([res['vim'] for res in results]).mean(axis=0)
results_vivi = np.array([res['vivi'] for res in results]).mean(axis=0) 
for i in range(3):
    print(['KL', 'Mean', 'Std'][i] + ':')
    print(results_mm[:, i], '\n', results_mvi[:, i], '\n', results_vim[:, i], '\n', results_vivi[:, i])
    print('---')

KL:
[4.688255   1.0635079  0.65986836] 
 [4.3347163 2.8120158 2.0422614] 
 [4.806561  1.7637587 0.8876586] 
 [4.3282533  2.339662   0.88005793]
---
Mean:
[2.        0.6645841 0.7218342] 
 [2.0237563 1.5307654 1.2116525] 
 [1.9762433  0.92916477 0.63984734] 
 [2.        1.3177689 0.2946269]
---
Std:
[0.5        0.5306341  0.64793974] 
 [0.5249086  0.5253553  0.52854174] 
 [0.47630548 0.48899323 0.5590891 ] 
 [0.5        0.49966702 0.50055546]
---
