# Visualising Score & Schedules of the Diffusion Model


In [None]:
import os

if "KERAS_BACKEND" not in os.environ:
    os.environ["KERAS_BACKEND"] = "jax"
else:
    print(f"Using '{os.environ['KERAS_BACKEND']}' backend")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import keras
import bayesflow as bf

## Simulator<a class="anchor" id="simulator"></a>

In [None]:
def theta_prior():
    theta = np.random.uniform(-1, 1, 2)
    return dict(theta=theta)

def forward_model(theta):
    alpha = np.random.uniform(-np.pi / 2, np.pi / 2)
    r = np.random.normal(0.1, 0.01)
    x1 = -np.abs(theta[0] + theta[1]) / np.sqrt(2) + r * np.cos(alpha) + 0.25
    x2 = (-theta[0] + theta[1]) / np.sqrt(2) + r * np.sin(alpha)
    return dict(x=np.array([x1, x2]))

In [None]:
simulator = bf.make_simulator([theta_prior, forward_model])

In [None]:
adapter = (
    bf.adapters.Adapter()
    .to_array()
    .convert_dtype("float64", "float32")
    .rename("theta", "inference_variables")
    .rename("x", "inference_conditions")
)
adapter

# Training

In [None]:
num_training_batches = 512
num_validation_sets = 300
batch_size = 64
epochs = 50

In [None]:
training_data = simulator.sample(num_training_batches * batch_size)
validation_data = simulator.sample(num_validation_sets)

In [None]:
noise_schedule = ['cosine', 'edm'][1]
diffusion_model = bf.networks.DiffusionModel(
    noise_schedule=noise_schedule
)

In [None]:
diffusion_model_workflow = bf.BasicWorkflow(
    simulator=simulator,
    adapter=adapter,
    inference_network=diffusion_model,
)

In [None]:
history = diffusion_model_workflow.fit_offline(
    training_data, 
    epochs=epochs,
    batch_size=batch_size, 
    validation_data=validation_data,
)

In [None]:
%%time
samples_s = diffusion_model_workflow.sample(num_samples=1000, method="euler_maruyama",
                                           conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)})
plt.scatter(samples_s["theta"][0, :, 0], samples_s["theta"][0, :, 1], alpha=0.75, s=0.5, label="euler maruyama")
plt.gca().set_aspect("equal", adjustable="box")
plt.xlim([-0.5, 0.5])
plt.ylim([-0.5, 0.5])

samples = diffusion_model_workflow.sample(num_samples=1000,method="euler",
                                          conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)})
plt.scatter(samples["theta"][0, :, 0], samples["theta"][0, :, 1], alpha=0.75, s=0.5, label="euler")
plt.xlim([-0.5, 0.5])
plt.ylim([-0.5, 0.5])
plt.legend()
plt.show()

# Visualizing the Trajectory

In [None]:
def euler_backward_like(workflow, conditions, t_start=1.0, t_end=0.0, steps=100, num_samples=1, stochastic_solver=False):
    # conditions must always have shape (batch_size, ..., dims)
    conditions_prep = diffusion_model_workflow.approximator._prepare_data(conditions)['inference_conditions']
    batch_size = keras.ops.shape(conditions_prep)[0]
    inference_conditions = keras.ops.expand_dims(conditions_prep, axis=1)
    inference_conditions = keras.ops.broadcast_to(
                    inference_conditions, (batch_size, num_samples, *keras.ops.shape(inference_conditions)[2:])
    )

    dt = (t_end - t_start) / steps  # negative if integrating toward 0
    x = diffusion_model_workflow.inference_network.base_distribution.sample((1, num_samples))
    t = float(t_start)

    traj = []
    vels = []
    for k in range(steps):
        traj.append(x.numpy())
        v_curr = workflow.inference_network.velocity(
            xz=x, time=t, conditions=inference_conditions, stochastic_solver=stochastic_solver, training=False
        )
        if stochastic_solver:
            diff_curr = workflow.inference_network.diffusion_term(
                xz=x, time=t, training=False
            )
            noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x)) * np.sqrt(np.abs(dt))
            x = x + diff_curr * noise

        x = x + dt * v_curr
        t = t + dt
        vels.append(v_curr.numpy())

    traj = np.stack(traj, axis=0)      # shape [steps+1, batch, num_samples, dims]
    vels = np.stack(vels, axis=0)      # shape [steps,   batch, num_samples, dims]
    times = np.linspace(t_start, t_end, steps+1, dtype=np.float32)

    traj =  workflow.approximator.standardize_layers["inference_variables"](traj, forward=False)
    return traj, vels, times

# run
traj, vels, times = euler_backward_like(
    diffusion_model_workflow, conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)}, steps=100, stochastic_solver=False
)
traj, vels = traj[:, 0, 0], vels[:, 0, 0]  # take first batch item

# pick first item in batch and first sample
vel_norm = np.linalg.norm(vels, axis=-1)

# trajectory plot
plt.figure()
plt.plot(traj[:, 0], traj[:, 1], linewidth=2)
plt.scatter(traj[0, 0], traj[0, 1], s=60, marker='o', label='start')   # start
plt.scatter(traj[-1, 0], traj[-1, 1], s=60, marker='x', label='end') # end
plt.title("Trajectory")
plt.xlabel("x")
plt.ylabel("y")
plt.axis('equal')
plt.legend()
plt.show()

# velocity over time
plt.figure()
plt.plot(times[:-1], vel_norm, linewidth=2)
plt.title("Velocity norm over time")
plt.xlabel("time")
plt.ylabel("||v||")
plt.show()

In [None]:
plt.figure()
for i in range(20):
    traj, vels, times = euler_backward_like(
        diffusion_model_workflow, conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)}, steps=50, stochastic_solver=False
    )
    traj, vels = traj[:, 0, 0], vels[:, 0, 0]  # take first batch item

    # trajectory plot
    plt.plot(traj[:, 0], traj[:, 1], linewidth=2, color='blue')
    plt.scatter(traj[0, 0], traj[0, 1], s=60, marker='x', label='start' if i == 0 else None, color='blue', alpha=0.5) # start
    #plt.scatter(traj[-1, 0], traj[-1, 1], s=60, marker='o', label='end' if i == 0 else None, color='blue', alpha=0.5) # end

plt.scatter(samples["theta"][0, :, 0], samples["theta"][0, :, 1], alpha=0.5, s=0.5, color='red', label='samples')
plt.title(f"Trajectory {noise_schedule.title()}")
plt.xlabel("x")
plt.ylabel("y")
plt.axis('equal')
plt.legend()
plt.show()

In [None]:
def moving_average(x, w=10):
    return np.convolve(x, np.ones(w)/w, mode="valid")



plt.figure()
for i in range(10):
    traj, vels, times = euler_backward_like(
        diffusion_model_workflow, conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)}, steps=50, stochastic_solver=True
    )
    traj, vels = traj[:, 0, 0], vels[:, 0, 0]  # take first batch item

    # trajectory plot
    #plt.plot(traj[:, 0], traj[:, 1], linewidth=2, color='blue')
    #plt.scatter(traj[0, 0], traj[0, 1], s=60, marker='o', label='start' if i == 0 else None, color='blue', alpha=0.2) # start
    #plt.scatter(traj[-1, 0], traj[-1, 1], s=60, marker='x', label='end' if i == 0 else None, color='blue', alpha=0.2) # end
    x_smooth = moving_average(traj[:,0])
    y_smooth = moving_average(traj[:,1])
    plt.plot(x_smooth, y_smooth, linewidth=2, color='blue')
    plt.scatter(x_smooth[0], y_smooth[0], s=60, marker='x', label='start' if i == 0 else None, color='blue', alpha=0.2) # start
    #plt.scatter(x_smooth[-1], y_smooth[-1], s=60, marker='o', label='end' if i == 0 else None, color='blue', alpha=0.2) # end

plt.scatter(samples_s["theta"][0, :, 0], samples_s["theta"][0, :, 1], alpha=0.2, s=0.5, color='red', label='samples')
plt.title(f"Trajectory {noise_schedule.title()}, stochastic sampler (moving average)")
plt.xlabel("x")
plt.ylabel("y")
plt.axis('equal')
plt.legend()
plt.show()

In [None]:
def velocity_field_plot(workflow, conditions, times, traj, stochastic_solver, grid_limits=(-3,3), grid_points=20, name=None):
    # grid
    x = np.linspace(grid_limits[0], grid_limits[1], grid_points)
    y = np.linspace(grid_limits[0], grid_limits[1], grid_points)
    X, Y = np.meshgrid(x, y)
    grid = np.stack([X, Y], axis=-1)  # [grid_points, grid_points, 2]
    grid = grid.reshape(-1, 2)[None]  # [1, grid_points*grid_points, 2]

    grid_transf = workflow.approximator.standardize_layers["inference_variables"](grid.reshape(-1, 2), forward=False)
    grid_transf = np.asarray(grid_transf)
    XY_transf = grid_transf.reshape(grid_points, grid_points, 2) # [G, G, 2]
    X_transf = XY_transf[..., 0]
    Y_transf = XY_transf[..., 1]

    # conditions must always have shape (batch_size, ..., dims)
    conditions_prep = diffusion_model_workflow.approximator._prepare_data(conditions)['inference_conditions']
    batch_size = keras.ops.shape(conditions_prep)[0]
    inference_conditions = keras.ops.expand_dims(conditions_prep, axis=1)
    inference_conditions = keras.ops.broadcast_to(
                    inference_conditions, (batch_size, grid_points*grid_points, *keras.ops.shape(inference_conditions)[2:])
    )

    fig, axes = plt.subplots(1, len(times), figsize=(5*len(times), 5), layout='constrained', sharex=True, sharey=True, squeeze=False)

    for i, t in enumerate(times):
        # expand to shape [batch, num_samples, dim]
        v = workflow.inference_network.velocity(
            xz=grid, time=float(t), conditions=inference_conditions, stochastic_solver=stochastic_solver, training=False
        ).numpy()
        v = v[0, :, :2]  # [num_points, 2]

        U = v[:,0].reshape(grid_points, grid_points)
        V = v[:,1].reshape(grid_points, grid_points)

        ax = axes[0, i]
        ax.quiver(X_transf, Y_transf, U, -V, angles="xy", label="velocity field", alpha=0.5)
        ax.plot(traj[:,0], traj[:,1], color="red", linewidth=2, label="trajectory")
        ax.scatter(traj[0,0], traj[0,1], color="green", s=60, label="start")
        ax.scatter(traj[-1,0], traj[-1,1], color="black", s=60, label="end")

        if name is not None:
            ax.set_title(f"{name} time={t:.2f}")
        else:
            ax.set_title(f"time={t:.2f}")
        ax.set_xlim((X_transf.min(), X_transf.max()))
        ax.set_ylim((Y_transf.min(), Y_transf.max()))
        ax.legend()
    plt.show()

In [None]:
traj, vels, times = euler_backward_like(
    diffusion_model_workflow, conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)}, steps=100, stochastic_solver=False
)
traj, vels = traj[:, 0, 0], vels[:, 0, 0]  # take first batch item
velocity_field_plot(
    diffusion_model_workflow,
    conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)},
    times=[1.0, 0.5, 0.0],
    traj=traj,
    stochastic_solver=False,
    name=f"deterministic sampling, {noise_schedule},"
)

traj, vels, times = euler_backward_like(
    diffusion_model_workflow, conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)}, steps=100, stochastic_solver=True
)
traj, vels = traj[:, 0, 0], vels[:, 0, 0]  # take first batch item
velocity_field_plot(
    diffusion_model_workflow,
    conditions={"x":np.array([[0.0, 0.0]], dtype=np.float32)},
    times=[1.0, 0.5, 0.0],
    traj=traj,
    stochastic_solver=True,
    name=f"stochastic sampling, {noise_schedule},"
)

In [None]:
from bayesflow.networks.diffusion_model.schedules import EDMNoiseSchedule, CosineNoiseSchedule, NoiseSchedule

In [None]:
class FlowMatching(NoiseSchedule):
    def __init__(self):
        super().__init__(name="Flow Matching Schedule", variance_type="preserving", weighting=None)

    def get_log_snr(self, t, training):
        """Get the log signal-to-noise ratio (lambda) for a given diffusion time."""
        return 2 * keras.ops.log((1-t)/t)

    def get_t_from_log_snr(self, log_snr_t, training: bool):
        raise NotImplementedError

    def derivative_log_snr(self, log_snr_t, training):
        raise NotImplementedError

    def get_weights_for_snr(self, log_snr_t):
        return 1 + keras.ops.exp(-log_snr_t) + 2*keras.ops.exp(-log_snr_t / 2)

In [None]:
edm = EDMNoiseSchedule()
cosine = CosineNoiseSchedule()
edm.name = "EDM Schedule"
cosine.name = "Cosine Schedule"
fm = FlowMatching()

time = keras.ops.linspace(0.0, 1.0, 10000)
colors = ["blue", "orange", "green"]
schedules = [edm, cosine, fm]

In [None]:
plt.figure(figsize=(6,3), layout='constrained')
for i, schedule in enumerate(schedules):
    training_schedule = schedule.get_log_snr(time, training=True).numpy()
    inference_schedule = schedule.get_log_snr(time, training=False).numpy()

    if (training_schedule != inference_schedule).all():
        plt.plot(time, training_schedule, label=f"{schedule.name} Training", color=colors[i])
        plt.plot(time, inference_schedule, label=f"{schedule.name} Inference", linestyle="--", color=colors[i])
    else:
        plt.plot(time, training_schedule, label=f"{schedule.name} Training & Inference", color=colors[i])
plt.legend()
plt.ylabel(r"log SNR $\lambda$")
plt.xlabel("time")
plt.show()

In [None]:
fig, ax = plt.subplots(ncols=2, figsize=(6,3), layout='constrained', sharey=True, sharex=True)
for i, schedule in enumerate(schedules):
    training_schedule = schedule.get_log_snr(time[1:-1], training=True)
    training_weights = schedule.get_weights_for_snr(training_schedule).numpy()
    ax[0].hist(training_schedule, density=True, color=colors[i], label=f"{schedule.name}", alpha=0.7, bins=20)

    log_snr = np.random.choice(training_schedule, p=training_weights / training_weights.sum(), replace=True, size=10000)
    ax[1].hist(log_snr, density=True, color=colors[i], label=f"{schedule.name}", alpha=0.7, bins=20)

for a in ax:
    a.legend(loc="upper right")
    a.set_xlabel(r"log SNR $\lambda$")
ax[0].set_ylabel(r"Density")
ax[0].set_title("Raw Schedules")
ax[1].set_title("With Weighting Functions")
plt.show()

In [None]:
from scipy.stats import norm

sech = lambda x: 1 / np.cosh(x)

In [None]:
lambda_t = np.linspace(-15, 15, 100)
fm_w =  np.exp(-lambda_t/2)
edm_w = norm.pdf(lambda_t, loc=2.4, scale=2.4) * (np.exp(-lambda_t) + 1**2)
cosine_w = sech(lambda_t / 2)

plt.plot(lambda_t, fm_w / max(fm_w), label='flow matching')
plt.plot(lambda_t, edm_w / max(edm_w), label='edm')
plt.plot(lambda_t, cosine_w / max(cosine_w), label='cosine')
plt.xlabel(r"log SNR $\lambda$")
plt.ylabel("Implied Weighting Function")
plt.legend()
plt.show()