In [1]:
import sys
import os

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=8"

from bde.bde import BdeRegressor, BdeClassifier
from bde.task import TaskType
from bde.loss.loss import *
from sklearn.datasets import fetch_openml, load_iris
from sklearn.model_selection import train_test_split
import jax.numpy as jnp

from bde.viz.plotting import plot_pred_vs_true, plot_confusion_matrix, plot_reliability_curve, plot_roc_curve

data = fetch_openml(name="airfoil_self_noise", as_frame=True)

X = data.data.values  # shape (1503, 5)
y = data.target.values.reshape(-1, 1)  # shape (1503, 1)

X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42)

    # Convert to JAX arrays
X_train = jnp.array(X_train, dtype=jnp.float32)
y_train = jnp.array(y_train, dtype=jnp.float32)
X_test = jnp.array(X_test, dtype=jnp.float32)
y_test = jnp.array(y_test, dtype=jnp.float32)

Xmu, Xstd = jnp.mean(X_train, 0), jnp.std(X_train, 0) + 1e-8
Ymu, Ystd = jnp.mean(y_train, 0), jnp.std(y_train, 0) + 1e-8

Xtr = (X_train - Xmu) / Xstd
Xte = (X_test - Xmu) / Xstd
ytr = (y_train - Ymu) / Ystd
yte = (y_test - Ymu) / Ystd

sizes = [5, 16, 16, 2]

- version 1, status: active
  url: https://www.openml.org/search?type=data&id=43919
- version 8, status: active
  url: https://www.openml.org/search?type=data&id=44957



In [2]:
regressor = BdeRegressor(
        hidden_layers=[16, 16],
        n_members=11,
        seed=0,
        loss=GaussianNLL(),
        epochs=100,
        lr=1e-3,
        warmup_steps=500,
        n_samples=100,
        n_thinning=10,
    )

regressor.fit(x=Xtr, y=ytr)

means, sigmas = regressor.predict(Xte, mean_and_std=True)


backend: cpu
devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
local_device_count: 8
Kernel devices: 8
0 2.034358024597168


MCLMC warmup:   0%|          | 0/500 [00:00<?, ?it/s]


step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0001220703125
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=-0.0001220703125
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=-0.0001220703125
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0001220703125
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=-0.000244140625
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0001220703125
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 1 | ok=True | step_si

In [4]:
print("RSME: ", jnp.sqrt(jnp.mean((means - yte) ** 2)))
plot_pred_vs_true(
    y_pred=means,
    y_true=yte,
    y_pred_err=sigmas,
    title="trial",
    savepath="plots_regression"
)

mean, intervals = regressor.predict(Xte, credible_intervals=[0.9, 0.95])

print("Credible intervals shape:", intervals.shape)  # (len(q), N)

    # for plotting, pick the 95% interval
lower = intervals[0]  # q=0.9 or 0.95 depending on order
upper = intervals[1]  # if you asked for 2 quantiles

plot_pred_vs_true(
    y_pred=means,
    y_true=yte,
    y_pred_err=(upper - lower) / 2,  # approx half-width as "sigma"
    title="trial_with_intervals",
    savepath="plots_regression",
)

score = regressor.score(Xtr, ytr)
print(f"the sklearn score is {score}")

RSME:  1.4352245
Credible intervals shape: (2, 301)
the sklearn score is 0.9414751529693604


In [6]:
iris = load_iris()
X = iris.data.astype("float32")
y = iris.target.astype("int32")  # 0, 1, 2

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Convert to JAX
Xtr, Xte = jnp.array(X_train), jnp.array(X_test)
ytr, yte = jnp.array(y_train), jnp.array(y_test)

classifier = BdeClassifier(
    n_members=5,
    hidden_layers=[16, 16],
    seed=0,
    loss=CategoricalCrossEntropy(),
    activation="relu",
    epochs=50,
    lr=1e-3,
    warmup_steps=200,
    n_samples=50,
    n_thinning=5
)

classifier.fit(x=Xtr, y=ytr)

preds = classifier.predict(Xte)
probs = classifier.predict_proba(Xte)
print("Predicted class probabilities:\n", probs)
print("Predicted class labels:\n", preds)
print("True labels:\n", yte)

savepath = "plots_classification"
classes = list(range(3))

backend: cpu
devices: [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
local_device_count: 8
Kernel devices: 8
0 1.643818736076355


MCLMC warmup:   0%|          | 0/200 [00:01<?, ?it/s]


step 0 | ok=True | step_size=0.0010000000474974513 | dE=-3.0517578125e-05
step 0 | ok=True | step_size=0.0010000000474974513 | dE=0.0
step 0 | ok=True | step_size=0.0010000000474974513 | dE=-3.0517578125e-05
step 0 | ok=True | step_size=0.0010000000474974513 | dE=-3.0517578125e-05
step 0 | ok=True | step_size=0.0010000000474974513 | dE=-3.0517578125e-05
step 0 | ok=True | step_size=0.0010000000474974513 | dE=6.103515625e-05
step 1 | ok=True | step_size=0.021542686969041824 | dE=-0.000152587890625
step 0 | ok=True | step_size=0.0010000000474974513 | dE=3.0517578125e-05
step 1 | ok=True | step_size=0.021542686969041824 | dE=3.0517578125e-05
step 1 | ok=True | step_size=0.021537713706493378 | dE=-6.103515625e-05
step 1 | ok=True | step_size=0.021542686969041824 | dE=-0.001922607421875
step 0 | ok=True | step_size=0.0010000000474974513 | dE=-3.0517578125e-05
step 1 | ok=True | step_size=0.02154434472322464 | dE=-9.1552734375e-05
step 2 | ok=True | step_size=0.02421579882502556 | dE=0.00042