# 🚀 Simulation-Based Inference Tutorial

In [None]:
# @title 1. Installation and Imports
# Install necessary packages quietly
!pip install sbi corner -q

import os
import itertools
from collections.abc import Callable

import torch
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sbi.inference import NPE, simulate_for_sbi
from sbi.utils import BoxUniform
from sbi.utils.user_input_checks import process_simulator
from sbi.analysis import pairplot, conditional_pairplot
from scipy import stats
import warnings

# Suppress minor warnings
warnings.filterwarnings("ignore", category=UserWarning)

print("✅ Libraries installed and imported.")

In [None]:
# Use all available CPU cores for parallelization to speed up simulations
num_workers = os.cpu_count()
print(f"⚙️ Using {num_workers} available CPU cores for parallel processing.")

In [None]:
# @title 2. Helper and Plotting Functions 🛠️
def analyze_posterior_statistics(
        posterior_samples: torch.Tensor,
        param_names: list[str],
        true_params: torch.Tensor | list[float] | None = None,
):
    """
    Analyze and print posterior statistics, including correlations and comparison to true values.

    Args:
        posterior_samples: Samples from the posterior distribution, shape (n_samples, n_params).
        param_names: A list of names for each parameter.
        true_params: The true parameter values for comparison (optional).
    Source:
        https://github.com/janfb/euroscipy-2025-sbi-tutorial
    """
    if isinstance(true_params, list):
        true_params = torch.tensor(true_params)

    # --- Calculate Statistics ---
    posterior_mean = posterior_samples.mean(dim=0)
    posterior_std = posterior_samples.std(dim=0)
    posterior_median = posterior_samples.median(dim=0).values
    lower_ci = torch.quantile(posterior_samples, 0.025, dim=0)
    upper_ci = torch.quantile(posterior_samples, 0.975, dim=0)
    posterior_corr = np.corrcoef(posterior_samples.T)

    # --- Print Statistics Table ---
    print("📊 Posterior Statistics")
    print("=" * 70)
    print(f"{'Parameter':<22} {'Mean ± Std':<20} {'Median':<15} {'95% Credible Interval':<20}")
    print("-" * 70)
    for i, name in enumerate(param_names):
        mean_std_str = f"{posterior_mean[i]:.3f} ± {posterior_std[i]:.3f}"
        median_str = f"{posterior_median[i]:.3f}"
        ci_str = f"[{lower_ci[i]:.3f}, {upper_ci[i]:.3f}]"
        print(f"{name:<22} {mean_std_str:<20} {median_str:<15} {ci_str:<20}")
    print("-" * 70)

    # --- Print True Parameters Comparison (if provided) ---
    if true_params is not None:
        print("\n🎯 Comparison with True Parameters:")
        for i, name in enumerate(param_names):
            in_ci = bool(lower_ci[i] <= true_params[i] <= upper_ci[i])
            symbol = "✅" if in_ci else "❌"
            print(f"{symbol} {name:<22} True value: {true_params[i]:.3f} (In 95% CI: {in_ci})")

    # --- Print Correlations ---
    print("\n🔗 Parameter Correlations:")
    n_params = posterior_samples.shape[1]
    param_pairs = list(itertools.combinations(range(n_params), 2))
    for i, j in param_pairs:
        corr_value = posterior_corr[i, j]
        name_i = param_names[i].split("(")[0].strip()
        name_j = param_names[j].split("(")[0].strip()
        print(f"{name_i} / {name_j}: {corr_value:+.3f}")


def generate_posterior_predictive_simulations(
        posterior,
        observed_data: torch.Tensor,
        simulate_func: Callable,
        prior: torch.distributions.Distribution,
        num_simulations: int = 1000,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Generate posterior predictive simulations and the simulation for the MAP estimate.
    Source:
        https://github.com/janfb/euroscipy-2025-sbi-tutorial
    """
    posterior.set_default_x(observed_data)
    batch_simulator = process_simulator(simulate_func, prior, True)

    # Generate simulations from posterior samples
    _, predictive_simulations = simulate_for_sbi(
        batch_simulator,
        posterior,
        num_simulations=num_simulations,
        num_workers=num_workers,
    )

    # Get the Maximum A Posteriori (MAP) estimate and simulate it
    map_estimate = posterior.map()
    map_simulation = simulate_func(map_estimate.squeeze())

    return torch.from_numpy(map_simulation), predictive_simulations


def plot_posterior_predictions(
        predictions: torch.Tensor,
        map_prediction: torch.Tensor,
        time_span: float = 200.0,
        dt: float = 0.1,
        labels: list = ["Prey", "Predator"],
        colors: list = ["red", "blue"],
):
    """
    Plot posterior predictive time series with uncertainty bands.
    Source:
        https://github.com/janfb/euroscipy-2025-sbi-tutorial
    """
    print("Plotting posterior predictive time series...")
    lower_bound = torch.quantile(predictions, 0.05, dim=0)
    upper_bound = torch.quantile(predictions, 0.95, dim=0)
    time_axis = np.arange(0, time_span, dt)

    fig, ax = plt.subplots(figsize=(14, 7))
    for i, (label, color) in enumerate(zip(labels, colors)):
        # Plot MAP prediction
        ax.plot(
            time_axis,
            map_prediction[:, i],
            color=f"dark{color}",
            linestyle="--",
            lw=2.5,
            label=f"MAP {label} Prediction",
        )
        # Plot uncertainty bands (90% credible interval)
        ax.fill_between(
            time_axis,
            lower_bound[:, i],
            upper_bound[:, i],
            color=color,
            alpha=0.2,
            label=f"90% Credible Interval ({label})",
        )

    ax.set_xlabel("Time", fontsize=14)
    ax.set_ylabel("Population", fontsize=14)
    ax.set_title("Predicted Future Population Dynamics", fontsize=18)
    ax.legend(loc="upper right")
    ax.grid(True, which="both", linestyle="--", linewidth=0.5)
    plt.tight_layout()
    plt.show()

print("✅ Utility functions defined.")

## 🎾 Part 1: The Ball Throw Physics Simulator

Our first model is a simple physics simulation of a projectile with air resistance.

**The Goal:** Imagine you are at a sports event. You can't see the athlete throw, but you can measure two things about the ball's flight:
1.  **Landing Distance:** How far it traveled horizontally.
2.  **Maximum Height:** The peak of its arc.

Based *only* on these two observations, can we infer the **three hidden parameters** of the throw?
-   `Initial Velocity` ($v_0$): How fast the ball was thrown.
-   `Launch Angle` ($\theta$): The angle of the throw.
-   `Friction` ($\mu$): A coefficient for air resistance.

**The Physics:**
Let's assume that the trajectory is governed by these differential equations:
-   Horizontal motion: $\frac{d^2x}{dt^2} = W - \mu \cdot \frac{dx}{dt}$ (where $W$ is wind)
-   Vertical motion: $\frac{d^2y}{dt^2} = -g - \mu \cdot \frac{dy}{dt}$ (where $g$ is gravity)

In [None]:
# @title Simulator and Prior Definition
def ball_throw_simulator(
        params: torch.Tensor | np.ndarray, return_trajectory: bool = False
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Simulates a ball throw with air resistance and returns its summary statistics.

    Args:
        params: A tensor/array containing [initial_velocity, launch_angle, friction, wind?].
        return_trajectory: If True, also returns the x and y coordinates of the flight path.

    Returns:
        A tensor with [landing_distance, max_height].
    Source:
        https://github.com/janfb/euroscipy-2025-sbi-tutorial
    """
    if isinstance(params, torch.Tensor):
        params = params.detach().cpu().numpy()

    v0, angle, friction = params[0], params[1], params[2]
    wind = params[3] if len(params) > 3 else 0.0
    g, dt = 9.81, 0.01  # Gravity and time step

    x, y = 0.0, 0.0
    vx, vy = v0 * np.cos(angle), v0 * np.sin(angle)
    x_traj, y_traj = ([x], [y])
    max_height = 0.0

    # Simulate up to 10,000 steps (100 seconds)
    for _ in range(10_000):
        # Update velocities with friction and gravity/wind
        vx += (wind - friction * vx) * dt
        vy += (-g - friction * vy) * dt
        # Update position
        x_new, y_new = x + vx * dt, y + vy * dt

        # Stop if the ball hits the ground
        if y_new < 0:
            # Interpolate to find the exact landing spot
            t_impact = -y / vy
            landing_distance = x + vx * t_impact
            break

        x, y = x_new, y_new
        max_height = max(max_height, y)
        x_traj.append(x)
        y_traj.append(y)
    else: # Failsafe in case it never lands
        landing_distance = x

    # Add a small amount of observational noise to make it more realistic
    noise_scale = 0.05
    landing_distance *= (1 + np.random.randn() * noise_scale)
    max_height *= (1 + np.random.randn() * noise_scale)

    # Ensure observations are positive
    observations = torch.tensor([max(0.1, landing_distance), max(0.1, max_height)], dtype=torch.float32)

    if return_trajectory:
        return observations, torch.tensor(x_traj), torch.tensor(y_traj)
    return observations


def create_ball_throw_prior(include_wind: bool = False):
    """Creates a uniform prior distribution for the ball throw parameters."""
    if include_wind:
        # [v0, angle, friction, wind]
        low = torch.tensor([5.0, 0.2, 0.0, -5.0])   # Wind can be headwind (-) or tailwind (+)
        high = torch.tensor([30.0, 1.4, 0.5, 5.0])
    else:
        # [v0, angle, friction]
        low = torch.tensor([5.0, 0.2, 0.0])
        high = torch.tensor([30.0, 1.4, 0.5]) # Angle in radians (approx. 11° to 80°)
    return BoxUniform(low=low, high=high)


def plot_trajectories(params_set, labels):
    """Helper function to plot multiple trajectories."""
    plt.figure(figsize=(10, 6))
    for params, label in zip(params_set, labels):
        _, x_traj, y_traj = ball_throw_simulator(params, return_trajectory=True)
        plt.plot(x_traj, y_traj, label=label, lw=2.5)
    plt.title("Sample Ball Throw Trajectories", fontsize=16)
    plt.xlabel("Distance (m)")
    plt.ylabel("Height (m)")
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.6)
    plt.ylim(bottom=0)
    plt.xlim(left=0)
    plt.show()

print("✅ Ball throw simulator defined.")

In [None]:
# @title Interactive Trajectory Explorer
# @markdown Drag the sliders to see how each parameter affects the ball's trajectory. This builds intuition for what our algorithm will learn.
initial_velocity = 21  #@param {type:"slider", min:5.0, max:30.0, step:0.5}
launch_angle_degrees = 27  #@param {type:"slider", min:10, max:80, step:1}
friction_coefficient = 0.1 #@param {type:"slider", min:0.0, max:0.5, step:0.01}

# Convert degrees to radians for the simulator
launch_angle_radians = launch_angle_degrees * (np.pi / 180.0)
interactive_params = torch.tensor([initial_velocity, launch_angle_radians, friction_coefficient])

# Define some fixed throws for comparison
strong_throw_params = torch.tensor([25.0, 0.7, 0.1])  # A powerful throw
high_arc_params = torch.tensor([15.0, 1.2, 0.1])    # A high, looping throw

plot_trajectories(
    [strong_throw_params, high_arc_params, interactive_params],
    ["Strong Throw (25 m/s, 40°)", "High-Arc Throw (15 m/s, 68°)", "Your Custom Throw"]
)

In [None]:
# @title Perform Simulation-Based Inference
# --- Step 1: Define the Prior ---
prior = create_ball_throw_prior(include_wind=False)
param_names_3d = ["v₀ (velocity)", "θ (angle)", "μ (friction)"]

# --- Step 2: Set up the SBI pipeline ---
# The `process_simulator` function wraps our Python simulator so `sbi` can use it.
simulator = process_simulator(ball_throw_simulator, prior, False)
# We use Neural Posterior Estimation (NPE) as our inference algorithm.
npe = NPE(prior=prior)

# --- Step 3: Generate Training Data ---
# We run the simulator many times with parameters drawn from the prior.
# This is the "training set" for our neural network.
num_simulations = 2000  # Use 10,000+ for higher accuracy in a real project.
print(f"⚙️ Generating {num_simulations} simulations... (This may take a moment)")
theta, x = simulate_for_sbi(
    simulator,
    prior,
    num_simulations=num_simulations,
    num_workers=num_workers,
)
print(f"✅ Generated {len(theta)} simulation-observation pairs.")

# --- Step 4: Train the Neural Network ---
# The network learns the relationship between parameters (theta) and simulation outcomes (x).
print("\n🧠 Training the neural posterior estimator...")
npe.append_simulations(theta, x).train()
print("✅ Training complete!")

# --- Step 5: Define the Observation and Build the Posterior ---
# This is the data we actually observed. We'll use the "Strong Throw" as our target.
observation_strong_throw = torch.tensor([50.5, 12.1])
posterior = npe.build_posterior()
print(f"\n🎯 Our observation: Distance={observation_strong_throw[0]}m, Height={observation_strong_throw[1]}m")

# --- Step 6: Sample from the Posterior ---
# We draw samples from the learned posterior distribution, conditioned on our observation.
print("📈 Sampling from the posterior distribution...")
posterior_samples = posterior.sample((10000,), x=observation_strong_throw)
print(f"✅ Drew {len(posterior_samples)} posterior samples.")
print("\n🎉 Inference complete! Let's analyze the results.")

# --- Step 7: Analyze and Visualize the Results ---
# The pairplot shows the 1D and 2D marginals of the posterior.
# The blue lines/dots mark the true parameters we're trying to recover.
fig, axes = pairplot(
    [posterior_samples],
    points=strong_throw_params.unsqueeze(0),
    labels=param_names_3d,
    figsize=(8, 8),
);
plt.suptitle("Posterior Distribution for the Strong Throw", fontsize=16, y=1.02)
plt.show()

# Print detailed statistics
_ = analyze_posterior_statistics(posterior_samples, param_names_3d, strong_throw_params)

### 💡 Interpreting Posterior Correlations: The "Wiggle Room"

You might notice in the pair plots that our results aren't single, sharp points but are often elongated or slanted "blobs." This is a key insight, *not* a flaw! It shows that the data can be explained by a range of different parameter combinations. This "wiggle room" happens for a few important reasons:

1.  **Parameter Trade-Offs:** Often, a change in one parameter can be compensated for by a change in another. The posterior plot reveals the exact nature of this trade-off.

2.  **Identifiability Limits:** The summary statistics we use (distance and height) might not be perfect. Multiple different parameter sets could produce summary statistics that are very similar, making it hard for the model to distinguish between them. The posterior honestly reflects this ambiguity.


#### Why This is So Useful

Understanding these correlations is a core benefit of using SBI over simple optimization. Instead of just one "best" answer, we get a complete map of all plausible solutions. This helps us:

* **Pinpoint what the data tells us:** A narrow posterior for a parameter means the data has strongly constrained it. A broad, correlated posterior tells us the data can only constrain a specific *combination* of parameters.

* **Know when we need more data:** If key parameters remain too uncertain or correlated, the plot is telling us that we might need more informative summary statistics or different kinds of data to break the trade-offs.

* **Uncover scientific insights:** These statistical correlations often reflect real-world phenomena. For this model, it reveals the physical relationship between velocity, angle, friction, and wind (next section), providing a complete picture of the problem.

### 📈 Conditional Posterior Distibution

In [None]:
from sbi.analysis import conditional_pairplot

# The posterior must have a `default_x`.
posterior = npe.build_posterior().set_default_x(observation_strong_throw)


_ = conditional_pairplot(
    density=posterior,
    points=[strong_throw_params.unsqueeze(0)],
    condition=strong_throw_params,
    limits=torch.tensor([[5.0, 30.0], [0.2, 1.4], [0.0, 0.5]]),
    labels=param_names_3d,
)

### 🤸 Exercise 1: Two Different Athletes

**Scenario:** We have data from two athletes, but we only see the final result of their throws.

* **Athlete A** is a "power thrower": **50.3m** distance, **12.1m** max height.
* **Athlete B** is a "high-arc specialist": **12.1m** distance, **9.5m** max height.

**Question:** Based *only* on these two outcomes, what can we infer about their throwing styles (velocity, angle, friction)?

**Implementation:** We can reuse our already-trained `posterior` object. We just need to condition it on the new observations for each athlete.

### 💨 Exercise 2: Inferring an Unknown Wind

**Scenario:** A throw is made under unknown weather conditions. The ball travels an unusually long distance for its apparent arc. We suspect there was a tailwind helping it along.

**Task:** Add **wind strength** as a fourth parameter to our model and infer it from a new observation.

* **New Parameter:** `wind` (m/s), can be negative (headwind) or positive (tailwind).
* **New Prior:** A 4D uniform distribution including wind from -5.0 to 5.0 m/s.
* **Observation:** A throw lands at **65.0 meters** and reached a max height of **11.0 meters**.

**Implementation:** Since we've added a new parameter, our previous neural network is no longer valid. We must retrain from scratch with a new 4D prior and new simulations.

## 🦊 Part 2: Lotka-Volterra Predator-Prey Model

Next, we'll tackle a classic model in theoretical ecology: the Lotka-Volterra equations, which describe the population dynamics of a predator and its prey.

**The Goal:** Given a time-series of prey and predator populations, can we infer the four fundamental parameters that govern their interaction?

**The Equations:**
$$
\begin{align*}
\frac{d\text{Prey}}{dt} &= \alpha \cdot \text{Prey} - \beta \cdot \text{Prey} \cdot \text{Predator} \\
\frac{d\text{Predator}}{dt} &= \delta \cdot \text{Prey} \cdot \text{Predator} - \gamma \cdot \text{Predator}
\end{align*}
$$

**The Parameters:**
-   `α (alpha)`: **Prey birth rate**.
-   `β (beta)`: **Predation rate** (how effectively predators hunt prey).
-   `δ (delta)`: **Predator reproduction rate** (how efficiently predators turn food into offspring).
-   `γ (gamma)`: **Predator death rate**.

**The Challenge:** Unlike the ball throw, the output of this simulator is not just two numbers, but a long time series. We need to compress this data into a set of **summary statistics** that capture its essential features.

In [None]:
# @title Simulator and Summary Statistics
def lotka_volterra_simulation(
        parameters: np.ndarray,
        t_span: float = 200.0,
        dt: float = 0.1,
        y0: np.ndarray = np.asarray([40.0, 9.0])
) -> np.ndarray:
    """
    Simulates the Lotka-Volterra model using a simple Euler method.

    Args:
        parameters: A numpy array [alpha, beta, delta, gamma].
        t_span: Total simulation time.
        dt: Time step for the simulation.
        y0: Initial populations [prey, predator].

    Returns:
        A numpy array of shape (timesteps, 2) with the population dynamics.
    Source:
        https://github.com/janfb/euroscipy-2025-sbi-tutorial
    """
    alpha, beta, delta, gamma = parameters
    timesteps = int(t_span / dt)
    y = np.zeros((timesteps, 2))
    y[0] = y0

    for i in range(1, timesteps):
        prey, predator = y[i-1]
        dprey_dt = alpha * prey - beta * prey * predator
        dpredator_dt = delta * prey * predator - gamma * predator
        y[i] = y[i-1] + np.asarray([dprey_dt, dpredator_dt]) * dt
        # Ensure populations don't go below zero
        y[i][y[i] < 0] = 0
    return y


def _get_stats(population: np.ndarray, use_autocorrelation: bool) -> np.ndarray:
    """
    Helper to calculate summary stats for a single population time series.
    Source:
        https://github.com/janfb/euroscipy-2025-sbi-tutorial
    """
    # Moments: mean, std, max, skewness, kurtosis
    moments = np.array([
        np.mean(population), np.std(population), np.max(population),
        stats.skew(population), stats.kurtosis(population)
    ])
    if not use_autocorrelation:
        return moments

    # Autocorrelation at specific time lags
    mean_centered_pop = population - np.mean(population)
    autocorr_full = np.correlate(mean_centered_pop, mean_centered_pop, mode="full")
    lag_0_corr = autocorr_full[autocorr_full.size // 2]

    if lag_0_corr > 1e-6:
        normalized_autocorr = (autocorr_full / lag_0_corr)[autocorr_full.size // 2 :]
        # Lags correspond to time delays of 1, 5, 10, 20, and 40 units of time
        lags_to_take = [10, 50, 100, 200, 400]
        autocorr = normalized_autocorr[lags_to_take]
    else: # If variance is zero, autocorrelation is undefined
        autocorr = np.zeros(5)

    return np.concatenate([moments, autocorr])


def summarize_simulation(
        simulation_result: np.ndarray, use_autocorrelation: bool = False
) -> np.ndarray:
    """
    Converts a simulation time series into summary statistics.
    Adds a bit of noise to simulate real-world measurement error.
    Source:
        https://github.com/janfb/euroscipy-2025-sbi-tutorial
    """
    noise = np.random.randn(*simulation_result.shape)
    noisy_populations = simulation_result + noise
    prey_stats = _get_stats(noisy_populations[:, 0], use_autocorrelation)
    predator_stats = _get_stats(noisy_populations[:, 1], use_autocorrelation)
    return np.concatenate([prey_stats, predator_stats])

def get_summary_labels(use_autocorrelation: bool = False) -> list:
    """
    Returns a list of names for the summary statistics.
    Source:
        https://github.com/janfb/euroscipy-2025-sbi-tutorial
    """
    moment_labels = ["Mean", "Std", "Max", "Skew", "Kurtosis"]
    if use_autocorrelation:
        acf_labels = ["ACF Lag 10", "ACF Lag 50", "ACF Lag 100", "ACF Lag 200", "ACF Lag 400"]
        stat_labels = moment_labels + acf_labels
    else:
        stat_labels = moment_labels
    return [f"Prey {lbl}" for lbl in stat_labels] + [f"Predator {lbl}" for lbl in stat_labels]


# This is our master simulator function for SBI. It runs the simulation AND summarizes it.
def lotka_volterra_sbi_simulator(parameters, use_autocorrelation=False):
    populations = lotka_volterra_simulation(parameters)
    return summarize_simulation(populations, use_autocorrelation)

print("✅ Lotka-Volterra simulator and summarizer defined.")

In [None]:
# @title Interactive LV Explorer & Summary Statistics
# @markdown Drag the sliders to see how the parameters change the population cycles.
α = 0.12  #@param {type:"slider", min:0.05, max:0.15, step:0.01}
β = 0.015  #@param {type:"slider", min:0.01, max:0.03, step:0.005}
δ = 0.01  #@param {type:"slider", min:0.005, max:0.03, step:0.005}
γ = 0.085  #@param {type:"slider", min:0.005, max:0.15, step:0.01}

# --- Generate a ground-truth simulation ---
true_params_lv = np.asarray([α, β, δ, γ])
observed_time_series = lotka_volterra_simulation(true_params_lv)

fig, ax = plt.subplots(1, 1, figsize=(10, 4))
ax.plot(np.arange(0, 200, 0.1), observed_time_series)
ax.legend(["Prey", "Predator"])
ax.set_xlabel("Time")
ax.set_ylabel("Population")
ax.set_title("Lotka-Volterra Population Dynamics")
plt.show()

# --- Calculate summary statistics for this "observed" data ---
# We will start by using only the basic moment-based statistics.
USE_AUTOCORRELATION = False # This is the key setting for Exercise 3
observed_summary = summarize_simulation(observed_time_series, use_autocorrelation=USE_AUTOCORRELATION)
data_labels = get_summary_labels(use_autocorrelation=USE_AUTOCORRELATION)

print("📋 Observed Summary Statistics (using moments only):")
for label, value in zip(data_labels, observed_summary):
    print(f"{label:20s}: {value:.2f}")

In [None]:
# @title Perform SBI on the Lotka-Volterra Model
# --- Step 1: Define Prior ---
lv_param_names = ["α (prey birth)", "β (predation)", "δ (predator effic.)", "γ (predator death)"]
lower_bound = torch.tensor([0.05, 0.01, 0.005, 0.005])
upper_bound = torch.tensor([0.15, 0.03, 0.03, 0.15])
prior_lv = BoxUniform(low=lower_bound, high=upper_bound)

# --- Step 2: Set up SBI Pipeline ---
# Note: we pass the `use_autocorrelation` flag to our simulator wrapper
simulator_lv = lambda params: lotka_volterra_sbi_simulator(params, use_autocorrelation=USE_AUTOCORRELATION)
simulator_lv = process_simulator(simulator_lv, prior_lv, False)
npe_lv = NPE(prior=prior_lv)

# --- Step 3: Generate Data & Train ---
num_simulations_lv = 2000
print(f"⚙️ Generating {num_simulations_lv} simulations...")
theta_lv, x_lv = simulate_for_sbi(simulator_lv, prior_lv, num_simulations=num_simulations_lv, num_workers=num_workers)
print("🧠 Training neural posterior estimator...")
npe_lv.append_simulations(theta_lv, x_lv).train()
posterior_lv = npe_lv.build_posterior()
print("✅ Training Complete!")

# --- Step 4: Sample from Posterior ---
print("\n📈 Sampling from posterior...")
observed_data_lv = torch.tensor(observed_summary, dtype=torch.float32)
posterior_samples_lv = posterior_lv.sample((10000,), x=observed_data_lv)
print(f"✅ Drew {len(posterior_samples_lv)} posterior samples.")

# --- Step 5: Analyze and Visualize ---
fig = pairplot(
    posterior_samples_lv,
    points=true_params_lv,
    labels=[r"$\alpha$", r"$\beta$", r"$\delta$", r"$\gamma$"],
    figsize=(9, 9),
    limits=[(low, high) for low, high in zip(lower_bound, upper_bound)],
)
plt.suptitle("Posterior Distribution of Lotka-Volterra Parameters", fontsize=20, y=1.02)
plt.show()

_ = analyze_posterior_statistics(
    posterior_samples=posterior_samples_lv,
    param_names=lv_param_names,
    true_params=torch.from_numpy(true_params_lv),
)
# @title Posterior Predictive Check
# --- Step 6: Posterior Predictive Check ---
# This is a crucial step: can our inferred parameters generate simulations that look like the original data?
map_sim, pred_sims = generate_posterior_predictive_simulations(
    posterior=posterior_lv,
    observed_data=observed_data_lv,
    simulate_func=lotka_volterra_simulation, # Use the raw simulator here
    prior=prior_lv,
    num_simulations=1000,
)

plot_posterior_predictions(predictions=pred_sims, map_prediction=map_sim)

### 🔬 Exercise 3: The Value of Better Statistics

**Scenario:** Our first analysis (using only moments) worked, but the posterior predictive plot shows that the uncertainty in our forecast grows very quickly. Can we do better?

**Hypothesis:** The moments (mean, std, etc.) don't capture the *temporal structure* of the time series (e.g., the speed of the oscillations). Autocorrelation, which measures how a signal correlates with a delayed copy of itself, should provide this missing information.

**Your Task:**
1.  Go back to the **"Interactive LV Explorer & Summary Statistics"** cell.
2.  Set the variable `USE_AUTOCORRELATION = True`.
3.  Re-run that cell and all the subsequent cells for the Lotka-Volterra model.
4.  Compare the new results (posterior distributions and predictive plots) with the old ones.


## 🌎 Exercise 4: Real-World Fox and Rabbit Data

**Scenario:** We have obtained real-world data on the number of hunted **foxes (predator)** and **wild rabbits (prey)** in Saxony, Germany, from 1991 to 2023. Can we use our Lotka-Volterra model and SBI to find parameters that describe this real ecological system and predict its future?

**The Data:**
-   Source: Saxony State Ministry for the Environment and Agriculture([Excel file in the PDF document](https://www.medienservice.sachsen.de/medien/medienobjekte/117580)).

**Your Task:**
1.  **Preprocess the Data:** Calculate summary statistics from the real time-series data. This will be our `observation`.
2.  **Define a Prior:** Choose a reasonable prior range for the LV parameters. We may need a wider prior than before, as real-world dynamics can be different.
3.  **Run the Inference:** Execute the full SBI pipeline using the real data as the target observation.
4.  **Analyze and Predict:** Analyze the inferred posterior and generate posterior predictive simulations to forecast the populations into the future.


In [None]:
# @title Data Preparation
# Data source: https://www.medienservice.sachsen.de/medien/medienobjekte/117580

times = np.arange(1991, 1991 + 33)
foxes_raw = np.array([8100.0, 16446.0, 22152.0, 24413.0, 30010.0, 23240.0, 28922.0,
                      30949.0, 32598.0, 26475.0, 29037.0, 28537.0, 23503.0, 24619.0,
                      26604.0, 21376.0, 28169.0, 27091.0, 24705.0, 24592.0, 22235.0,
                      18618.0, 13496.0, 14365.0, 16479.0, 14752.0, 13332.0, 14893.0,
                      16303.0, 17797.0, 13869.0, 14000.0, 16262.0])
rabbits_raw = np.array([274.0, 355.0, 293.0, 271.0, 174.0, 73.0, 100.0, 91.0, 69.0,
                        73.0, 45.0, 37.0, 63.0, 47.0, 91.0, 25.0, 36.0, 55.0, 32.0,
                        37.0, 38.0, 21.0, 71.0, 44.0, 46.0, 18.0, 10.0, 0.0, 21.0,
                        0.0, 1.0, 8.0, 1.0])

# We scale the data to be in a similar numerical range as our simulator (e.g., by thousands)
# This helps with numerical stability.
foxes = foxes_raw / 1000
rabbits = rabbits_raw / 1000
real_data_timeseries = np.stack([rabbits, foxes], axis=1) # Note: Prey first, then Predator

fig, ax1 = plt.subplots(figsize=(15, 7))
color_fox = 'tab:blue'
ax1.set_xlabel('Year', fontsize=12)
ax1.set_ylabel('Foxes (in thousands)', color=color_fox, fontsize=12)
ax1.plot(times, foxes, color=color_fox, marker='o', label='Füchse (Foxes)')
ax1.tick_params(axis='y', labelcolor=color_fox)
ax2 = ax1.twinx()
color_rabbit = 'tab:red'
ax2.set_ylabel('Rabbits (in thousands)', color=color_rabbit, fontsize=12)
ax2.plot(times, rabbits, color=color_rabbit, marker='x', linestyle='--', label='Wildkaninchen (Rabbits)')
ax2.tick_params(axis='y', labelcolor=color_rabbit)
plt.title('Foxes and Rabbits in Saxony (1991-2023)', fontsize=16, pad=20)
ax1.legend(loc='upper left'); ax2.legend(loc='upper right')
fig.tight_layout(); plt.show()


## 📚 References and Credits

* **sbi Package Documentation:** [sbi.readthedocs.io](https://sbi.readthedocs.io/)
* This tutorial was inspired by and adapts materials from:
    * Boelts, J. (2025). EuroSciPy 2025: Simulation-Based Inference Tutorial. [github.com/janfb/euroscipy-2025-sbi-tutorial](https://github.com/janfb/euroscipy-2025-sbi-tutorial)
    * Deistler, M., et al. (2025). Simulation-Based Inference: A Practical Guide.
    * Cranmer, K., Brehmer, J., & Louppe, G. (2020). The frontier of simulation-based inference. *PNAS*.