In [None]:
import os
os.environ["KERAS_BACKEND"] = "torch"

In [None]:
import keras

if keras.backend.backend() == "torch":
    import torch
    torch.autograd.set_grad_enabled(False)

In [None]:
import bayesflow as bf
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
epochs = 64
steps_per_epoch = 128
batch_size = 256

In [None]:
simulator = bf.simulators.TwoMoonsSimulator()
dataset = bf.datasets.OnlineDataset(simulator, batch_size=batch_size)

Visualize the target distribution with rejection sampling:

In [None]:
# rejection sample returns at least the requested number of samples
samples = simulator.rejection_sample((1024,), lambda s: keras.ops.norm(s["x"], axis=-1) < 0.01, numpy=True)

keras.ops.shape(samples["x"])

In [None]:
samples = {"x": samples["theta"][:, 0], "y": samples["theta"][:, 1]}

In [None]:
samples = {key: keras.ops.convert_to_numpy(value) for key, value in samples.items()}

In [None]:
sns.scatterplot(samples, x="x", y="y", size=1, alpha=0.25, lw=0, legend=False)
plt.xlim((-0.5, 0.5))
plt.ylim((-0.5, 0.5))
plt.gca().set_aspect("equal")
plt.xlabel(r"$\theta_1$")
plt.ylabel(r"$\theta_2$")
plt.title("Target Samples")
plt.show()

In [None]:
inference_network = bf.networks.CouplingFlow()
# inference_network = bf.networks.FlowMatching(subnet="resnet")

In [None]:
approximator = bf.Approximator(
    inference_network=inference_network,
    inference_variables=["theta", "r", "alpha"],
    inference_conditions=["x"],
)

In [None]:
learning_rate = 1e-4

In [None]:
optimizer = keras.optimizers.AdamW(learning_rate=learning_rate, weight_decay=0.0)

In [None]:
approximator.compile(optimizer=optimizer)

In [None]:
fit_history = approximator.fit(dataset, epochs=epochs, steps_per_epoch=steps_per_epoch)

In [None]:
metrics_history = fit_history.history
loss_history = metrics_history["loss"]

In [None]:
sns.lineplot(loss_history)
plt.xlabel("Epoch")
plt.ylabel("Loss")

In [None]:
data = {"x": keras.ops.zeros((256, 2))}
samples = approximator.sample(256, data=data, numpy=True)

In [None]:
samples = {"x": samples["theta"][:, 0], "y": samples["theta"][:, 1]}

In [None]:
sns.scatterplot(samples, x="x", y="y", size=1, alpha=0.25, lw=0, legend=False)
plt.xlim((-0.5, 0.5))
plt.ylim((-0.5, 0.5))
plt.gca().set_aspect("equal")
plt.xlabel(r"$\theta_1$")
plt.ylabel(r"$\theta_2$")
plt.title("Learned Samples")
plt.show()