# Exercise 3: Apply SBI to Your Own Problem 🚀

**Time:** 20 minutes  
**Goal:** Apply what you've learned to a new simulator

## 🎯 Learning Objectives

By the end of this exercise, you will:
1. ✅ Adapt the SBI workflow to a new problem
2. ✅ Define appropriate priors for your parameters
3. ✅ Run inference and diagnostics on your simulator
4. ✅ Leave with working code you can adapt

## Choose Your Adventure!

We provide two well-tested example simulators, or you can bring your own:

### 🎾 Option A: Ball Throw Physics
- **Story**: You're analyzing baseball pitches or golf drives
- **Physics**: Projectile motion with air resistance
- **Challenge**: Infer launch conditions from landing position

### 🦠 Option B: SIR Epidemic Model
- **Story**: You're tracking disease spread in a community
- **Model**: Classic compartmental epidemic model
- **Challenge**: Infer transmission rates from outbreak data

### 🔬 Option C: Your Own Simulator
- Bring your research problem!
- We'll help you adapt it

## Setup

In [None]:
# Core imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

# SBI imports
from sbi.inference import NPE, simulate_for_sbi

# Our example simulators
import sys

sys.path.append("..")
from simulators.ball_throw import ball_throw_simulator, create_ball_throw_prior
from simulators.sir_model import sir_epidemic_simulator, create_sir_prior

# Set style
plt.style.use("seaborn-v0_8-darkgrid")
sns.set_palette("colorblind")

# Random seed
torch.manual_seed(42)
np.random.seed(42)

print("✅ Ready to apply SBI to your problem!")

## Part 1: Explore the Simulators

Let's understand what each simulator does before choosing one.

### 🎾 Ball Throw Physics

This simulator models projectile motion with air resistance:

**Differential equations:**
- Horizontal: `d²x/dt² = wind - friction·dx/dt`
- Vertical: `d²y/dt² = -gravity - friction·dy/dt`

**Parameters to infer:**
1. Initial velocity (5-30 m/s)
2. Launch angle (0.2-1.4 radians ≈ 11°-80°)
3. Friction coefficient (0.0-0.5)

**What we observe:**
- Landing distance (meters)
- Maximum height reached (meters)

In [None]:
# Test the ball throw simulator
test_params = torch.tensor([15.0, 0.8, 0.1])  # 15 m/s, ~45°, low friction
observations = ball_throw_simulator(test_params)

print("🎾 Ball Throw Test:")
print(
    f"  Parameters: v₀={test_params[0]:.1f} m/s, θ={test_params[1]:.2f} rad, μ={test_params[2]:.2f}"
)
print(
    f"  Observations: distance={observations[0]:.1f}m, max_height={observations[1]:.1f}m"
)

# Visualize a trajectory
obs, x_traj, y_traj = ball_throw_simulator(test_params, return_trajectory=True)

plt.figure(figsize=(10, 4))
plt.plot(x_traj, y_traj, "b-", linewidth=2, label="Trajectory")
plt.scatter([obs[0].item()], [0], color="red", s=100, zorder=5, label="Landing")
plt.scatter(
    [x_traj[np.argmax(y_traj)]],
    [obs[1].item()],
    color="green",
    s=100,
    zorder=5,
    label="Peak",
)
plt.xlabel("Distance (m)", fontsize=12)
plt.ylabel("Height (m)", fontsize=12)
plt.title("Ball Trajectory with Air Resistance", fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.show()

print(
    "\n💡 We observe only the landing distance and max height, not the full trajectory!"
)

### 🦠 SIR Epidemic Model

This simulator models disease spread through a population:

**Compartments:**
- **S**usceptible: Can catch the disease
- **I**nfected: Currently sick and contagious
- **R**ecovered: Immune after recovery

**Differential equations:**
- `dS/dt = -β·S·I/N` (infection)
- `dI/dt = β·S·I/N - γ·I` (infection - recovery)
- `dR/dt = γ·I` (recovery)

**Parameters to infer:**
1. β: Infection rate (0.1-2.0 per day)
2. γ: Recovery rate (0.05-0.5 per day)
3. I₀: Initial infected count (1-100 people)

**What we observe:**
- Peak number of infected
- Time to reach peak (days)
- Total recovered at end
- Epidemic duration (days)

In [None]:
# Test the SIR simulator
test_params = torch.tensor([0.5, 0.1, 10])  # β=0.5, γ=0.1, I₀=10
observations = sir_epidemic_simulator(test_params)


print("🦠 SIR Epidemic Test:")
print(
    f"  Parameters: β={test_params[0]:.2f}, γ={test_params[1]:.2f}, I₀={test_params[2]:.0f}"
)
print(f"  Basic reproduction number R₀ = β/γ = {test_params[0] / test_params[1]:.1f}")
print(f"\n  Observations:")
print(f"    Peak infected: {observations[0]:.0f} people")
print(f"    Time to peak: {observations[1]:.0f} days")
print(f"    Total recovered: {observations[2]:.0f} people")
print(f"    Epidemic duration: {observations[3]:.0f} days")

# Visualize epidemic curves
obs, time_series = sir_epidemic_simulator(test_params, return_time_series=True)

plt.figure(figsize=(10, 5))
plt.plot(
    time_series["t"], time_series["S"], label="Susceptible", linewidth=2, color="blue"
)
plt.plot(time_series["t"], time_series["I"], label="Infected", linewidth=2, color="red")
plt.plot(
    time_series["t"], time_series["R"], label="Recovered", linewidth=2, color="green"
)

# Mark observations
peak_idx = np.argmax(time_series["I"])
plt.scatter(
    [time_series["t"][peak_idx]],
    [time_series["I"][peak_idx]],
    color="red",
    s=100,
    zorder=5,
    label=f"Peak: {obs[0]:.0f}",
)

plt.xlabel("Time (days)", fontsize=12)
plt.ylabel("Number of individuals", fontsize=12)
plt.title("SIR Epidemic Dynamics (Population = 10,000)", fontsize=14)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.show()

print("\n💡 We observe summary statistics, not the full time series!")

### 🔬 Your Own Simulator

If you brought your own simulator, adapt this template:

In [None]:
def your_simulator(params):
    """
    Template for your own simulator.

    Requirements:
    1. Takes parameters (torch.Tensor or numpy array)
    2. Returns observations (torch.Tensor)
    3. Should include some stochasticity (noise)
    4. Runs reasonably fast (< 1 second)
    """
    # Convert to torch if needed
    if isinstance(params, np.ndarray):
        params = torch.tensor(params, dtype=torch.float32)

    # Your simulation code here
    # ...

    # Add observation noise (important!)
    # observations = observations * (1 + torch.randn_like(observations) * 0.05)

    # Return as torch tensor
    # return torch.tensor(observations, dtype=torch.float32)

    pass  # Remove this when implementing

## Part 2: Choose Your Simulator and Run SBI

**👇 Choose ONE option below by uncommenting the appropriate section:**

In [None]:
# We import a convenience function from sbi to process the simulator.
# This enable us to run parallel simulations below.
from sbi.utils.user_input_checks import process_simulator

# ========== OPTION A: Ball Throw ==========
prior = create_ball_throw_prior(include_wind=False)  # Set True to include wind
simulator = process_simulator(ball_throw_simulator, prior, False)
param_names = ["v₀ (m/s)", "θ (rad)", "μ (friction)"]
obs_names = ["distance (m)", "max height (m)"]

# ========== OPTION B: SIR Model ==========
# prior = create_sir_prior()
# simulator = process_simulator(sir_epidemic_simulator, prior, False)
# param_names = ["β (infection)", "γ (recovery)", "I₀ (initial)"]
# obs_names = ["peak infected", "time to peak", "total recovered", "duration"]

# ========== OPTION C: Your Simulator ==========
# prior = utils.BoxUniform(
#     low=torch.tensor([...]),   # Your parameter lower bounds
#     high=torch.tensor([...])   # Your parameter upper bounds
# )
# simulator = process_simulator(your_simulator, prior, returns_numpy=?)
# param_names = [...]  # Your parameter names
# obs_names = [...]    # Your observable names

print(f"Selected simulator: {simulator.__name__}")
print(f"Parameters: {param_names}")
print(f"Observables: {obs_names}")

## Part 3: Generate "Observed" Data and Training Data

In a real application, this would be your experimental data.

In [None]:
# Generate synthetic observation (ground truth for testing)
true_params = prior.sample((1,))
observed_data = simulator(true_params)

print("\n🎯 True parameters (hidden in real applications):")
for i, name in enumerate(param_names):
    print(f"  {name}: {true_params[0, i]:.3f}")

print("\n📊 Observed data:")
for i, name in enumerate(obs_names):
    print(f"  {name}: {observed_data[0, i]:.3f}")

# Generating training data
num_simulations = 10000
num_workers = 4  # Adjust based on your system

# TODO
# theta, x = ?

## Part 4: Run Neural Posterior Estimation 🚀

The same 4-step workflow from Exercise 1!

In [None]:
# Step 1: Create NPE object
# TODO

# Step 2: Train on simulations
print("🏃 Training neural network...")

# TODO

# Step 3: Build posterior for our observation

# TODO

# Step 4: Sample from posterior

# TODO

print("\n✅ Inference complete! Let's see what we learned...")

## Part 5: Visualize Results 📊

### Instructions

- Use the SBI `pairplot` function to visualize the posterior in a corner plot
- Pass the prior boundaries as limits to see whether the posterior is more constrained
  than the prior
- Look for visual signs of correlations in the off-diagonal plots
- Quantify correlations by calculating the correlation matrix of the posterior samples
- Interpret what these correlations mean for your simulator - do they make physical
  sense?

In [None]:
# Check notebook 01 for detailed explanations of the steps
from sbi.analysis import pairplot

## Part 6: Diagnostic - Posterior Predictive Check, Training Diagnostics, Calibration 🔍

### Instructions

- Perform **prior predictive checks**: Sample parameters from prior, simulate
  observations, check if they cover reasonable ranges
- Check **training convergence**: Plot training losses to ensure the neural network
  converged properly
- Run **posterior predictive checks**: Sample parameters from posterior, simulate new
  observations, compare to actual observed data
- Perform **calibration checks**: Use SBI's calibration functions to test if uncertainty
  estimates are well-calibrated
- Use the plotting and analysis functions from notebook 02 for all diagnostic
  visualizations

In [None]:
# Check notebook 02 for detailed explanations of the steps.

## 🎉 Congratulations!

You've successfully:
- ✅ Applied SBI to different problems
- ✅ Learned the universal NPE workflow
- ✅ Performed diagnostic checks
- ✅ Explored how choices affect inference

### 🔑 Key Takeaways:

1. **SBI is universal** - Same workflow for any simulator!
2. **Prior choice matters** - Must cover true parameters
3. **Diagnostics are essential** - Always check predictive distributions
4. **More data = less uncertainty** - Both simulations and observations help
5. **Parameters can be correlated** - Trade-offs and identifiability

### 🚀 Next Steps:

For your research:
1. **Start simple** - Test with synthetic data first
2. **Scale up gradually** - Increase complexity step by step
3. **Use Sequential NPE** - More efficient for expensive simulators
4. **Try other methods** - NLE, NRE for different use cases

---

## 🙏 Thank you for participating!

**Now go forth and quantify uncertainty in your simulators!**