In [1]:
import jax
from jax.flatten_util import ravel_pytree
from jax.tree_util import tree_map, tree_structure
import jax.numpy as jnp
import optax

from bde.models.models import Fnn
from bde.training.trainer import FnnTrainer
from bde.bde_builder import BdeBuilder
from bde.viz.plotting import plot_pred_vs_true
from bde.data.dataloader import DataLoader
from bde.data.preprocessor import DataPreProcessor
from bde.sampler.mile_wrapper import MileWrapper
from bde.bde_evaluator import BDEPredictor

from bde.sampler.warmup import warmup_wrapper
from bde.sampler.probabilistic import ProbabilisticModel
from bde.sampler.prior import Prior, PriorDist

from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

from bde.tests.metrics import metrics

import sys
import os

import pandas

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

In [2]:
data = fetch_openml(name="airfoil_self_noise", as_frame=True)

X = data.data.values   # shape (1503, 5)
y = data.target.values.reshape(-1, 1)  # shape (1503, 1)

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42)

# Convert to JAX arrays
X_train = jnp.array(X_train, dtype=jnp.float32)
y_train = jnp.array(y_train, dtype=jnp.float32)
X_test = jnp.array(X_test, dtype=jnp.float32)
y_test = jnp.array(y_test, dtype=jnp.float32)

Xmu, Xstd = jnp.mean(X_train, 0), jnp.std(X_train, 0) + 1e-8
Ymu, Ystd = jnp.mean(y_train, 0), jnp.std(y_train, 0) + 1e-8

Xtr = (X_train - Xmu) / Xstd
Xte = (X_test  - Xmu) / Xstd
ytr = (y_train - Ymu) / Ystd
yte = (y_test  - Ymu) / Ystd

- version 1, status: active
  url: https://www.openml.org/search?type=data&id=43919
- version 8, status: active
  url: https://www.openml.org/search?type=data&id=44957



In [3]:
sizes = [5, 16, 16, 2]

In [4]:
bde = BdeBuilder(
        sizes, 
        n_members=10, 
        epochs=1000, 
        optimizer=optax.adam(1e-4)
        )

In [5]:
bde.fit_members(
        x=Xtr, 
        y=ytr, 
        epochs=1000
        )

0 2.820504665374756
100 0.41043931245803833
200 0.11688908189535141
300 -0.02331928350031376
400 -0.10943841934204102
500 -0.14141306281089783
600 -0.18712210655212402
700 -0.20854142308235168
800 -0.21600838005542755
900 -0.24235045909881592


[<bde.models.models.Fnn at 0x13709fc10>,
 <bde.models.models.Fnn at 0x137049010>,
 <bde.models.models.Fnn at 0x136df9cd0>,
 <bde.models.models.Fnn at 0x136d45e90>,
 <bde.models.models.Fnn at 0x136de6d10>,
 <bde.models.models.Fnn at 0x136d4a590>,
 <bde.models.models.Fnn at 0x1370ce610>,
 <bde.models.models.Fnn at 0x1370b45d0>,
 <bde.models.models.Fnn at 0x1370c8d90>,
 <bde.models.models.Fnn at 0x137035650>]

In [6]:
for m in bde.members:
    print(len(m.params))

3
3
3
3
3
3
3
3
3
3


In [7]:
from functools import partial
from jax.tree_util import tree_map

# 1) Prototype model (same architecture for all members)
prior = PriorDist.STANDARDNORMAL.get_prior()
proto_module = bde.members[0]
model = ProbabilisticModel(module=proto_module,
                           params=proto_module.params,   # just for counting/info
                           prior=prior)

# 2) Correct single-chain log-density
logdensity_fn = partial(model.log_unnormalized_posterior, x=Xtr, y=ytr)

# 3) (Optional) batched helper, if you ever need to score many params at once
logdensity_fn_batched = jax.vmap(logdensity_fn)

# 4) If you need stacked params for vmapped warmup/sampling:
params_list = [m.params for m in bde.members]                 # length E
params_e = tree_map(lambda *p: jnp.stack(p, axis=0), *params_list)  # (E, ...)


In [8]:
warmup = warmup_wrapper(Xtr, ytr, bde)

---Initialize warmup---
---Search for optimal Parameters---
Initial L:  20.049938
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=0.0013885498046875
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=-0.0031890869140625
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=0.0001220703125
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=0.019256591796875
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=0.0002899169921875
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=-0.00035858154296875
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=0.0018758773803710938
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=-0.000701904296875
step 0 | ok=True | step_size=0.004999999888241291 | cap=3.4028234663852886e+38 | dE=-0.00302

In [9]:
warmup.parameters.sqrt_diag_cov.ndim

2

In [10]:
E = len(bde.members)
rng0 = jax.random.PRNGKey(42)
rng_keys_e = jax.random.split(rng0, E)

init_positions_e = warmup.state.position  # pytree with leading axis E

num_samples = 2000
L_e = warmup.parameters.L
step_size_e = warmup.parameters.step_size
sqrt_diag_e = warmup.parameters.sqrt_diag_cov

sampler = MileWrapper(logdensity_fn=logdensity_fn)

positions_eT, infos_eT, states_e = sampler.sample_batched(
    rng_keys_e=rng_keys_e,
    init_positions_e=init_positions_e,
    num_samples=num_samples,
    thinning=1,
    L_e=L_e,
    step_e=step_size_e,
    sqrt_diag_e=sqrt_diag_e,
    store_states=True,
)


In [11]:
# infos is something like MCLMCInfo(logdensity=..., kinetic_change=..., energy_change=...)
ld = jnp.ravel(infos_eT.logdensity)       # log posterior per step, shape (S,)
dE = jnp.ravel(infos_eT.energy_change)    # shape (S,)
kc = jnp.ravel(infos_eT.kinetic_change)   # shape (S,)

print("finite?", bool(jnp.all(jnp.isfinite(ld))),
                  bool(jnp.all(jnp.isfinite(dE))),
                  bool(jnp.all(jnp.isfinite(kc))))
print("dE mean:", float(dE.mean()))
print("dE q05/median/q95:", [float(x) for x in jnp.quantile(dE, jnp.array([.05,.5,.95]))])

finite? True True True
dE mean: 5.654772758483887
dE q05/median/q95: [-12.279135704040527, 0.4513664245605469, 38.77876281738281]


In [None]:
tree_map(lambda a: a[i], self.positions_eT)

In [None]:
preds = BDEPredictor(fnn, positions_eT, Xte)

means, sigmas = preds.get_preds()


In [None]:
y_pred_sampled = means * Ystd + Ymu
y_err_sampled  = sigmas * Ystd
y_true = yte * Ystd + Ymu

yt = np.asarray(y_true).ravel()
yp_sampled = np.asarray(y_pred_sampled).ravel()
ye_sampled = np.asarray(y_err_sampled).ravel()
ye_sampled = np.maximum(ye_sampled, 1e-8)  # guard

fig = plt.figure(figsize=(7,7))
gs  = gridspec.GridSpec(2,1, height_ratios=[4,1], hspace=0.05)

ax = plt.subplot(gs[0])
ax.errorbar(yt, yp_sampled, yerr=ye_sampled, fmt='o', alpha=0.5)
m, M = float(min(yt.min(), yp_sampled.min())), float(max(yt.max(), yp_sampled.max()))
ax.plot([m,M],[m,M], 'r--', lw=1)
ax.set_ylabel("Predicted"); ax.set_title("Airfoil: single FNN"); ax.grid(True)
plt.setp(ax.get_xticklabels(), visible=False)


ax2 = plt.subplot(gs[1], sharex=ax)
pull_sampled = (yp_sampled - yt) / ye_sampled
ax2.axhline(0, color='k', ls='--')
ax2.scatter(yt, pull_sampled, s=10, alpha=0.5)
ax2.set_xlabel("True"); ax2.set_ylabel("Pull"); ax2.set_ylim(-3,3); ax2.grid(True)


In [None]:
fnn.params = initial_params
pred_unsampled = fnn.predict(Xte)
mu_n_unsampled  = pred_unsampled[..., 0:1]
sigma_n_unsampled = jax.nn.sigmoid(pred_unsampled[..., 1:2]) + 1e-6

y_pred_unsampled = mu_n_unsampled * Ystd + Ymu
y_err_unsampled  = sigma_n_unsampled * Ystd

print("y_true shape:", y_true.shape, "y_pred shape:", y_pred_unsampled.shape, "yerr shape:", y_err_unsampled.shape)

yp_unsampled = np.asarray(y_pred_unsampled).ravel()
ye_unsampled = np.asarray(y_err_unsampled).ravel()
ye_unsampled= np.maximum(ye_unsampled, 1e-8)  # guard

acc_unsampled = bde.predictive_accuracy(y=yt, mu=yp_unsampled, sigma=ye_unsampled)
print(acc_unsampled)
print("mean sigma: ", jnp.mean(ye_unsampled))
print("mae: ", jnp.mean(jnp.abs(yt - yp_unsampled)))

fig = plt.figure(figsize=(7,7))
gs  = gridspec.GridSpec(2,1, height_ratios=[4,1], hspace=0.05)

ax = plt.subplot(gs[0])
ax.errorbar(yt, yp_unsampled, yerr=ye_unsampled, fmt='o', alpha=0.5)
m, M = float(min(yt.min(), yp_unsampled.min())), float(max(yt.max(), yp_unsampled.max()))
ax.plot([m,M],[m,M], 'r--', lw=1)
ax.set_ylabel("Predicted"); ax.set_title("Airfoil: single FNN"); ax.grid(True)
plt.setp(ax.get_xticklabels(), visible=False)

ax2 = plt.subplot(gs[1], sharex=ax)
pull_unsampled = (yp_unsampled - yt) / ye_unsampled
ax2.axhline(0, color='k', ls='--')
ax2.scatter(yt, pull_unsampled, s=10, alpha=0.5)
ax2.set_xlabel("True"); ax2.set_ylabel("Pull"); ax2.set_ylim(-3,3); ax2.grid(True)
