## 1. Problem Setting

Consider the inverse problem for $u \in \mathbb{R}^d$ with the prior $u\sim \rho$, based on the observation $y \in \mathbb{R}^k$:
$$y = G(u) + \eta, \quad \eta \sim \mathcal{N}(0, \Gamma)$$
where $G: \mathbb{R}^d \rightarrow \mathbb{R}^k$ is the observation operator.

### 1.1 Bayesian Formulation
$$\pi^y(u) = \frac{1}{Z} \ell(y|u) \rho(u)$$
* **Likelihood**: $\ell(y|u) \propto \exp \left( -\frac{1}{2} \|y - G(u)\|^2_{\Gamma^{-1}} \right)$
* **Prior**: $\rho(u) \sim \mathcal{N}(0, I)$
* **Evidence**: $Z = \int \ell(y|u) \rho(u) du$

**GOAL:** Approximate $\pi^y(u)$.

### 1.2 Biochemical Oxygen Demand (BOD) Model
We estimate $u = [u_1, u_2]^\top \in \mathbb{R}^2$ from observations $y \in \mathbb{R}^5$. Physical parameters $A$ and $B$ are mapped from $u$ via the standard normal CDF $\Phi(\cdot)$ to enforce uniform priors $A \sim U(0.4, 1.2)$ and $B \sim U(0.01, 0.31)$:
\begin{align*}
A(u_1) &= 0.4 + 0.8\Phi(u_1) \\
B(u_2) &= 0.01 + 0.3\Phi(u_2)
\end{align*}

The forward operator $G(u)$ is defined by the BOD equation evaluated at $t \in \{1, 2, 3, 4, 5\}$:
$$\mathfrak{B}(t; u) = A(u_1) \left( 1 - \exp(-B(u_2)t) \right)$$

The noise covariance is $\Gamma = 10^{-3} I_5$.

### 1.3 Prior Transformation and Forward Operator
The Gaussian prior $\rho(u) = \mathcal{N}(u; 0, I_2)$ induces uniform distributions on the physical parameters $A$ and $B$ through the Probability Integral Transform. Since $\Phi(u_i) \sim U(0, 1)$ for $u_i \sim \mathcal{N}(0, 1)$, the affine transformations ensure:
* $A = 0.4 + 0.8\Phi(u_1) \sim U(0.4, 1.2)$
* $B = 0.01 + 0.3\Phi(u_2) \sim U(0.01, 0.31)$

The forward operator $G: \mathbb{R}^2 \rightarrow \mathbb{R}^5$ maps the latent vector $u$ to the observation space:
$$G(u) = \begin{bmatrix} \mathfrak{B}(1; u) \\ \mathfrak{B}(2; u) \\ \mathfrak{B}(3; u) \\ \mathfrak{B}(4; u) \\ \mathfrak{B}(5; u) \end{bmatrix} = \begin{bmatrix} A(u_1)(1 - e^{-B(u_2) \cdot 1}) \\ A(u_1)(1 - e^{-B(u_2) \cdot 2}) \\ A(u_1)(1 - e^{-B(u_2) \cdot 3}) \\ A(u_1)(1 - e^{-B(u_2) \cdot 4}) \\ A(u_1)(1 - e^{-B(u_2) \cdot 5}) \end{bmatrix}$$

This formulation allows for unconstrained optimization in $u$-space while satisfying the physical constraints on $A$ and $B$.

In [None]:
import torch
import torch.distributions as distributions
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

class BODPosterior:
    def __init__(self, y=None):
        self.dim = 2
        self.obs_time = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
        self.gamma_var = torch.tensor(1e-3)
        self.prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
        self.lik = distributions.MultivariateNormal(torch.zeros(5), torch.eye(5) * self.gamma_var)
        self.y = y

    def forward_model(self, x):
        # x: [N, 2]
        A = 0.4 + 0.8 * torch.distributions.Normal(0, 1).cdf(x[:, 0])
        B = 0.01 + 0.3 * torch.distributions.Normal(0, 1).cdf(x[:, 1])
        return A.unsqueeze(1) * (1 - torch.exp(-B.unsqueeze(1) * self.obs_time))

    def log_prob(self, x):
        # x: [N, 2]
        log_lik = self.lik.log_prob(self.y - self.forward_model(x))
        log_prior = self.prior.log_prob(x)
        return log_lik + log_prior

    def plot_posterior(self, ax=None, domain_x1=[-1.5, 1.5], domain_x2=[-0.75, 2.25], n_grid=100):
        if ax is None:
            fig, ax = plt.subplots(figsize=(6, 5))
        else:
            fig = ax.get_figure()

        x1 = torch.linspace(domain_x1[0], domain_x1[1], n_grid)
        x2 = torch.linspace(domain_x2[0], domain_x2[1], n_grid)
        xg = torch.meshgrid(x1, x2, indexing='ij')
        xx = torch.stack([xg[0].flatten(), xg[1].flatten()], dim=1)

        with torch.no_grad():
            density = torch.exp(self.log_prob(xx)).reshape(n_grid, n_grid).numpy()

        # Colormap setup
        threshold = np.percentile(density, 5)
        colors_list = plt.get_cmap('Oranges')(np.linspace(0, 0.8, 256))
        for i in range(50):
            alpha = i / 50.0
            colors_list[i] = (1 - alpha) * np.array([1, 1, 1, 1]) + alpha * colors_list[i]
        new_cmap = mcolors.LinearSegmentedColormap.from_list('WhiteOrange', colors_list)

        density_plot = np.where(density < threshold, 0, density)
        cf = ax.contourf(xg[0].numpy(), xg[1].numpy(), density_plot, cmap=new_cmap, levels=50, alpha=0.8)
        ax.set_xlabel("$u_1$"), ax.set_ylabel("$u_2$")
        ax.set_xlim(domain_x1), ax.set_ylim(domain_x2)
        return fig, ax

# Execution
y_obs = torch.tensor([0.1615, 0.1868, 0.3949, 0.3728, 0.4177])
pi = BODPosterior(y=y_obs)
pi.plot_posterior()

## 2. MAP Estimation

Maximum A Posteriori (MAP) estimation identifies the mode of the posterior distribution by solving a deterministic optimization problem:
$$u_{\text{MAP}} = \arg\max_{u \in \mathbb{R}^d} \pi^y(u)$$

By taking the negative logarithm and discarding terms independent of $u$ (the evidence $Z$ and Gaussian normalization constants), the maximization is equivalent to:
\begin{align*}
u_{\text{MAP}} &= \arg\min_{u} \left( -\log \ell(y|u) - \log \rho(u) \right) \\
&= \arg\min_{u} \left( \frac{1}{2} \| y - G(u) \|^2_{\Gamma^{-1}} + \frac{1}{2} \| u \|^2_2 \right)
\end{align*}

The MAP loss function is defined as:
$$\mathcal{L}_{\text{MAP}}(u) = \frac{1}{2} \| y - G(u) \|^2_{\Gamma^{-1}} + \frac{1}{2} \| u \|^2_2.$$


In [None]:
import torch
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML
import ipywidgets as widgets
from ipywidgets import interact_manual

def get_map_loss(u, obj):
    # Data misfit + Tikhonov regularization
    diff = obj.y - obj.forward_model(u.unsqueeze(0)).squeeze(0)
    return 0.5 * torch.sum(diff**2) / obj.gamma_var + 0.5 * torch.sum(u**2)

def run_map_optimization(u1_init=-1.2, u2_init=1.5, lr=0.1, steps=50, show_animation=True):
    # Re-initialize parameter and optimizer
    u_param = torch.tensor([u1_init, u2_init], requires_grad=True)
    optimizer = optim.Adam([u_param], lr=lr)

    history = []
    for _ in range(steps):
        history.append(u_param.detach().clone().numpy())
        optimizer.zero_grad()
        loss = get_map_loss(u_param, pi)
        loss.backward()
        optimizer.step()
    history = np.array(history)

    # 1. Plot Key Frames
    key_steps = [0, max(1, steps//10), max(1, steps//2), steps-1]
    fig_key, axes = plt.subplots(1, 4, figsize=(20, 4))
    for idx, step in enumerate(key_steps):
        pi.plot_posterior(ax=axes[idx])
        axes[idx].plot(history[:step+1, 0], history[:step+1, 1], 'k--', alpha=0.6)
        axes[idx].scatter(history[step, 0], history[step, 1], c='red', s=40)
        axes[idx].set_title(f"Step {step}")
    plt.show()

    # 2. Plot Animation (Optional)
    if show_animation:
        fig, ax = plt.subplots(figsize=(6, 5))
        def update(frame):
            ax.clear()
            pi.plot_posterior(ax=ax)
            ax.plot(history[:frame+1, 0], history[:frame+1, 1], 'k--', lw=1)
            ax.scatter(history[frame, 1], history[frame, 1], c='red', edgecolors='k', s=50, zorder=5)
            ax.set_title(f"MAP Iteration: {frame}")

        ani = FuncAnimation(fig, update, frames=len(history), interval=50)
        plt.close()
        return HTML(ani.to_jshtml())

# UI Components setup
u1_input = widgets.FloatText(value=-1.2, description='u1:')
u2_input = widgets.FloatText(value=1.5, description='u2:')
lr_input = widgets.FloatLogSlider(value=0.1, base=10, min=-3, max=0, step=0.1, description='LR:')
steps_input = widgets.IntSlider(value=50, min=10, max=200, step=10, description='Steps:')
anim_checkbox = widgets.Checkbox(value=False, description='Show Animation')

# Layout: Horizontal box for numerical/slider inputs
ui_top = widgets.HBox([u1_input, u2_input, lr_input, steps_input])

# Use interact_manual to add a "Run" button
out = widgets.interactive_output(run_map_optimization, {
    'u1_init': u1_input,
    'u2_init': u2_input,
    'lr': lr_input,
    'steps': steps_input,
    'show_animation': anim_checkbox
})

# Display UI
run_button = widgets.Button(description="Run Optimization", button_style='primary')

def on_button_clicked(b):
    with out:
        out.clear_output(wait=True)
        display(run_map_optimization(u1_input.value, u2_input.value, lr_input.value, steps_input.value, anim_checkbox.value))

run_button.on_click(on_button_clicked)
display(ui_top, anim_checkbox, run_button, out)

## 3. Gaussian Variational Inference (VI)

### 3.1 The Variational Objective
Minimizing $D_{\text{KL}}(q_\theta \| \pi^y)$ is equivalent to minimizing the variational loss $\mathcal{L}_{\text{VI}}(\theta)$ (the negative Evidence Lower Bound):
\begin{align*}
\mathcal{L}_{\text{VI}}(\theta) &= \mathbb{E}_{u \sim q_\theta} \left[ \log q_\theta(u) - \log \left( \ell(y|u)\rho(u) \right) \right] \\
&= \mathbb{E}_{u \sim q_\theta} \left[ \mathcal{L}_{\text{MAP}}(u) + \log q_\theta(u) \right] + C
\end{align*}
where $C$ contains terms independent of $\theta$.

### 3.2 Gaussian Variational Families
We define $q_\theta = \mathcal{N}(m, \Sigma)$. To ensure $\Sigma$ is positive definite, we optimize the Cholesky factor $L$ such that $\Sigma = LL^\top$.

* **Mean-Field Approximation**: $\Sigma$ is restricted to a diagonal matrix, where $L = \text{diag}(\sigma_1, \sigma_2)$.
* **Full-Rank Gaussian**: $L$ is a general lower-triangular matrix, allowing for posterior correlations.

### 3.3 Reparameterization Trick
We express $u$ as a transformation of a standard normal random variable $\epsilon \sim \mathcal{N}(0, I_d)$:
$$u = m + L\epsilon.$$

In [None]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from IPython.display import HTML, display
import ipywidgets as widgets
from tqdm.auto import tqdm
from matplotlib.animation import ArtistAnimation

# --- Target Definition ---

class BODPosterior:
    def __init__(self, y=None):
        self.dim = 2
        self.obs_time = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32)
        self.gamma_var = torch.tensor(1e-3)
        self.prior = dist.MultivariateNormal(torch.zeros(2), torch.eye(2))
        self.lik = dist.MultivariateNormal(torch.zeros(5), torch.eye(5) * self.gamma_var)
        self.y = y

    def forward_model(self, x):
        A = 0.4 + 0.8 * torch.distributions.Normal(0, 1).cdf(x[:, 0])
        B = 0.01 + 0.3 * torch.distributions.Normal(0, 1).cdf(x[:, 1])
        return A.unsqueeze(1) * (1 - torch.exp(-B.unsqueeze(1) * self.obs_time))

    def log_prob(self, x):
        log_lik = self.lik.log_prob(self.y - self.forward_model(x))
        log_prior = self.prior.log_prob(x)
        return log_lik + log_prior

    def plot_posterior(self, ax=None, domain_x1=[-1.5, 1.5], domain_x2=[-0.75, 2.25], n_grid=100):
        if ax is None:
            fig, ax = plt.subplots(figsize=(6, 5))
        else:
            fig = ax.get_figure()

        x1 = torch.linspace(domain_x1[0], domain_x1[1], n_grid)
        x2 = torch.linspace(domain_x2[0], domain_x2[1], n_grid)
        xg = torch.meshgrid(x1, x2, indexing='ij')
        xx = torch.stack([xg[0].flatten(), xg[1].flatten()], dim=1)

        with torch.no_grad():
            density = torch.exp(self.log_prob(xx)).reshape(n_grid, n_grid).numpy()

        threshold = np.percentile(density, 5)
        colors_list = plt.get_cmap('Oranges')(np.linspace(0, 0.8, 256))
        for i in range(50):
            alpha = i / 50.0
            colors_list[i] = (1 - alpha) * np.array([1, 1, 1, 1]) + alpha * colors_list[i]
        new_cmap = mcolors.LinearSegmentedColormap.from_list('WhiteOrange', colors_list)

        density_plot = np.where(density < threshold, 0, density)
        ax.contourf(xg[0].numpy(), xg[1].numpy(), density_plot, cmap=new_cmap, levels=50, alpha=0.8)
        ax.set_xlabel("$u_1$"), ax.set_ylabel("$u_2$")
        ax.set_xlim(domain_x1), ax.set_ylim(domain_x2)
        return fig, ax

# --- Variational Models ---

class MeanFieldGaussianApproximation(nn.Module):
    def __init__(self, pi_obj, init_m):
        super().__init__()
        self.pi = pi_obj
        self.m = nn.Parameter(init_m.clone())
        self.log_sigma = nn.Parameter(torch.zeros(2))

    def get_distribution(self):
        return dist.MultivariateNormal(self.m, scale_tril=torch.diag(torch.exp(self.log_sigma)))

    def sample(self, batch_size):
        return self.get_distribution().rsample((batch_size,))

    def neg_log_prob(self, x):
        return self.get_distribution().log_prob(x) - self.pi.log_prob(x)

class FullGaussianApproximation(nn.Module):
    def __init__(self, pi_obj, init_m):
        super().__init__()
        self.pi = pi_obj
        self.m = nn.Parameter(init_m.clone())
        self.l_params = nn.Parameter(torch.zeros(3))

    def get_distribution(self):
        L = torch.zeros((2, 2), device=self.m.device)
        L[0, 0] = torch.exp(self.l_params[0])
        L[1, 0] = self.l_params[1]
        L[1, 1] = torch.exp(self.l_params[2])
        return dist.MultivariateNormal(self.m, scale_tril=L)

    def sample(self, batch_size):
        return self.get_distribution().rsample((batch_size,))

    def neg_log_prob(self, x):
        return self.get_distribution().log_prob(x) - self.pi.log_prob(x)

# --- Main Execution ---

def run_vi_ml_style(u1_init=0.0, u2_init=0.0, cov_type='full',
                    lr=1e-1, num_epochs=100, batch_size=64, show_anim=False):
    init_m = torch.tensor([u1_init, u2_init])
    record_freq = 5

    target = MeanFieldGaussianApproximation(pi, init_m) if cov_type == 'diagonal' else FullGaussianApproximation(pi, init_m)
    optimizer = optim.Adam(target.parameters(), lr=lr)

    milestones = [0, int(num_epochs * 0.25), int(num_epochs * 0.5), num_epochs - 1]
    fig_static, axes = plt.subplots(1, 4, figsize=(20, 5))

    # Grid for contours
    x1_line = torch.linspace(-1.5, 1.5, 100)
    x2_line = torch.linspace(-0.75, 2.25, 100)
    xg = torch.meshgrid(x1_line, x2_line, indexing='ij')
    xx = torch.stack([xg[0].flatten(), xg[1].flatten()], dim=1)

    frames = []
    if show_anim:
        fig_anim, ax_anim = plt.subplots(figsize=(6, 5))
        pi.plot_posterior(ax=ax_anim)
        plt.close(fig_anim)

    pbar = tqdm(range(num_epochs), desc="Training VI")
    plot_idx = 0

    for epoch in pbar:
        if epoch in milestones:
            ax = axes[plot_idx]
            pi.plot_posterior(ax=ax)
            with torch.no_grad():
                xs = target.sample(batch_size).cpu().numpy()
                ax.plot(xs[:, 0], xs[:, 1], '.b', markersize=6, alpha=0.6)
                log_px = target.get_distribution().log_prob(xx)
                px = torch.exp(log_px).reshape(100, 100).numpy()
                ax.contour(xg[0].numpy(), xg[1].numpy(), px, colors='black', levels=5)
            ax.set_title(f"Epoch {epoch}")
            plot_idx += 1

        if show_anim and (epoch % record_freq == 0 or epoch == num_epochs - 1):
            with torch.no_grad():
                xs = target.sample(batch_size).cpu().numpy()
                px = torch.exp(target.get_distribution().log_prob(xx)).reshape(100, 100).numpy()

                pts, = ax_anim.plot(xs[:, 0], xs[:, 1], '.b', markersize=6, alpha=0.6)
                cnt = ax_anim.contour(xg[0].numpy(), xg[1].numpy(), px, colors='black', levels=5)
                txt = ax_anim.text(0.05, 0.92, f'Epoch: {epoch}', transform=ax_anim.transAxes, fontweight='bold')

                # Fixed: cnt itself is added to the list, not .collections
                frames.append([pts, txt] + list(cnt.collections if hasattr(cnt, 'collections') else [cnt]))

        # Train
        x = target.sample(batch_size)
        loss = target.neg_log_prob(x).mean()
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    plt.figure(fig_static.number); plt.tight_layout(); plt.show()

    if show_anim and frames:
        ani = ArtistAnimation(fig_anim, frames, interval=100, blit=False)
        display(HTML(ani.to_jshtml()))

# --- UI Setup ---
y_obs = torch.tensor([0.1615, 0.1868, 0.3949, 0.3728, 0.4177])
pi = BODPosterior(y=y_obs)

u1_in = widgets.FloatText(value=0.0, description='u1 init:')
u2_in = widgets.FloatText(value=0.0, description='u2 init:')
cov_in = widgets.Dropdown(options=[('Diagonal', 'diagonal'), ('Full', 'full')], value='full', description='Type:')
lr_in = widgets.FloatLogSlider(value=0.1, base=10, min=-3, max=0, step=0.1, description='LR:')
epochs_in = widgets.IntSlider(value=100, min=10, max=500, step=10, description='Epochs:')
bs_in = widgets.IntSlider(value=64, min=16, max=512, step=16, description='Batch:')
anim_in = widgets.Checkbox(value=False, description='Show Animation')

ui_box = widgets.VBox([widgets.HBox([u1_in, u2_in, cov_in]), widgets.HBox([lr_in, epochs_in, bs_in, anim_in])])
run_btn = widgets.Button(description="Run VI Gaussian", button_style='success')
out = widgets.Output()

def on_click(b):
    with out:
        out.clear_output(wait=True)
        run_vi_ml_style(u1_in.value, u2_in.value, cov_in.value, lr_in.value, epochs_in.value, bs_in.value, anim_in.value)

run_btn.on_click(on_click)
display(ui_box, run_btn, out)

## 4. Variational Inference via Transport Maps

Instead of prescribing a fixed density family, we learn a **transport map** $T(\cdot; \theta): \mathbb{R}^d \to \mathbb{R}^d$ that pushes forward a simple reference distribution $\varrho$ (typically $\mathcal{N}(0, I)$) to the target posterior $\pi^y$.

### 4.1 Variational Formulation
The push-forward density $q_\theta = T_\#\varrho$ is defined via the change of variables formula:
$$q_\theta(u) = \varrho(T^{-1}(u; \theta)) \left| \det \nabla_u T^{-1}(u; \theta) \right|$$

Minimizing the KL divergence $D_{KL}(q_\theta \| \pi^y)$ is equivalent to minimizing the divergence in the reference space:
$$\min_\theta D_{KL}(\varrho \| T^{-1}_\sharp \pi^y)$$
where $T^{-1}_\sharp \pi^y$ is the pull-back of the posterior to the reference space. This leads to the following loss function:
$$\mathcal{L}_{\text{Flow}}(\theta) = -\mathbb{E}_{z \sim \varrho} \left[ \log \rho(T(z; \theta)) + \log \ell(y | T(z; \theta)) + \log \left| \det \nabla_z T(z; \theta) \right| \right]$$

Given i.i.d. samples $z^{(1)}, \dots, z^{(N)} \sim \varrho$, the empirical objective is:
$$\widehat{\theta} = \arg\min_{\theta} \frac{1}{N} \sum_{i=1}^N \left[ \mathcal{L}_{\text{MAP}}(T(z^{(i)}; \theta)) - \log \left| \det \nabla_z T(z^{(i)}; \theta) \right| \right]$$
where $\mathcal{L}_{\text{MAP}}$ is the negative log-posterior (data misfit + regularization) defined in Section 2.

### 4.2 Model Architecture: Real NVP
We utilize **Normalizing Flows** via Real Non-Volume Preserving (Real NVP) layers to ensure $T$ is both invertible and possesses a tractable Jacobian determinant. The map is constructed as a composition of $K$ affine coupling layers:
$$T = L_K \circ L_{K-1} \circ \cdots \circ L_1$$
where each $L_k$ is an instance of the coupling layers defined below.

#### Affine Coupling Layers
For a 2D input $z = [z_1, z_2]^\top$, we alternate between two types of layers.

**Type 1 ($L^{(1)}$):** Updates $z_2$ based on $z_1$
\begin{equation*}
L^{(1)}(z) = \begin{bmatrix} z_1 \\ z_2 \exp(s(z_1)) + t(z_1) \end{bmatrix}
\end{equation*}

**Type 2 ($L^{(2)}$):** Updates $z_1$ based on $z_2$
\begin{equation*}
L^{(2)}(z) = \begin{bmatrix} z_1 \exp(s(z_2)) + t(z_2) \\ z_2 \end{bmatrix}
\end{equation*}

These can be unified using a binary mask $m \in \{0, 1\}^2$ and the **Hadamard product** $\odot$:
$$L(z; m) = m \odot z + (1 - m) \odot \left( z \odot \exp(s(m \odot z)) + t(m \odot z) \right)$$
where $s(\cdot)$ and $t(\cdot)$ are scale and translation neural networks. For $L^{(1)}$, $m=[1, 0]^\top$; for $L^{(2)}$, $m=[0, 1]^\top$. In practice, the sequence $L_1, \dots, L_K$ is formed by alternating these masks.



#### Triangular Jacobian and Log-Determinant
The efficiency of Real NVP lies in the structure of its Jacobian matrix $\nabla_z L$. For $L^{(1)}$, the Jacobian is:
\begin{equation*}
\nabla_z L^{(1)} = \begin{bmatrix}
1 & 0 \\
\frac{\partial (L^{(1)}(z))_2}{\partial z_1} & \exp(s(z_1))
\end{bmatrix}
\end{equation*}
Since the Jacobian is lower triangular, its determinant is the product of the diagonal elements. The log-determinant for a single layer is:
$$\log \left| \det \nabla_z L \right| = \sum (1 - m) \odot s(m \odot z)$$
For the total map $T$, the total log-determinant is the sum of the log-scaling factors across all layers.

#### Inverse Mapping
The architecture allows for an exact, closed-form inverse. For $L^{(1)}$, the inverse is:
\begin{equation*}
(L^{(1)})^{-1}(z') = \begin{bmatrix} z'_1 \\ (z'_2 - t(z'_1)) \exp(-s(z'_1)) \end{bmatrix}
\end{equation*}
This ensures we can evaluate the approximate density and sample from the posterior without iterative solvers.

In [None]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from IPython.display import HTML, display
import ipywidgets as widgets
from tqdm.auto import tqdm
from matplotlib.animation import ArtistAnimation

# --- RealNVP Model Components ---

class RealNVP(nn.Module):
    """
    Implementation of RealNVP (Real-valued Non-Volume Preserving) flow.
    It learns a bijective mapping T: Z -> X between a simple reference
    distribution (Z) and a complex target posterior (X).
    """
    def __init__(self, nets, nett, masks, pi_obj, ref_dist):
        super(RealNVP, self).__init__()
        self.pi = pi_obj   # Target log-posterior
        self.ref = ref_dist # Reference distribution (usually Standard Gaussian)
        self.mask = nn.Parameter(masks, requires_grad=False)

        # Scaling (s) and Translation (t) networks for each coupling layer
        self.s = nn.ModuleList([nets() for _ in range(len(masks))])
        self.t = nn.ModuleList([nett() for _ in range(len(masks))])

    def T(self, z):
        """
        Forward transform: Reference Z -> Target X.
        Used for sampling: x = T(z).
        """
        log_det_J, x = z.new_zeros(z.shape[0]), z
        for i in range(len(self.s)):
            x_ = x * self.mask[i]
            # s and t only depend on masked dimensions
            s = self.s[i](x_) * (1 - self.mask[i])
            t = self.t[i](x_) * (1 - self.mask[i])
            # Affine transformation: x = x_gate * exp(s) + t
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
            log_det_J += s.sum(dim=1)
        return x, log_det_J

    def Tinv(self, x):
        """
        Inverse transform: Target X -> Reference Z.
        Used for density estimation and KL divergence: z = T_inv(x).
        """
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.s))):
            z_ = self.mask[i] * z
            s = self.s[i](z_) * (1 - self.mask[i])
            t = self.t[i](z_) * (1 - self.mask[i])
            # Inverse affine: z = (z_gate - t) * exp(-s)
            z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
            log_det_J -= s.sum(dim=1)
        return z, log_det_J

    def log_prob_transport(self, z):
        """
        Calculates log pi(T(z)) + log|det J_T(z)|.
        Maximizing this is equivalent to minimizing KL(q || pi).
        """
        x, logdet = self.T(z)
        return self.pi.log_prob(x) + logdet

    def approximate_log_prob(self, x):
        """
        Calculates the density of the push-forward distribution q(x).
        log q(x) = log ref(T_inv(x)) + log|det J_T_inv(x)|.
        """
        z, logdet = self.Tinv(x)
        return self.ref.log_prob(z) + logdet

    def sample(self, batch_size):
        """Generates samples from the variational distribution q(x)."""
        z = self.ref.sample((batch_size,))
        x, _ = self.T(z)
        return x

# --- Training Execution Function ---

def run_transport_vi(lr=1e-3, num_epochs=200, batch_size=64,
                     num_layers=6, hidden_dim=256, show_anim=False):
    """
    Executes the VI training loop and returns the trained RealNVP model.
    """
    # Setup Target and Reference
    ref = dist.MultivariateNormal(torch.zeros(2), torch.eye(2))

    # Define Network Architectures
    nets = lambda: nn.Sequential(nn.Linear(2, hidden_dim), nn.LeakyReLU(),
                                 nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(),
                                 nn.Linear(hidden_dim, 2), nn.Tanh())
    nett = lambda: nn.Sequential(nn.Linear(2, hidden_dim), nn.LeakyReLU(),
                                 nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(),
                                 nn.Linear(hidden_dim, 2))

    # Create alternating binary masks
    mask_list = [[0, 1], [1, 0]] * (num_layers // 2)
    masks = torch.tensor(mask_list, dtype=torch.float32)

    model = RealNVP(nets, nett, masks, pi, ref)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Visualization setup
    milestones = [0, int(num_epochs * 0.25), int(num_epochs * 0.5), num_epochs - 1]
    fig_static, axes = plt.subplots(1, 4, figsize=(20, 5))

    x1_line = torch.linspace(-1.5, 1.5, 100)
    x2_line = torch.linspace(-0.75, 2.25, 100)
    xg = torch.meshgrid(x1_line, x2_line, indexing='ij')
    xx = torch.stack([xg[0].flatten(), xg[1].flatten()], dim=1)

    frames = []
    if show_anim:
        fig_anim, ax_anim = plt.subplots(figsize=(6, 5))
        pi.plot_posterior(ax=ax_anim)
        plt.close(fig_anim)

    pbar = tqdm(range(num_epochs), desc="Training Transport VI")
    plot_idx = 0

    for epoch in pbar:
        # Optimization
        z_samples = model.ref.sample((batch_size,))
        loss = -model.log_prob_transport(z_samples).mean()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        # Plotting Static Milestones
        if epoch in milestones:
            ax = axes[plot_idx]
            pi.plot_posterior(ax=ax)
            with torch.no_grad():
                xs = model.sample(batch_size).cpu().numpy()
                ax.plot(xs[:, 0], xs[:, 1], '.b', markersize=4, alpha=0.5)
                log_qx = model.approximate_log_prob(xx)
                qx = torch.exp(log_qx).reshape(100, 100).numpy()
                ax.contour(xg[0].numpy(), xg[1].numpy(), qx, colors='black', levels=7)
            ax.set_title(f"Epoch {epoch}")
            plot_idx += 1

        # Recording Animation
        if show_anim and (epoch % 5 == 0 or epoch == num_epochs - 1):
            with torch.no_grad():
                xs = model.sample(batch_size).cpu().numpy()
                qx = torch.exp(model.approximate_log_prob(xx)).reshape(100, 100).numpy()
                pts, = ax_anim.plot(xs[:, 0], xs[:, 1], '.b', markersize=4, alpha=0.5)
                cnt = ax_anim.contour(xg[0].numpy(), xg[1].numpy(), qx, colors='black', levels=7)
                txt = ax_anim.text(0.05, 0.92, f'Epoch: {epoch}', transform=ax_anim.transAxes, fontweight='bold')
                frames.append([pts, txt] + list(cnt.collections if hasattr(cnt, 'collections') else [cnt]))

        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    plt.figure(fig_static.number); plt.tight_layout(); plt.show()

    if show_anim and frames:
        ani = ArtistAnimation(fig_anim, frames, interval=100, blit=False)
        display(HTML(ani.to_jshtml()))

    return model

# --- UI Setup ---

lr_in = widgets.FloatLogSlider(value=1e-3, base=10, min=-4, max=-1, step=0.1, description='LR:')
epochs_in = widgets.IntSlider(value=100, min=50, max=1000, step=50, description='Epochs:')
layers_in = widgets.IntSlider(value=6, min=2, max=12, step=2, description='Layers:')
dim_in = widgets.Dropdown(options=[64, 128, 256, 512], value=256, description='Hidden Dim:')
bs_in = widgets.IntSlider(value=64, min=32, max=256, step=32, description='Batch:')
anim_in = widgets.Checkbox(value=False, description='Show Animation')

ui_box = widgets.VBox([
    widgets.HBox([lr_in, epochs_in, bs_in]),
    widgets.HBox([layers_in, dim_in, anim_in])
])
run_btn = widgets.Button(description="Run Transport VI", button_style='info')
out = widgets.Output()

# Placeholder for the trained model in the global scope
flow = None

def on_click(b):
    global flow
    with out:
        out.clear_output(wait=True)
        # Train and capture the model output
        flow = run_transport_vi(lr_in.value, epochs_in.value, bs_in.value,
                                layers_in.value, dim_in.value, anim_in.value)

run_btn.on_click(on_click)
display(ui_box, run_btn, out)

In [None]:
# --- Posterior Sampling & Latent Mapping Test Block ---

def run_sampling_test(model, n_samples=300):
    """
    Test the trained flow model by visualizing samples and their mapping
    back to the latent space using the training observation.

    Enhanced smoothness for background fills.
    """
    pi_test = BODPosterior(y=y_obs)

    u1_range, u2_range = [-1.5, 1.5], [-0.75, 2.25]
    z_range = [-4, 4]

    # 1. High-resolution Grids for smoother backgrounds
    u_res, z_res = 200, 200 # Increased resolution
    u1 = torch.linspace(u1_range[0], u1_range[1], u_res)
    u2 = torch.linspace(u2_range[0], u2_range[1], u_res)
    ug1, ug2 = torch.meshgrid(u1, u2, indexing='ij')
    uu = torch.stack([ug1.flatten(), ug2.flatten()], dim=1)

    z1 = torch.linspace(z_range[0], z_range[1], z_res)
    z2 = torch.linspace(z_range[0], z_range[1], z_res)
    zg1, zg2 = torch.meshgrid(z1, z2, indexing='ij')
    zz = torch.stack([zg1.flatten(), zg2.flatten()], dim=1)

    # 2. Model Inference
    model.eval()
    with torch.no_grad():
        xs = model.sample(n_samples)
        zs, _ = model.Tinv(xs)

        log_qx = model.approximate_log_prob(uu)
        log_pi = pi_test.log_prob(uu)
        log_pz = model.ref.log_prob(zz)

    # 3. Plotting
    plt.figure(figsize=(14, 6))

    # --- Left Plot: Target Space (x-space) ---
    plt.subplot(1, 2, 1)
    # Background: Smooth true posterior
    pi_density = torch.exp(log_pi).reshape(u_res, u_res).numpy()
    plt.contourf(ug1.numpy(), ug2.numpy(), pi_density,
                 levels=100, cmap='Oranges', alpha=0.4) # High levels for smoothness

    # Contours: Model's learned density q(x)
    plt.contour(ug1.numpy(), ug2.numpy(), torch.exp(log_qx).reshape(u_res, u_res).numpy(),
                levels=8, colors='black', linewidths=0.8)

    # Points: Blue samples
    plt.plot(xs[:, 0], xs[:, 1], '.b', markersize=4, alpha=0.5, label=r'Samples $x \sim q(x)$')
    plt.title(r"Target Space: $x = T(z)$")
    plt.xlabel("$u_1$"); plt.ylabel("$u_2$"); plt.xlim(u1_range); plt.ylim(u2_range); plt.legend()

    # --- Right Plot: Latent Space (z-space) ---
    plt.subplot(1, 2, 2)

    pz_density = torch.exp(log_pz).reshape(z_res, z_res).numpy()
    max_pz = pz_density.max()

    # Ultra-smooth non-linear levels for the Greens background
    # 200 levels create a nearly continuous gradient
    levs_smooth = np.linspace(0, 1, 200)**1.5 * max_pz

    # Fill: Smooth Reference density N(0,I)
    plt.contourf(zg1.numpy(), zg2.numpy(), pz_density,
                 levels=levs_smooth, cmap='Greens', alpha=0.85)

    # Contours: Analytical Reference lines (fewer lines to avoid clutter)
    plt.contour(zg1.numpy(), zg2.numpy(), pz_density,
                levels=levs_smooth[::25], colors='black', linewidths=0.6, alpha=0.4)

    # Points: Blue samples pulled back via T_inv
    plt.plot(zs[:, 0], zs[:, 1], '.b', markersize=4, alpha=0.5, label=r'Latent $z = T^{-1}(x)$')

    plt.title(r"Latent Space: $z = T^{-1}(x)$")
    plt.xlabel("$z_1$"); plt.ylabel("$z_2$"); plt.xlim(z_range); plt.ylim(z_range); plt.legend()

    plt.tight_layout()
    plt.show()

# --- Execution ---
if 'flow' in globals():
    run_sampling_test(flow, n_samples=400)

## 5. Method 4: Amortized Variational Inference

In the previous section, the transport map $T$ was optimized for a specific observation $y$. To avoid re-training the model for every new data point, we use **Amortized Variational Inference**. Here, we learn a single **conditional transport map** $T(z; y, \theta)$ that maps the reference distribution $\varrho$ to the posterior $\pi^y$ for any given $y$.

### 5.1 Conditional Variational Formulation
The objective is to minimize the expected KL divergence over the distribution of observations $p(y)$:
$$\min_\theta \mathbb{E}_{y \sim p(y)} \left[ D_{KL}(q_{\theta}(\cdot | y) \| \pi^y) \right]$$
where $q_{\theta}(u | y) = T(\cdot; y, \theta)_\# \varrho$. The resulting loss function $\mathcal{L}_{\text{Amort}}(\theta)$ extends the previous formulation by incorporating $y$ as a contextual input:
$$\mathcal{L}_{\text{Amort}}(\theta) = -\mathbb{E}_{y \sim p(y)} \mathbb{E}_{z \sim \varrho} \left[ \log \rho(T(z; y, \theta)) + \log \ell(y | T(z; y, \theta)) + \log \left| \det \nabla_z T(z; y, \theta) \right| \right]$$

### 5.2 Conditional Real NVP Architecture
To amortize the inference, we modify the affine coupling layers $L^{(1)}$ and $L^{(2)}$ defined in Section 4.2. The key change is that the scale and translation networks, $s(\cdot)$ and $t(\cdot)$, now also take the observation $y$ as an input.

For a 2D input $z$, the **Conditional Type 1 layer** $L^{(1)}(z, y)$ is:
\begin{equation*}
L^{(1)}(z, y) = \begin{bmatrix} z_1 \\ z_2 \exp(s(z_1, y)) + t(z_1, y) \end{bmatrix}
\end{equation*}

The unified masked formulation becomes:
$$L(z; y, m) = m \odot z + (1 - m) \odot \left( z \odot \exp(s(m \odot z, y)) + t(m \odot z, y) \right)$$

#### Contextual Embedding
In practice, the observation $y$ is often high-dimensional. Before being fed into the coupling layers, $y$ is typically processed by an **embedding network** (or summary network) $h_\phi(y)$—such as a CNN or ResNet—to extract relevant features:
$$s = s(m \odot z, h_\phi(y)), \quad t = t(m \odot z, h_\phi(y))$$
The parameters $\theta$ of the transport map now include the weights of the coupling layers and the embedding network.

#### Jacobian and Inverse
The Jacobian remains block-triangular with respect to $z$, as the condition $y$ is constant during the push-forward and pull-back operations:
$$\log \left| \det \nabla_z L \right| = \sum (1 - m) \odot s(m \odot z, y)$$
The inverse mapping also follows the same closed-form structure as in Section 4.2, conditioned on $y$:
$$(L^{(1)})^{-1}(z', y) = \begin{bmatrix} z'_1 \\ (z'_2 - t(z'_1, y)) \exp(-s(z'_1, y)) \end{bmatrix}$$



### 5.3 Training and Inference
During training, we sample pairs $(u^{(i)}, y^{(i)})$ from the joint distribution (or a simulator) to minimize the objective. At test time, for a *new* observation $y^*$, we obtain posterior samples $u \sim \pi^{y^*}$ simply by passing $z \sim \varrho$ through the pre-trained map $T(\cdot; y^*, \widehat{\theta})$, bypassing any further optimization or MCMC sampling.

In [None]:
import torch
import torch.nn as nn
import torch.distributions as dist
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from tqdm.auto import tqdm

# --- 1. Helper: Joint Sampling Logic ---

def generate_joint_samples(posterior, batch_size):
    """
    Functionality: Generates joint samples (u, y) for training.
    Input: posterior (BODPosterior instance), batch_size.
    Output: u (parameters), y (simulated data with noise).
    """
    # 1. Sample u from prior N(0, I)
    u = posterior.prior.sample((batch_size,))

    # 2. Forward model (get mean y)
    y_mean = posterior.forward_model(u)

    # 3. Add observation noise: y = f(u) + epsilon
    # Noise std = sqrt(gamma_var)
    noise_std = torch.sqrt(posterior.gamma_var)
    y = y_mean + torch.randn_like(y_mean) * noise_std

    return u, y

# --- 2. Amortized Flow Model ---

class AmortizedFlow(nn.Module):
    # Implements RealNVP with amortized context y.
    def __init__(self, nets, nett, masks, ref_dist):
        super(AmortizedFlow, self).__init__()
        self.ref = ref_dist
        self.mask = nn.Parameter(masks, requires_grad=False)
        self.s = nn.ModuleList([nets() for _ in range(len(masks))])
        self.t = nn.ModuleList([nett() for _ in range(len(masks))])

    def T(self, z, y):
        x = z
        for i in range(len(self.s)):
            x_ = x * self.mask[i]
            xy = torch.cat([x_, y], dim=1)
            s = self.s[i](xy) * (1 - self.mask[i])
            t = self.t[i](xy) * (1 - self.mask[i])
            x = x_ + (1 - self.mask[i]) * (x * torch.exp(s) + t)
        return x

    def Tinv(self, x, y):
        log_det_J, z = x.new_zeros(x.shape[0]), x
        for i in reversed(range(len(self.s))):
            z_ = z * self.mask[i]
            zy = torch.cat([z_, y], dim=1)
            s = self.s[i](zy) * (1 - self.mask[i])
            t = self.t[i](zy) * (1 - self.mask[i])
            z = (1 - self.mask[i]) * (z - t) * torch.exp(-s) + z_
            log_det_J -= s.sum(dim=1)
        return z, log_det_J

    def log_prob(self, x, y):
        z, logdet = self.Tinv(x, y)
        return self.ref.log_prob(z) + logdet

    def sample(self, batch_size, y):
        if y.dim() == 1:
            y = y.unsqueeze(0).repeat(batch_size, 1)
        elif y.shape[0] == 1:
            y = y.repeat(batch_size, 1)
        z = self.ref.sample((batch_size,))
        x = self.T(z, y)
        return x

# --- 3. Training Function ---

def run_amortized_vi(lr=1e-3, num_epochs=500, batch_size=128,
                     num_layers=6, hidden_dim=256):
    """
    Functionality: Trains the model.
    Visualization:
      - Samples a random y.
      - Creates a temporary BODPosterior(y=random_y) to plot ground truth.
      - Overlays model predictions.
      - Uses adaptive axes limits based on model samples.
    """
    ref = dist.MultivariateNormal(torch.zeros(2), torch.eye(2))
    in_dim = 2 + 5

    nets = lambda: nn.Sequential(nn.Linear(in_dim, hidden_dim), nn.LeakyReLU(),
                                 nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(),
                                 nn.Linear(hidden_dim, 2), nn.Tanh())
    nett = lambda: nn.Sequential(nn.Linear(in_dim, hidden_dim), nn.LeakyReLU(),
                                 nn.Linear(hidden_dim, hidden_dim), nn.LeakyReLU(),
                                 nn.Linear(hidden_dim, 2))

    masks = torch.tensor([[0, 1], [1, 0]] * (num_layers // 2), dtype=torch.float32)
    model = AmortizedFlow(nets, nett, masks, ref)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    # Setup visualization
    fig, axes = plt.subplots(1, 4, figsize=(24, 6))
    milestones = [0, int(num_epochs * 0.25), int(num_epochs * 0.5), num_epochs - 1]
    plot_idx = 0

    pbar = tqdm(range(num_epochs), desc="Training Amortized VI")

    for epoch in pbar:
        model.train()
        # 1. Train Step
        # Assumes 'pi' (global instance) is available for accessing prior/simulator
        u_batch, y_batch = generate_joint_samples(pi, batch_size)
        loss = -model.log_prob(u_batch, y_batch).mean()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 2. Adaptive Visualization Step
        if epoch in milestones:
            model.eval()
            with torch.no_grad():
                # A. Select Random Y for Validation (simulate from prior)
                _, y_vis = generate_joint_samples(pi, 1) # Shape [1, 5]

                # B. Generate Samples first to determine Adaptive Domain
                u_samples = model.sample(200, y_vis).cpu().numpy()
                u_min, u_max = u_samples.min(axis=0), u_samples.max(axis=0)
                margin = 0.5
                d_x1 = [float(u_min[0] - margin), float(u_max[0] + margin)]
                d_x2 = [float(u_min[1] - margin), float(u_max[1] + margin)]

                # C. Re-generate Grid based on Adaptive Domain
                x1_line = torch.linspace(d_x1[0], d_x1[1], 100)
                x2_line = torch.linspace(d_x2[0], d_x2[1], 100)
                xg = torch.meshgrid(x1_line, x2_line, indexing='ij')
                xx = torch.stack([xg[0].flatten(), xg[1].flatten()], dim=1)

                # D. Compute Model Density on New Grid
                y_grid = y_vis.repeat(xx.shape[0], 1)
                log_qu = model.log_prob(xx, y_grid)
                qu_model = torch.exp(log_qu).reshape(100, 100).numpy()

                # E. Plotting
                ax = axes[plot_idx]

                # IMPORTANT: Create a temporary posterior instance for this SPECIFIC random y
                # This allows .plot_posterior to use the correct likelihood for visualization
                pi_vis = BODPosterior(y=y_vis)

                # Call plot_posterior on the TEMP instance with adaptive domains
                pi_vis.plot_posterior(ax=ax, domain_x1=d_x1, domain_x2=d_x2)

                # Overlay Model (Black Contours)
                ax.contour(xg[0].numpy(), xg[1].numpy(), qu_model, colors='black', levels=7)

                # Overlay Samples (Blue Dots)
                ax.plot(u_samples[:, 0], u_samples[:, 1], '.b', markersize=4, alpha=0.4)

                ax.set_title(f"Epoch {epoch}\nAdaptive View")
                plot_idx += 1

        pbar.set_postfix({'loss': f"{loss.item():.4f}"})

    plt.tight_layout()
    plt.show()
    return model

# --- 4. UI Setup ---

lr_in = widgets.FloatLogSlider(value=1e-3, base=10, min=-4, max=-1, step=0.1, description='LR:')
epochs_in = widgets.IntSlider(value=500, min=100, max=2000, step=100, description='Epochs:')
layers_in = widgets.IntSlider(value=6, min=2, max=12, step=2, description='Layers:')
dim_in = widgets.Dropdown(options=[64, 128, 256, 512], value=256, description='Hidden Dim:')
bs_in = widgets.IntSlider(value=128, min=32, max=512, step=32, description='Batch:')

ui_box = widgets.VBox([
    widgets.HBox([lr_in, epochs_in, bs_in]),
    widgets.HBox([layers_in, dim_in])
])
run_btn = widgets.Button(description="Run Training", button_style='success')
out = widgets.Output()

def on_click(b):
    with out:
        out.clear_output(wait=True)
        # Assumes 'pi' is initialized in the global scope (e.g., pi = BODPosterior(y_obs))
        _ = run_amortized_vi(lr_in.value, epochs_in.value, bs_in.value,
                             layers_in.value, dim_in.value)

run_btn.on_click(on_click)
display(ui_box, run_btn, out)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# --- Amortized Posterior Sampling & Latent Mapping Test Block ---

def run_amortized_test(model, test_y=None, n_samples=400, title_suffix=""):
    """
    Test the amortized flow model on a specific y.
    1. Visualizes how samples from q(u|y) match the true posterior p(u|y).
    2. Visualizes how the model pulls samples back to the latent N(0, I).
    """
    if test_y is None:
        test_y = pi.y

    # Ensure test_y is [1, 5]
    if test_y.dim() == 1:
        test_y = test_y.unsqueeze(0)

    # Define a local posterior object for this specific y to compute Ground Truth
    # (We detach to ensure no gradients are tracked during plotting)
    pi_test = BODPosterior(y=test_y.detach())

    # --- 1. Model Inference & Auto-Scaling ---
    model.eval()
    with torch.no_grad():
        # A. Sample u ~ q(u|y)
        us = model.sample(n_samples, test_y)

        # B. Inverse transform z = T^-1(u, y)
        y_context_samples = test_y.repeat(n_samples, 1)
        zs, _ = model.Tinv(us, y_context_samples)

        # C. Auto-scale Plotting Range based on samples
        u_min, u_max = us.min(dim=0)[0].cpu(), us.max(dim=0)[0].cpu()
        margin = 0.6
        u_range = [
            [float(u_min[0] - margin), float(u_max[0] + margin)],
            [float(u_min[1] - margin), float(u_max[1] + margin)]
        ]

    # --- 2. Grid Setup ---
    res = 150
    z_range = [-4, 4]

    u1 = torch.linspace(u_range[0][0], u_range[0][1], res)
    u2 = torch.linspace(u_range[1][0], u_range[1][1], res)
    ug1, ug2 = torch.meshgrid(u1, u2, indexing='ij')
    uu = torch.stack([ug1.flatten(), ug2.flatten()], dim=1)

    z1 = torch.linspace(z_range[0], z_range[1], res)
    z2 = torch.linspace(z_range[0], z_range[1], res)
    zg1, zg2 = torch.meshgrid(z1, z2, indexing='ij')
    zz = torch.stack([zg1.flatten(), zg2.flatten()], dim=1)

    # --- 3. Compute Densities ---
    with torch.no_grad():
        y_context_grid = test_y.repeat(uu.shape[0], 1)

        # Flow Density q(u|y)
        log_qu = model.log_prob(uu, y_context_grid)
        qu_density = torch.exp(log_qu).reshape(res, res).numpy()

        # True Posterior Density p(u|y)
        log_pi = pi_test.log_prob(uu)
        pi_density = torch.exp(log_pi).reshape(res, res).numpy()

        # Latent Density Reference
        log_pz = model.ref.log_prob(zz)
        pz_density = torch.exp(log_pz).reshape(res, res).numpy()

    # --- 4. Plotting ---
    plt.figure(figsize=(14, 5))

    # Left Plot: Target Space
    plt.subplot(1, 2, 1)
    plt.contourf(ug1.numpy(), ug2.numpy(), pi_density, levels=50, cmap='Oranges', alpha=0.4)
    plt.contour(ug1.numpy(), ug2.numpy(), qu_density, levels=8, colors='black', linewidths=1.0)
    plt.plot(us[:, 0].cpu(), us[:, 1].cpu(), '.b', markersize=4, alpha=0.5, label='Flow Samples')

    # Format y for title
    y_vals = test_y.flatten().cpu().numpy()
    y_str = "[" + ", ".join([f"{v:.3f}" for v in y_vals]) + "]"

    plt.title(f"Target Space (u) | {title_suffix}\ny = {y_str}", fontsize=10)
    plt.xlabel("$u_1$"); plt.ylabel("$u_2$"); plt.legend()

    # Right Plot: Latent Space
    plt.subplot(1, 2, 2)
    levs_smooth = np.linspace(0, 1, 200)**1.5 * pz_density.max()
    plt.contourf(zg1.numpy(), zg2.numpy(), pz_density, levels=levs_smooth, cmap='Greens', alpha=0.85)
    plt.plot(zs[:, 0].cpu(), zs[:, 1].cpu(), '.b', markersize=4, alpha=0.5, label='Latent z')

    plt.title("Latent Space (z)\nShould be Standard Normal")
    plt.xlabel("$z_1$"); plt.ylabel("$z_2$"); plt.xlim(z_range); plt.ylim(z_range)

    plt.tight_layout()
    plt.show()

# --- Execution with 3 Specific Test Cases ---

if 'flow' in globals() and flow is not None:
    # 1. User specified Y
    y1 = torch.tensor([-0.0072, 0.0427, 0.0505, 0.0638, 0.0729], dtype=torch.float32)

    # 2. Standard Y (Common BOD observation)
    y2 = torch.tensor([0.1615, 0.1868, 0.3949, 0.3728, 0.4177], dtype=torch.float32)

    # 3. High Value Y (To test generalization to stronger signals)
    y3 = torch.tensor([0.2, 0.4, 0.6, 0.75, 0.85], dtype=torch.float32)

    test_cases = [
        (y1, "User Case (Low values)"),
        (y2, "Standard Case (Medium values)"),
        (y3, "Synthetic Case (High values)")
    ]

    print(f"{'='*20} Amortized Inference Test {'='*20}")

    for i, (y_val, desc) in enumerate(test_cases):
        print(f"\nTest Case {i+1}: {desc}")
        run_amortized_test(flow, test_y=y_val, title_suffix=desc)