In [None]:
import os
os.environ['KERAS_BACKEND'] = 'torch'  # todo: not working for jax

import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

import keras
from keras import ops

import bayesflow as bf

In [None]:
def theta_prior():
    theta = np.random.uniform(-1, 1, 2)
    return dict(theta=theta)

def forward_model(theta):
    alpha = np.random.uniform(-np.pi / 2, np.pi / 2)
    r = np.random.normal(0.1, 0.01)
    x1 = -np.abs(theta[0] + theta[1]) / np.sqrt(2) + r * np.cos(alpha) + 0.25
    x2 = (-theta[0] + theta[1]) / np.sqrt(2) + r * np.sin(alpha)
    return dict(x=np.array([x1, x2]))

simulator = bf.make_simulator([theta_prior, forward_model])

In [None]:
adapter = (
    bf.adapters.Adapter()
    # convert any non-arrays to numpy arrays
    .to_array()
    # convert from numpy's default float64 to deep learning friendly float32
    .convert_dtype("float64", "float32")
    # rename the variables to match the required approximator inputs
    .rename("theta", "inference_variables")
    .rename("x", "inference_conditions")
)

In [None]:
num_training_batches = 512
num_validation_sets = 100
batch_size = 64
epochs = 25

In [None]:
training_data = simulator.sample(num_training_batches * batch_size)
validation_data = simulator.sample(num_validation_sets)

In [None]:
# in 1% of the training steps, we update the EMA
ema_update_every = int(epochs * num_training_batches * 0.01)
print(f"EMA update every {ema_update_every} steps of {num_training_batches* epochs} total training steps ({ema_update_every / (num_training_batches* epochs)*100}%).")

In [None]:
class EMA(keras.callbacks.Callback):
    def __init__(self, update_every, beta=0.9, use_for_validation=False):  # todo: use_for_validation seems to be not working
        super().__init__()
        self.beta = float(beta)
        self.update_every = int(update_every)
        self.use_for_validation = bool(use_for_validation)
        self._shadow = None
        self._backup = None
        self._step = 0
        self._n_vars = 0
        print(f"EMA model update every {update_every} steps.")

    def _snapshot(self, v):
        t = ops.convert_to_tensor(v)
        try:
            t = ops.stop_gradient(t)
        except Exception as e:
            print(e)
        return t

    def _init_slots_from_tv(self, tv):
        self._shadow = [self._snapshot(v) for v in tv]
        self._backup = [None] * len(tv)
        self._n_vars = len(tv)

    def _ensure_slots(self):
        tv = self.model.trainable_variables
        if self._shadow is None or self._n_vars != len(tv):
            self._init_slots_from_tv(tv)
        return tv

    def on_train_begin(self, logs=None):
        self._step = 0
        self._ensure_slots()

    def on_train_batch_end(self, batch, logs=None):
        self._step += 1
        if self._step % self.update_every != 0:
            return
        tv = self._ensure_slots()
        b = self.beta
        new_shadow = []
        for s, v in zip(self._shadow, tv):
            v_now = self._snapshot(v)
            if ops.dtype(s) != ops.dtype(v_now):
                v_now = ops.cast(v_now, ops.dtype(s))
            new_shadow.append(b * s + (1.0 - b) * v_now)
        self._shadow = new_shadow

    def _swap_to_shadow(self):
        tv = self._ensure_slots()
        for i, v in enumerate(tv):
            self._backup[i] = self._snapshot(v)
            w = self._shadow[i]
            if ops.dtype(w) != v.dtype:
                w = ops.cast(w, v.dtype)
            v.assign(w)

    def _swap_from_shadow(self):
        tv = self._ensure_slots()
        for i, v in enumerate(tv):
            v.assign(self._backup[i])
        self._backup = [None] * len(tv)

    def on_test_begin(self, logs=None):
        if not self.use_for_validation:
            return
        self._swap_to_shadow()

    def on_test_end(self, logs=None):
        if not self.use_for_validation:
            return
        self._swap_from_shadow()


In [None]:
ema_cb = EMA(update_every=ema_update_every)

workflow_diffusion_ema = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    #inference_network=bf.networks.DiffusionModel()
    inference_network=bf.networks.CouplingFlow()
)

history_ema = workflow_diffusion_ema.fit_offline(
    data=training_data,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=validation_data,
    callbacks=[ema_cb]
)
#workflow_diffusion_ema.approximator.save(filepath='test.keras')

In [None]:
# assume ema_cb is your callback with _swap_to_shadow and _swap_from_shadow
def save_both(model, ema_cb, path_noema="model_noema.keras", path_ema="model_ema.keras"):
    # save non EMA
    model.save(path_noema)
    # save EMA
    ema_cb._swap_to_shadow()
    try:
        model.save(path_ema)
    finally:
        ema_cb._swap_from_shadow()

save_both(workflow_diffusion_ema.approximator, ema_cb)
model_noema = keras.saving.load_model("model_noema.keras")
model_ema   = keras.saving.load_model("model_ema.keras")

In [None]:
workflow_diffusion_ema.approximator = model_noema
workflow_diffusion_ema.plot_default_diagnostics(test_data=validation_data, num_samples=100, calibration_ecdf_kwargs={"difference": True});

In [None]:
workflow_diffusion_ema.approximator = model_ema
workflow_diffusion_ema.plot_default_diagnostics(test_data=validation_data, num_samples=100, calibration_ecdf_kwargs={"difference": True});

In [None]:
workflow_diffusion = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    #inference_network=bf.networks.DiffusionModel()
    inference_network=bf.networks.CouplingFlow()
)

history = workflow_diffusion.fit_offline(
    data=training_data,
    epochs=epochs,
    batch_size=batch_size,
    validation_data=validation_data,
)

In [None]:
#workflow_diffusion.approximator.save("model.keras")
#workflow_diffusion.approximator = keras.saving.load_model("model.keras")

In [None]:
workflow_diffusion.plot_default_diagnostics(test_data=validation_data, num_samples=100, calibration_ecdf_kwargs={"difference": True});

In [None]:
# Obtain samples from the approximators (can also use the workflows' methods)
workflows = [
    workflow_diffusion,
    workflow_diffusion_ema,
    workflow_diffusion_ema,
]
names = [
    "Standard Model",
    "Model with EMA (disabled)",
    "Model with EMA"
]
colors = ["#153c7a", "#7a1515", "#8c1740"]

In [None]:
# Set the number of posterior draws you want to get
num_samples = 1000

# Obtain samples from amortized posterior
conditions = {"x": np.array([[0.0, 0.0]]).astype("float32")}

# Prepare figure
f, axes = plt.subplots(1, len(workflows), figsize=(12, 6), layout="constrained")

for i, (ax, w, name, color) in enumerate(zip(axes, workflows, names, colors)):
    np.random.seed(0)
    if i == 1:
        w.approximator = model_noema
    elif i == 2:
        w.approximator = model_ema
    print(w.compute_default_diagnostics(test_data=validation_data, num_samples=100))

    # Obtain samples
    samples = w.approximator.sample(conditions=conditions, num_samples=num_samples)["theta"]

    # Plot samples
    ax.scatter(samples[0, :, 0], samples[0, :, 1], color=color, alpha=0.75, s=0.5)
    sns.despine(ax=ax)
    ax.set_title(f"{name}", fontsize=16)
    ax.grid(alpha=0.3)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlim([-0.5, 0.5])
    ax.set_ylim([-0.5, 0.5])
    ax.set_xlabel(r"$\theta_1$", fontsize=15)
    ax.set_ylabel(r"$\theta_2$", fontsize=15)

plt.show()

In [None]:
# Prepare figure
colors = ["#153c7a", "#7a1515", "#8c1740"]
conditions = {"x": np.array([[0.0, 0.0]]).astype("float32")}
num_samples = 1000

f, axes = plt.subplots(1, 2, figsize=(12, 6), layout="constrained")

for i, (ax, name, color) in enumerate(zip(axes, ['Model', 'Loaded Model'], colors)):
    # Obtain samples
    samples = workflow_diffusion.approximator.sample(conditions=conditions, num_samples=num_samples)["theta"]

    # Plot samples
    ax.scatter(samples[0, :, 0], samples[0, :, 1], color=color, alpha=0.75, s=0.5)
    sns.despine(ax=ax)
    ax.set_title(f"{name}", fontsize=16)
    ax.grid(alpha=0.3)
    ax.set_aspect("equal", adjustable="box")
    ax.set_xlim([-0.5, 0.5])
    ax.set_ylim([-0.5, 0.5])
    ax.set_xlabel(r"$\theta_1$", fontsize=15)
    ax.set_ylabel(r"$\theta_2$", fontsize=15)

    workflow_diffusion.approximator.save("model.keras")
    workflow_diffusion.approximator = keras.saving.load_model("model.keras")

plt.show()