# Foundation Models for SBI: NPE-PFN

**Time: ~15 minutes**

In the previous notebooks, we:
1. **Notebook 3**: Learned the `sbi` workflow for NPE (training normalizing flows)
2. **Notebook 4**: Explored different summary statistics
3. **Notebook 5**: Learned how to diagnose our posteriors

In all these approaches, we had to **train** a neural network on our simulated data. What if we could skip training entirely?

> **Foundation Models for SBI**: Pre-trained models that work "out of the box" on new problems!

## What We'll Learn

1. **NPE-PFN**: A foundation model for SBI based on Prior-data Fitted Networks (TabPFN)
2. **Amortized inference without training**: Just provide (θ, x) pairs and get a posterior!
3. **TSNPE-PFN**: Sequential version that focuses simulations on a specific observation

**Paper**: [Simulation-Based Inference with the Prior-Data Fitted Networks](https://arxiv.org/abs/2407.20482)

**Repository**: [github.com/mackelab/npe-pfn](https://github.com/mackelab/npe-pfn)

---
## Installation

The `npe-pfn` package needs to be cloned from GitHub. **Run the cell below once** to set it up.

In [None]:
# Setup npe-pfn (run once)
import subprocess
import sys
import os

# Get absolute path (works in notebooks)
npe_pfn_path = os.path.abspath("npe-pfn")

# Clone if needed
if not os.path.exists(npe_pfn_path):
    print("Cloning npe-pfn repository...")
    subprocess.run(["git", "clone", "https://github.com/mackelab/npe-pfn", npe_pfn_path], check=True)

# Add to Python path (so we can import without pip install)
if npe_pfn_path not in sys.path:
    sys.path.insert(0, npe_pfn_path)
    print(f"Added {npe_pfn_path} to sys.path")

# Install only tabpfn (the one missing dependency)
try:
    import tabpfn
    print("✓ tabpfn already installed")
except ImportError:
    print("Installing tabpfn...")
    subprocess.run(["uv", "pip", "install", "tabpfn"], check=True)
    print("✓ tabpfn installed")

print("✓ Ready to import from npe_pfn")

Installing tabpfn...


[2mUsing Python 3.12.8 environment at: /Users/danielgedon/Dropbox/05_Postdoc/organizing/2601_hackathon_sbi_grenoble/sbi-hackathon-2026/.venv[0m
[2mResolved [1m51 packages[0m [2min 502ms[0m[0m
[2mPrepared [1m19 packages[0m [2min 421ms[0m[0m
[2mUninstalled [1m1 package[0m [2min 153ms[0m[0m


✓ tabpfn installed
✓ Ready to import from npe_pfn


[2mInstalled [1m20 packages[0m [2min 101ms[0m[0m
 [32m+[39m [1mannotated-types[0m[2m==0.7.0[0m
 [32m+[39m [1mbackoff[0m[2m==2.2.1[0m
 [32m+[39m [1mclick[0m[2m==8.3.1[0m
 [32m+[39m [1mdistro[0m[2m==1.9.0[0m
 [32m+[39m [1meinops[0m[2m==0.8.1[0m
 [32m+[39m [1meval-type-backport[0m[2m==0.3.1[0m
 [32m+[39m [1mhf-xet[0m[2m==1.2.0[0m
 [32m+[39m [1mhuggingface-hub[0m[2m==1.3.2[0m
 [32m+[39m [1mposthog[0m[2m==6.9.3[0m
 [32m+[39m [1mpydantic[0m[2m==2.12.5[0m
 [32m+[39m [1mpydantic-core[0m[2m==2.41.5[0m
 [32m+[39m [1mpydantic-settings[0m[2m==2.12.0[0m
 [32m+[39m [1mpython-dotenv[0m[2m==1.2.1[0m
 [32m+[39m [1mruff[0m[2m==0.14.13[0m
 [31m-[39m [1mscikit-learn[0m[2m==1.8.0[0m
 [32m+[39m [1mscikit-learn[0m[2m==1.7.2[0m
 [32m+[39m [1mshellingham[0m[2m==1.5.4[0m
 [32m+[39m [1mtabpfn[0m[2m==6.3.1[0m
 [32m+[39m [1mtabpfn-common-utils[0m[2m==0.2.14[0m
 [32m+[39m [1mtyper-slim[0m[2m=

In [7]:
import matplotlib.pyplot as plt
import torch
import numpy as np

from sbi.inference import NPE
from sbi.analysis import pairplot

from npe_pfn import TabPFN_Based_NPE_PFN, run_tsnpe_pfn

from simulators import (
    create_lotka_volterra_prior,
    generate_observed_data,
    lotka_volterra_simulator,
    simulate,
)

# For reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Force CPU (NPE-PFN works on CPU)
device = "cpu"

ImportError: cannot import name '_is_pandas_df' from 'sklearn.utils.validation' (/Users/danielgedon/Dropbox/05_Postdoc/organizing/2601_hackathon_sbi_grenoble/sbi-hackathon-2026/.venv/lib/python3.12/site-packages/sklearn/utils/validation.py)

---
## Setup: Same Lotka-Volterra Problem

We continue with the predator-prey model from previous notebooks. The goal is to infer the 4 Lotka-Volterra parameters from summary statistics.

In [None]:
# Setup prior and observed data
prior = create_lotka_volterra_prior()
x_o, theta_o = generate_observed_data(use_autocorrelation=True)

# For visualization later
time = np.arange(0, 200, 0.1)
ts_observed = simulate(theta_o.numpy())

---
## Think First!

Before we use NPE-PFN, let's understand the key concepts:

**Question 1**: Standard NPE requires training a neural network for each new problem. What are the advantages and disadvantages of this?

**Question 2**: How can a pre-trained model work on problems it has never seen before?

<details>
<summary>Click to reveal answers</summary>

1. **Standard NPE trade-offs:**
   - **Advantages**: Tailored to your specific problem, can handle complex posteriors
   - **Disadvantages**: Requires training time, need many simulations, hyperparameter tuning

2. **How NPE-PFN generalizes:**
   - Based on TabPFN, a transformer trained on synthetic tabular regression problems
   - Learns "how to do regression" rather than a specific regression task
   - At test time: Conditions on your (θ, x) pairs as context → outputs posterior!
   - This is **in-context learning**: The model learns from examples you provide

</details>

---
## Part 1: Standard NPE (Baseline)

First, let's run standard NPE as a baseline for comparison. 

NOTE: we only consider 100 simulations here for now -- that is extremely little! 

In [None]:
# Generate training data
num_simulations = 100

theta = prior.sample((num_simulations,))
x = lotka_volterra_simulator(theta, use_autocorrelation=True)

print(f"Generated {num_simulations} simulations")
print(f"theta shape: {theta.shape}, x shape: {x.shape}")

In [None]:
# Train standard NPE
print("Training standard NPE...")
npe = NPE(prior)
npe.append_simulations(theta, x).train()

posterior_npe = npe.build_posterior()
samples_npe = posterior_npe.sample((10_000,), x=x_o)

print(f"\nStandard NPE trained! Posterior samples shape: {samples_npe.shape}")

---
## Part 2: NPE-PFN (Foundation Model)

Now let's use NPE-PFN — a foundation model that requires **no training**!

### How NPE-PFN Works

NPE-PFN is based on [TabPFN](https://arxiv.org/abs/2207.01848), a transformer that was pre-trained to solve tabular prediction problems. The key insight:

1. **Pre-training**: TabPFN was trained on millions of synthetic regression problems
2. **In-context learning**: At test time, it takes your (θ, x) pairs as "context"
3. **Posterior prediction**: Given a new observation x_o, it predicts the posterior over θ

**No gradient updates needed** — just forward passes through the pre-trained network!

### The NPE-PFN Workflow

```python
# 1. Create the posterior object (loads pre-trained model)
npe_pfn_posterior = TabPFN_Based_NPE_PFN(prior=prior)

# 2. Append simulations (these become the "context" for in-context learning)
npe_pfn_posterior.append_simulations(thetas, xs)

# 3. Sample from posterior (no training step!)
samples = npe_pfn_posterior.sample((num_samples,), x=x_o)
```

**Note**: The default context size is 10,000 simulations. If you provide more, NPE-PFN will filter them based on Euclidean distance to x_o.

### Your Task: Run NPE-PFN

Complete the code below to run inference with NPE-PFN.

**Hints**:
- Create the posterior with `TabPFN_Based_NPE_PFN(prior=prior)`
- Use `.append_simulations(theta, x)` to provide context (same data as standard NPE)
- Sample with `.sample((num_samples,), x=x_o)`

In [None]:
# SOLUTION: NPE-PFN inference
print("Running NPE-PFN (no training needed!)...")

# TODO for students: Create NPE-PFN posterior and append simulations
# Hint: Use TabPFN_Based_NPE_PFN(prior=prior)

npe_pfn_posterior = TabPFN_Based_NPE_PFN(prior=prior)
npe_pfn_posterior.append_simulations(theta, x)

# TODO for students: Sample from the posterior
# Hint: No training needed! Just call .sample()

samples_pfn = npe_pfn_posterior.sample((10_000,), x=x_o)

print(f"\nNPE-PFN done! Posterior samples shape: {samples_pfn.shape}")
print("Notice: No training step was needed!")

---
## Part 3: TSNPE-PFN (Sequential Version)

Standard NPE-PFN uses simulations from the prior. But what if we want to focus on a specific observation?

**TSNPE-PFN** (Truncated Sequential NPE-PFN) is a sequential variant that:
1. Starts with prior samples
2. Identifies which prior regions are consistent with x_o
3. Focuses new simulations on promising regions
4. Iterates to refine the posterior

This is similar to SNPE (Sequential NPE) but using the foundation model!

### The TSNPE-PFN Workflow

```python
# All-in-one function that handles the sequential rounds
tsnpe_pfn_posterior, tsnpe_pfn_samples = run_tsnpe_pfn(
    prior=prior,
    simulator=simulator,
    x_o=x_o,
    num_simulations_per_round=500,  # Simulations per sequential round
    num_rounds=3,                    # Number of sequential rounds
    num_posterior_samples=10_000,    # Final posterior samples
)
```

### Your Task: Run TSNPE-PFN

Complete the code below to run sequential inference with TSNPE-PFN.

**Hints**:
- Use `run_tsnpe_pfn()` with the prior, simulator, and observation
- The simulator should match the format we used: `lotka_volterra_simulator(theta, use_autocorrelation=True)`

In [None]:
# SOLUTION: TSNPE-PFN sequential inference
print("Running TSNPE-PFN (sequential, focused on x_o)...\n")

# Define simulator wrapper for TSNPE-PFN
def simulator(theta):
    return lotka_volterra_simulator(theta, use_autocorrelation=True)

# TODO for students: Run TSNPE-PFN
# Hint: Use run_tsnpe_pfn() with the prior, simulator, and observation

tsnpe_pfn_posterior, samples_tsnpe = run_tsnpe_pfn(
    prior=prior,
    simulator=simulator,
    x_o=x_o,
    num_simulations_per_round=500,
    num_rounds=3,
    num_posterior_samples=10_000,
)

print(f"\nTSNPE-PFN done! Posterior samples shape: {samples_tsnpe.shape}")

---
## Comparing All Methods

Now let's compare the posteriors from all three methods:
1. **Standard NPE**: Trained normalizing flow
2. **NPE-PFN**: Foundation model (no training)
3. **TSNPE-PFN**: Sequential foundation model

In [None]:
# Compare posteriors with pairplot
param_labels = [r"$\alpha$", r"$\beta$", r"$\delta$", r"$\gamma$"]
limits = [[0.05, 0.15], [0.01, 0.03], [0.005, 0.03], [0.005, 0.15]]

fig, axes = pairplot(
    [samples_npe, samples_pfn, samples_tsnpe],
    limits=limits,
    labels=param_labels,
    figsize=(10, 10),
    points=theta_o,
    points_colors="red",
    diag="kde",
)

# Add legend
fig.legend(
    ["Standard NPE", "NPE-PFN", "TSNPE-PFN"],
    loc="upper right",
    bbox_to_anchor=(0.95, 0.95),
)
plt.suptitle("Posterior Comparison: Standard NPE vs Foundation Models", y=1.02, fontsize=14)
plt.show()

In [None]:
# Compare posterior statistics
print("Parameter Recovery Comparison")
print("=" * 70)
print(f"{'Parameter':<10} {'True':<10} {'NPE':<15} {'NPE-PFN':<15} {'TSNPE-PFN':<15}")
print("-" * 70)

for i, (name, true_val) in enumerate(zip(["α", "β", "δ", "γ"], theta_o)):
    npe_mean = samples_npe[:, i].mean().item()
    npe_std = samples_npe[:, i].std().item()
    pfn_mean = samples_pfn[:, i].mean().item()
    pfn_std = samples_pfn[:, i].std().item()
    tsnpe_mean = samples_tsnpe[:, i].mean().item()
    tsnpe_std = samples_tsnpe[:, i].std().item()

    print(f"{name:<10} {true_val:.4f}    {npe_mean:.4f}±{npe_std:.3f}  {pfn_mean:.4f}±{pfn_std:.3f}  {tsnpe_mean:.4f}±{tsnpe_std:.3f}")

In [None]:
# Posterior Predictive Check for all methods
def plot_posterior_predictive_comparison(samples_list, labels, theta_o, ts_observed, time, n_samples=30):
    """Compare posterior predictive simulations for multiple methods."""
    fig, axes = plt.subplots(len(samples_list), 2, figsize=(14, 4*len(samples_list)))

    colors = ["C0", "C1", "C2"]

    for row, (samples, label, color) in enumerate(zip(samples_list, labels, colors)):
        indices = np.random.choice(len(samples), size=n_samples, replace=False)

        for idx in indices:
            theta_sample = samples[idx].numpy()
            ts_sample = simulate(theta_sample)
            axes[row, 0].plot(time, ts_sample[:, 0], color=color, alpha=0.2, linewidth=0.5)
            axes[row, 1].plot(time, ts_sample[:, 1], color=color, alpha=0.2, linewidth=0.5)

        # Plot ground truth
        axes[row, 0].plot(time, ts_observed[:, 0], color="black", linewidth=2, label="Observed")
        axes[row, 1].plot(time, ts_observed[:, 1], color="black", linewidth=2, label="Observed")

        axes[row, 0].set_ylabel("Population")
        axes[row, 0].set_title(f"Prey - {label}")
        axes[row, 0].legend()

        axes[row, 1].set_title(f"Predator - {label}")
        axes[row, 1].legend()

    axes[-1, 0].set_xlabel("Time (days)")
    axes[-1, 1].set_xlabel("Time (days)")

    plt.tight_layout()
    plt.show()

plot_posterior_predictive_comparison(
    [samples_npe, samples_pfn, samples_tsnpe],
    ["Standard NPE", "NPE-PFN", "TSNPE-PFN"],
    theta_o, ts_observed, time
)

---
## Summary

### Methods Comparison

| Method | Training | Simulations | Best For |
|--------|----------|-------------|----------|
| **Standard NPE** | Required (minutes) | Many (1000+) | Production, complex posteriors |
| **NPE-PFN** | None! | Moderate (100-10000) | Quick prototyping, iteration |
| **TSNPE-PFN** | None! | Fewer (focused) | Single observation, refinement |

### Key Takeaways

1. **Foundation models skip training**: NPE-PFN gives posteriors instantly using in-context learning
2. **Trade-offs exist**: Foundation models may be less flexible than trained models for complex problems
3. **Sequential variants help**: TSNPE-PFN focuses simulations for better efficiency
4. **Great for prototyping**: Try NPE-PFN first, then train a custom model if needed

### The NPE-PFN Pattern

```python
from npe_pfn import TabPFN_Based_NPE_PFN, run_tsnpe_pfn

# Amortized (works for any x_o)
posterior = TabPFN_Based_NPE_PFN(prior=prior)
posterior.append_simulations(thetas, xs)
samples = posterior.sample((N,), x=x_o)

# Sequential (focused on one x_o)
posterior, samples = run_tsnpe_pfn(
    prior=prior, simulator=simulator, x_o=x_o,
    num_simulations_per_round=500, num_rounds=3
)
```

**Further reading**: [NPE-PFN paper](https://arxiv.org/abs/2407.20482) | [GitHub repository](https://github.com/mackelab/npe-pfn)

---
## Learning Goals

After this notebook, you should be able to:

- ✅ Explain what foundation models are and why they're useful for SBI
- ✅ Use NPE-PFN for instant posterior estimation without training
- ✅ Apply TSNPE-PFN for sequential, observation-focused inference
- ✅ Compare foundation models with standard NPE approaches